# coding=gbk from langchain.llms.base import LLM from typing import Any, List, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from transformers import AutoTokenizer, AutoModelForCausalLM import torch class ChatGLM_LLM(LLM): # 基于本地 InternLM 自定义 LLM 类 tokenizer: AutoTokenizer = None model: AutoModelForCausalLM = None def __init__(self, model_path: str): # model_path: InternLM 模型路径 # 从本地初始化模型 super().__init__() print("正在从本地加载模型...") self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda( device=1) self.model = self.model.eval() print("完成本地模型的加载") def _call(self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any): # 重写调用函数 response, history = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False) return response @property def _llm_type(self) -> str: return "ChatGLM3-6B"