123456789101112131415161718192021222324252627282930313233343536 |
- # coding=gbk
- from langchain.llms.base import LLM
- from typing import Any, List, Optional
- from langchain.callbacks.manager import CallbackManagerForLLMRun
- from transformers import AutoTokenizer, AutoModelForCausalLM
- import torch
- class ChatGLM_LLM(LLM):
- # 基于本地 InternLM 自定义 LLM 类
- tokenizer: AutoTokenizer = None
- model: AutoModelForCausalLM = None
- def __init__(self, model_path: str):
- # model_path: InternLM 模型路径
- # 从本地初始化模型
- super().__init__()
- print("正在从本地加载模型...")
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
- self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda(
- device=1)
- self.model = self.model.eval()
- print("完成本地模型的加载")
- def _call(self, prompt: str, stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any):
- # 重写调用函数
- response, history = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
- return response
- @property
- def _llm_type(self) -> str:
- return "ChatGLM3-6B"
|