langchain_embedding_ErnieBotSDK.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import asyncio
  2. import logging
  3. import threading
  4. from functools import partial
  5. from typing import Dict, List, Optional
  6. import requests
  7. from langchain.pydantic_v1 import BaseModel, root_validator
  8. from langchain.schema.embeddings import Embeddings
  9. from langchain.utils import get_from_dict_or_env
  10. import erniebot
  11. import numpy as np
  12. import time
  13. import os
  14. ## 注意不要用翻墙
  15. ## https://python.langchain.com/docs/integrations/chat/ernie
  16. logger = logging.getLogger(__name__)
  17. class ErnieEmbeddings(BaseModel, Embeddings):
  18. """`Ernie Embeddings V1` embedding models."""
  19. ernie_api_base: Optional[str] = None
  20. ernie_client_id: Optional[str] = None
  21. ernie_client_secret: Optional[str] = None
  22. access_token: Optional[str] = None#erniebot.access_token = '<access-token-for-aistudio>'
  23. chunk_size: int = 16
  24. model_name = "ErnieBot-Embedding-V1"
  25. _lock = threading.Lock()
  26. '''
  27. kevin modify:
  28. '''
  29. @root_validator()
  30. def validate_environment(cls, values: Dict) -> Dict:
  31. # values["ernie_api_base"] = get_from_dict_or_env(
  32. # values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
  33. # )
  34. values["access_token"] = get_from_dict_or_env(
  35. values,
  36. "access_token",
  37. "ACCESS_TOKEN",
  38. )
  39. values["api_type"] = 'aistudio'
  40. erniebot.api_type = values["api_type"]
  41. erniebot.access_token = values["access_token"]
  42. return values
  43. # def _embedding(self, json: object) -> dict:
  44. # base_url = (
  45. # f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
  46. # )
  47. # resp = requests.post(
  48. # f"{base_url}/embedding-v1",
  49. # headers={
  50. # "Content-Type": "application/json",
  51. # },
  52. # params={"access_token": self.access_token},
  53. # json=json,
  54. # )
  55. # return resp.json()
  56. '''
  57. kevin modify:
  58. '''
  59. def _embedding(self, json: object) -> dict:
  60. inputs=json['input']
  61. def erniebotSDK(inputs):
  62. response = erniebot.Embedding.create(
  63. model='ernie-text-embedding',
  64. input=inputs)
  65. time.sleep(1)
  66. return response
  67. try:
  68. response=erniebotSDK(inputs)
  69. except:
  70. print('connect erniebot error...wait 2s to retry(kevin)')
  71. time.sleep(2)
  72. response=erniebotSDK(inputs)
  73. return response
  74. def _refresh_access_token_with_lock(self) -> None:
  75. with self._lock:
  76. logger.debug("Refreshing access token")
  77. base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
  78. resp = requests.post(
  79. base_url,
  80. headers={
  81. "Content-Type": "application/json",
  82. "Accept": "application/json",
  83. },
  84. params={
  85. "grant_type": "client_credentials",
  86. "client_id": self.ernie_client_id,
  87. "client_secret": self.ernie_client_secret,
  88. },
  89. )
  90. self.access_token = str(resp.json().get("access_token"))
  91. def embed_documents(self, texts: List[str]) -> List[List[float]]:
  92. """Embed search docs.
  93. Args:
  94. texts: The list of texts to embed
  95. Returns:
  96. List[List[float]]: List of embeddings, one for each text.
  97. """
  98. if not self.access_token:
  99. self._refresh_access_token_with_lock()
  100. text_in_chunks = [
  101. texts[i : i + self.chunk_size]
  102. for i in range(0, len(texts), self.chunk_size)
  103. ]
  104. lst = []
  105. for chunk in text_in_chunks:
  106. resp = self._embedding({"input": [text for text in chunk]})
  107. if resp.get("error_code"):
  108. if resp.get("error_code") == 111:
  109. self._refresh_access_token_with_lock()
  110. resp = self._embedding({"input": [text for text in chunk]})
  111. else:
  112. raise ValueError(f"Error from Ernie: {resp}")
  113. lst.extend([i["embedding"] for i in resp["data"]])
  114. return lst
  115. def embed_query(self, text: str) -> List[float]:
  116. """Embed query text.
  117. Args:
  118. text: The text to embed.
  119. Returns:
  120. List[float]: Embeddings for the text.
  121. """
  122. if not self.access_token:
  123. self._refresh_access_token_with_lock()
  124. resp = self._embedding({"input": [text]})
  125. if resp.get("error_code"):
  126. if resp.get("error_code") == 111:
  127. self._refresh_access_token_with_lock()
  128. resp = self._embedding({"input": [text]})
  129. else:
  130. raise ValueError(f"Error from Ernie: {resp}")
  131. return resp["data"][0]["embedding"]
  132. async def aembed_query(self, text: str) -> List[float]:
  133. """Asynchronous Embed query text.
  134. Args:
  135. text: The text to embed.
  136. Returns:
  137. List[float]: Embeddings for the text.
  138. """
  139. return await asyncio.get_running_loop().run_in_executor(
  140. None, partial(self.embed_query, text)
  141. )
  142. async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
  143. """Asynchronous Embed search docs.
  144. Args:
  145. texts: The list of texts to embed
  146. Returns:
  147. List[List[float]]: List of embeddings, one for each text.
  148. """
  149. result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
  150. return list(result)