123456789101112131415161718192021222324252627282930313233343536 |
- # coding=gbk
- from langchain.llms.base import LLM
- from typing import Any, List, Optional
- from langchain.callbacks.manager import CallbackManagerForLLMRun
- from transformers import AutoTokenizer, AutoModelForCausalLM
- import torch
- class ChatGLM_LLM(LLM):
- # ���ڱ��� InternLM �Զ��� LLM ��
- tokenizer: AutoTokenizer = None
- model: AutoModelForCausalLM = None
- def __init__(self, model_path: str):
- # model_path: InternLM ģ��·��
- # �ӱ��س�ʼ��ģ��
- super().__init__()
- print("���ڴӱ��ؼ���ģ��...")
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda(
- device=1)
- self.model = self.model.eval()
- print("��ɱ���ģ�͵ļ���")
- def _call(self, prompt: str, stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any):
- # ��д���ú���
- response, history = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
- return response
- @property
- def _llm_type(self) -> str:
- return "ChatGLM3-6B"
|