123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- import asyncio
- import logging
- import threading
- from functools import partial
- from typing import Dict, List, Optional
- import requests
- from langchain.pydantic_v1 import BaseModel, root_validator
- from langchain.schema.embeddings import Embeddings
- from langchain.utils import get_from_dict_or_env
- import erniebot
- import numpy as np
- import time
- import os
- ## 注意不要用翻墙
- ## https://python.langchain.com/docs/integrations/chat/ernie
- logger = logging.getLogger(__name__)
- class ErnieEmbeddings(BaseModel, Embeddings):
- """`Ernie Embeddings V1` embedding models."""
- ernie_api_base: Optional[str] = None
- ernie_client_id: Optional[str] = None
- ernie_client_secret: Optional[str] = None
- access_token: Optional[str] = None#erniebot.access_token = '<access-token-for-aistudio>'
-
- chunk_size: int = 16
- model_name = "ErnieBot-Embedding-V1"
- _lock = threading.Lock()
- '''
- kevin modify:
- '''
- @root_validator()
- def validate_environment(cls, values: Dict) -> Dict:
- # values["ernie_api_base"] = get_from_dict_or_env(
- # values, "ernie_api_base", "ERNIE_API_BASE", "https://aip.baidubce.com"
- # )
- values["access_token"] = get_from_dict_or_env(
- values,
- "access_token",
- "ACCESS_TOKEN",
- )
- values["api_type"] = 'aistudio'
-
- erniebot.api_type = values["api_type"]
- erniebot.access_token = values["access_token"]
- return values
- # def _embedding(self, json: object) -> dict:
- # base_url = (
- # f"{self.ernie_api_base}/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings"
- # )
- # resp = requests.post(
- # f"{base_url}/embedding-v1",
- # headers={
- # "Content-Type": "application/json",
- # },
- # params={"access_token": self.access_token},
- # json=json,
- # )
- # return resp.json()
- '''
- kevin modify:
- '''
- def _embedding(self, json: object) -> dict:
- inputs=json['input']
- def erniebotSDK(inputs):
- response = erniebot.Embedding.create(
- model='ernie-text-embedding',
- input=inputs)
- time.sleep(1)
- return response
- try:
- response=erniebotSDK(inputs)
- except:
- print('connect erniebot error...wait 2s to retry(kevin)')
- time.sleep(2)
- response=erniebotSDK(inputs)
- return response
-
- def _refresh_access_token_with_lock(self) -> None:
- with self._lock:
- logger.debug("Refreshing access token")
- base_url: str = f"{self.ernie_api_base}/oauth/2.0/token"
- resp = requests.post(
- base_url,
- headers={
- "Content-Type": "application/json",
- "Accept": "application/json",
- },
- params={
- "grant_type": "client_credentials",
- "client_id": self.ernie_client_id,
- "client_secret": self.ernie_client_secret,
- },
- )
- self.access_token = str(resp.json().get("access_token"))
- def embed_documents(self, texts: List[str]) -> List[List[float]]:
- """Embed search docs.
- Args:
- texts: The list of texts to embed
- Returns:
- List[List[float]]: List of embeddings, one for each text.
- """
- if not self.access_token:
- self._refresh_access_token_with_lock()
- text_in_chunks = [
- texts[i : i + self.chunk_size]
- for i in range(0, len(texts), self.chunk_size)
- ]
- lst = []
- for chunk in text_in_chunks:
- resp = self._embedding({"input": [text for text in chunk]})
- if resp.get("error_code"):
- if resp.get("error_code") == 111:
- self._refresh_access_token_with_lock()
- resp = self._embedding({"input": [text for text in chunk]})
- else:
- raise ValueError(f"Error from Ernie: {resp}")
- lst.extend([i["embedding"] for i in resp["data"]])
- return lst
- def embed_query(self, text: str) -> List[float]:
- """Embed query text.
- Args:
- text: The text to embed.
- Returns:
- List[float]: Embeddings for the text.
- """
- if not self.access_token:
- self._refresh_access_token_with_lock()
- resp = self._embedding({"input": [text]})
- if resp.get("error_code"):
- if resp.get("error_code") == 111:
- self._refresh_access_token_with_lock()
- resp = self._embedding({"input": [text]})
- else:
- raise ValueError(f"Error from Ernie: {resp}")
- return resp["data"][0]["embedding"]
- async def aembed_query(self, text: str) -> List[float]:
- """Asynchronous Embed query text.
- Args:
- text: The text to embed.
- Returns:
- List[float]: Embeddings for the text.
- """
- return await asyncio.get_running_loop().run_in_executor(
- None, partial(self.embed_query, text)
- )
- async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
- """Asynchronous Embed search docs.
- Args:
- texts: The list of texts to embed
- Returns:
- List[List[float]]: List of embeddings, one for each text.
- """
- result = await asyncio.gather(*[self.aembed_query(text) for text in texts])
- return list(result)
|