1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- from langchain.llms.base import LLM
- from typing import Any, List, Optional
- from langchain.callbacks.manager import CallbackManagerForLLMRun
- from transformers import AutoTokenizer, AutoModelForCausalLM
- import torch
- class ChatGLM_LLM(LLM):
- # 基于本地 InternLM 自定义 LLM 类
- tokenizer: AutoTokenizer = None
- model: AutoModelForCausalLM = None
- def __init__(self, model_path: str):
- # model_path: InternLM 模型路径
- # 从本地初始化模型
- super().__init__()
- print("正在从本地加载模型...")
- self.tokenizer = AutoTokenizer.from_pretrained(
- model_path, trust_remote_code=True
- )
- self.model = (
- AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
- .to(torch.bfloat16)
- .cuda(device=1)
- )
- self.model = self.model.eval()
- print("完成本地模型的加载")
- def _call(
- self,
- prompt: str,
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any
- ):
- # 重写调用函数
- response, history = self.model.chat(
- self.tokenizer, prompt, history=[], do_sample=False
- )
- return response
- @property
- def _llm_type(self) -> str:
- return "ChatGLM3-6B"
|