LLM.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # coding=gbk
  2. from langchain.llms.base import LLM
  3. from typing import Any, List, Optional
  4. from langchain.callbacks.manager import CallbackManagerForLLMRun
  5. from transformers import AutoTokenizer, AutoModelForCausalLM
  6. import torch
  7. class ChatGLM_LLM(LLM):
  8. # 基于本地 InternLM 自定义 LLM 类
  9. tokenizer: AutoTokenizer = None
  10. model: AutoModelForCausalLM = None
  11. def __init__(self, model_path: str):
  12. # model_path: InternLM 模型路径
  13. # 从本地初始化模型
  14. super().__init__()
  15. print("正在从本地加载模型...")
  16. self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  17. self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda(
  18. device=1)
  19. self.model = self.model.eval()
  20. print("完成本地模型的加载")
  21. def _call(self, prompt: str, stop: Optional[List[str]] = None,
  22. run_manager: Optional[CallbackManagerForLLMRun] = None,
  23. **kwargs: Any):
  24. # 重写调用函数
  25. response, history = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
  26. return response
  27. @property
  28. def _llm_type(self) -> str:
  29. return "ChatGLM3-6B"