LLM.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # coding=gbk
  2. from langchain.llms.base import LLM
  3. from typing import Any, List, Optional
  4. from langchain.callbacks.manager import CallbackManagerForLLMRun
  5. from transformers import AutoTokenizer, AutoModelForCausalLM
  6. import torch
  7. class ChatGLM_LLM(LLM):
  8. # ���ڱ��� InternLM �Զ��� LLM ��
  9. tokenizer: AutoTokenizer = None
  10. model: AutoModelForCausalLM = None
  11. def __init__(self, model_path: str):
  12. # model_path: InternLM ģ��·��
  13. # �ӱ��س�ʼ��ģ��
  14. super().__init__()
  15. print("���ڴӱ��ؼ���ģ��...")
  16. self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
  17. self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda(
  18. device=1)
  19. self.model = self.model.eval()
  20. print("��ɱ���ģ�͵ļ���")
  21. def _call(self, prompt: str, stop: Optional[List[str]] = None,
  22. run_manager: Optional[CallbackManagerForLLMRun] = None,
  23. **kwargs: Any):
  24. # ��д���ú���
  25. response, history = self.model.chat(self.tokenizer, prompt, history=[], do_sample=False)
  26. return response
  27. @property
  28. def _llm_type(self) -> str:
  29. return "ChatGLM3-6B"