LLM.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from langchain.llms.base import LLM
  2. from typing import Any, List, Optional
  3. from langchain.callbacks.manager import CallbackManagerForLLMRun
  4. from transformers import AutoTokenizer, AutoModelForCausalLM
  5. import torch
  6. class ChatGLM_LLM(LLM):
  7. # 基于本地 InternLM 自定义 LLM 类
  8. tokenizer: AutoTokenizer = None
  9. model: AutoModelForCausalLM = None
  10. def __init__(self, model_path: str):
  11. # model_path: InternLM 模型路径
  12. # 从本地初始化模型
  13. super().__init__()
  14. print("正在从本地加载模型...")
  15. self.tokenizer = AutoTokenizer.from_pretrained(
  16. model_path, trust_remote_code=True
  17. )
  18. self.model = (
  19. AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
  20. .to(torch.bfloat16)
  21. .cuda(device=1)
  22. )
  23. self.model = self.model.eval()
  24. print("完成本地模型的加载")
  25. def _call(
  26. self,
  27. prompt: str,
  28. stop: Optional[List[str]] = None,
  29. run_manager: Optional[CallbackManagerForLLMRun] = None,
  30. **kwargs: Any
  31. ):
  32. # 重写调用函数
  33. response, history = self.model.chat(
  34. self.tokenizer, prompt, history=[], do_sample=False
  35. )
  36. return response
  37. @property
  38. def _llm_type(self) -> str:
  39. return "ChatGLM3-6B"