from langchain.llms.base import LLM from typing import Any, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from transformers import AutoTokenizer, AutoModelForCausalLM import torch class ChatGLM_LLM(LLM): # 基于本地 InternLM 自定义 LLM 类 tokenizer: AutoTokenizer = None model: AutoModelForCausalLM = None def __init__(self, model_path: str): # model_path: InternLM 模型路径 # 从本地初始化模型 super().__init__() print("正在从本地加载模型...") self.tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True ) self.model = ( AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) .to(torch.bfloat16) .cuda(device=1) ) self.model = self.model.eval() print("完成本地模型的加载") def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any ): # 重写调用函数 response, history = self.model.chat( self.tokenizer, prompt, history=[], do_sample=False ) return response @property def _llm_type(self) -> str: return "ChatGLM3-6B"