|
@@ -0,0 +1,179 @@
|
|
|
+import asyncio
|
|
|
+import logging
|
|
|
+import threading
|
|
|
+from functools import partial
|
|
|
+from typing import Dict, List, Optional
|
|
|
+
|
|
|
+import requests
|
|
|
+
|
|
|
+from langchain.pydantic_v1 import BaseModel, root_validator
|
|
|
+from langchain.schema.embeddings import Embeddings
|
|
|
+from langchain.utils import get_from_dict_or_env
|
|
|
+import erniebot
|
|
|
+import numpy as np
|
|
|
+import time
|
|
|
+import os
|
|
|
+## 注意不要用翻墙
|
|
|
+## https://python.langchain.com/docs/integrations/chat/ernie
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+class ErnieEmbeddings(BaseModel, Embeddings):
|
|
|
+ """`Ernie Embeddings V1` embedding models."""
|
|
|
+
|
|
|
+ ernie_api_base: Optional[str] = None
|
|
|
+ ernie_client_id: Optional[str] = None
|
|
|
+ ernie_client_secret: Optional[str] = None
|
|
|
+ access_token: Optional[str] = None#erniebot.access_token = '<access-token-for-aistudio>'
|
|
|
+
|
|
|
+ chunk_size: int = 16
|
|
|
+
|
|
|
+ model_name = "ErnieBot-Embedding-V1"
|
|
|
+
|
|
|
+ _lock = threading.Lock()
|
|
|
+ '''
|
|
|
+ kevin modify:
|
|
|
+ '''
|
|
|
+ @root_validator()
|
|
|
+ def validate_environment(cls, values: Dict) -> Dict:
|
|
|
+ # values["ernie_api_base"] = get_from_dict_or_env(
|
|
|
+ # values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
|
|
|
+ # )
|
|
|
+ values["access_token"] = get_from_dict_or_env(
|
|
|
+ values,
|
|
|
+ "access_token",
|
|
|
+ "ACCESS_TOKEN",
|
|
|
+ )
|
|
|
+ values["api_type"] = 'aistudio'
|
|
|
+
|
|
|
+ erniebot.api_type = values["api_type"]
|
|
|
+ erniebot.access_token = values["access_token"]
|
|
|
+ return values
|
|
|
+
|
|
|
+ # def _embedding(self, json: object) -> dict:
|
|
|
+ # base_url = (
|
|
|
+ # f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
|
|
|
+ # )
|
|
|
+ # resp = requests.post(
|
|
|
+ # f"{base_url}/embedding-v1",
|
|
|
+ # headers={
|
|
|
+ # "Content-Type": "application/json",
|
|
|
+ # },
|
|
|
+ # params={"access_token": self.access_token},
|
|
|
+ # json=json,
|
|
|
+ # )
|
|
|
+ # return resp.json()
|
|
|
+ '''
|
|
|
+ kevin modify:
|
|
|
+ '''
|
|
|
+ def _embedding(self, json: object) -> dict:
|
|
|
+ inputs=json['input']
|
|
|
+ def erniebotSDK(inputs):
|
|
|
+ response = erniebot.Embedding.create(
|
|
|
+ model='ernie-text-embedding',
|
|
|
+ input=inputs)
|
|
|
+ time.sleep(1)
|
|
|
+ return response
|
|
|
+ try:
|
|
|
+ response=erniebotSDK(inputs)
|
|
|
+ except:
|
|
|
+ print('connect erniebot error...wait 2s to retry(kevin)')
|
|
|
+ time.sleep(2)
|
|
|
+ response=erniebotSDK(inputs)
|
|
|
+ return response
|
|
|
+
|
|
|
+ def _refresh_access_token_with_lock(self) -> None:
|
|
|
+ with self._lock:
|
|
|
+ logger.debug("Refreshing access token")
|
|
|
+ base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
|
|
|
+ resp = requests.post(
|
|
|
+ base_url,
|
|
|
+ headers={
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "Accept": "application/json",
|
|
|
+ },
|
|
|
+ params={
|
|
|
+ "grant_type": "client_credentials",
|
|
|
+ "client_id": self.ernie_client_id,
|
|
|
+ "client_secret": self.ernie_client_secret,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.access_token = str(resp.json().get("access_token"))
|
|
|
+
|
|
|
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
+ """Embed search docs.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ texts: The list of texts to embed
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[List[float]]: List of embeddings, one for each text.
|
|
|
+ """
|
|
|
+
|
|
|
+ if not self.access_token:
|
|
|
+ self._refresh_access_token_with_lock()
|
|
|
+ text_in_chunks = [
|
|
|
+ texts[i : i + self.chunk_size]
|
|
|
+ for i in range(0, len(texts), self.chunk_size)
|
|
|
+ ]
|
|
|
+ lst = []
|
|
|
+ for chunk in text_in_chunks:
|
|
|
+ resp = self._embedding({"input": [text for text in chunk]})
|
|
|
+ if resp.get("error_code"):
|
|
|
+ if resp.get("error_code") == 111:
|
|
|
+ self._refresh_access_token_with_lock()
|
|
|
+ resp = self._embedding({"input": [text for text in chunk]})
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Error from Ernie: {resp}")
|
|
|
+ lst.extend([i["embedding"] for i in resp["data"]])
|
|
|
+ return lst
|
|
|
+
|
|
|
+ def embed_query(self, text: str) -> List[float]:
|
|
|
+ """Embed query text.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text: The text to embed.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[float]: Embeddings for the text.
|
|
|
+ """
|
|
|
+
|
|
|
+ if not self.access_token:
|
|
|
+ self._refresh_access_token_with_lock()
|
|
|
+ resp = self._embedding({"input": [text]})
|
|
|
+ if resp.get("error_code"):
|
|
|
+ if resp.get("error_code") == 111:
|
|
|
+ self._refresh_access_token_with_lock()
|
|
|
+ resp = self._embedding({"input": [text]})
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Error from Ernie: {resp}")
|
|
|
+ return resp["data"][0]["embedding"]
|
|
|
+
|
|
|
+ async def aembed_query(self, text: str) -> List[float]:
|
|
|
+ """Asynchronous Embed query text.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ text: The text to embed.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[float]: Embeddings for the text.
|
|
|
+ """
|
|
|
+
|
|
|
+ return await asyncio.get_running_loop().run_in_executor(
|
|
|
+ None, partial(self.embed_query, text)
|
|
|
+ )
|
|
|
+
|
|
|
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
+ """Asynchronous Embed search docs.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ texts: The list of texts to embed
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[List[float]]: List of embeddings, one for each text.
|
|
|
+ """
|
|
|
+
|
|
|
+ result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
|
|
|
+
|
|
|
+ return list(result)
|