ysyyhhh 1 år sedan
förälder
incheckning
efdec3b13c
1 ändrade filer med 18 tillägg och 12 borttagningar
  1. 18 12
      LLM.py

+ 18 - 12
LLM.py

@@ -1,4 +1,3 @@
-# coding=gbk
 from langchain.llms.base import LLM
 from typing import Any, List, Optional
 from langchain.callbacks.manager import CallbackManagerForLLMRun
@@ -16,23 +15,30 @@ class ChatGLM_LLM(LLM):
         # 从本地初始化模型
         super().__init__()
         print("正在从本地加载模型...")
-        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
-        self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda(
-            device=1)
+        self.tokenizer = AutoTokenizer.from_pretrained(
+            model_path, trust_remote_code=True
+        )
+        self.model = (
+            AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
+            .to(torch.bfloat16)
+            .cuda(device=1)
+        )
         self.model = self.model.eval()
         print("完成本地模型的加载")
 
-    def _call(self, prompt: str, stop: Optional[List[str]] = None,
-              run_manager: Optional[CallbackManagerForLLMRun] = None,
-              **kwargs: Any):
+    def _call(
+        self,
+        prompt: str,
+        stop: Optional[List[str]] = None,
+        run_manager: Optional[CallbackManagerForLLMRun] = None,
+        **kwargs: Any
+    ):
         # 重写调用函数
-        response, history = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
+        response, history = self.model.chat(
+            self.tokenizer, prompt, history=[], do_sample=False
+        )
         return response
 
-
-
     @property
     def _llm_type(self) -> str:
         return "ChatGLM3-6B"
-
-