|
@@ -1,4 +1,3 @@
|
|
|
-# coding=gbk
|
|
|
from langchain.llms.base import LLM
|
|
|
from typing import Any, List, Optional
|
|
|
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
@@ -16,23 +15,30 @@ class ChatGLM_LLM(LLM):
|
|
|
# 从本地初始化模型
|
|
|
super().__init__()
|
|
|
print("正在从本地加载模型...")
|
|
|
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
|
|
- self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda(
|
|
|
- device=1)
|
|
|
+ self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
+ model_path, trust_remote_code=True
|
|
|
+ )
|
|
|
+ self.model = (
|
|
|
+ AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
|
|
|
+ .to(torch.bfloat16)
|
|
|
+ .cuda(device=1)
|
|
|
+ )
|
|
|
self.model = self.model.eval()
|
|
|
print("完成本地模型的加载")
|
|
|
|
|
|
- def _call(self, prompt: str, stop: Optional[List[str]] = None,
|
|
|
- run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
- **kwargs: Any):
|
|
|
+ def _call(
|
|
|
+ self,
|
|
|
+ prompt: str,
|
|
|
+ stop: Optional[List[str]] = None,
|
|
|
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
+ **kwargs: Any
|
|
|
+ ):
|
|
|
# 重写调用函数
|
|
|
- response, history = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
|
|
|
+ response, history = self.model.chat(
|
|
|
+ self.tokenizer, prompt, history=[], do_sample=False
|
|
|
+ )
|
|
|
return response
|
|
|
|
|
|
-
|
|
|
-
|
|
|
@property
|
|
|
def _llm_type(self) -> str:
|
|
|
return "ChatGLM3-6B"
|
|
|
-
|
|
|
-
|