buildVectors3.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # coding=utf-8
  2. import chromadb
  3. import pandas as pd
  4. from langchain.document_loaders import UnstructuredFileLoader
  5. from langchain.document_loaders import UnstructuredMarkdownLoader
  6. from langchain.text_splitter import RecursiveCharacterTextSplitter
  7. from langchain.vectorstores import Chroma
  8. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  9. from langchain_community.document_loaders import PyPDFLoader
  10. from langchain_community.document_loaders import MathpixPDFLoader
  11. from langchain_community.document_loaders import UnstructuredPDFLoader
  12. from langchain_community.document_loaders import AzureAIDocumentIntelligenceLoader
  13. from langchain_community.document_loaders import Docx2txtLoader
  14. from langchain_community.document_loaders.csv_loader import CSVLoader
  15. from langchain.text_splitter import MarkdownHeaderTextSplitter
  16. from tqdm import tqdm
  17. from sentence_transformers import SentenceTransformer, util
  18. import os
  19. import chardet
  20. import erniebot
  21. import numpy as np
  22. from langchain_embedding_ErnieBotSDK import ErnieEmbeddings
  23. erniebot.api_type = "aistudio"
  24. erniebot.access_token = "ff1531c8c0f429f92adbc2eaed2e23bfb5349e0f"
  25. embeddings=ErnieEmbeddings(access_token="ff1531c8c0f429f92adbc2eaed2e23bfb5349e0f", chunk_size=1)
  26. response = erniebot.Embedding.create(
  27. model='ernie-text-embedding',
  28. input=[
  29. "我是百度公司开发的人工智能语言模型,我的中文名是文心一言,英文名是ERNIE-Bot。",
  30. "2018年深圳市各区GDP"
  31. ])
  32. for embedding in response.get_result():
  33. embedding = np.array(embedding)
  34. print(embedding)
  35. chroma_client = chromadb.PersistentClient(path="data/chroma")
  36. collection = chroma_client.create_collection(name="collection")
  37. # 获取文件路径函数
  38. def get_files(dir_path):
  39. # args:dir_path,目标文件夹路径
  40. file_list = []
  41. for filepath, dirnames, filenames in os.walk(dir_path):
  42. # os.walk 函数将递归遍历指定文件夹
  43. for filename in filenames:
  44. # 通过后缀名判断文件类型是否满足要求
  45. if filename.endswith(".csv"):
  46. file_list.append(os.path.join(filepath, filename))
  47. return file_list
  48. # 加载文件函数
  49. def get_text(dir_path):
  50. file_lst = get_files(dir_path)
  51. docs = []
  52. metadatas = []
  53. ids = []
  54. embeddings = []
  55. for one_file in tqdm(file_lst):
  56. file_type = one_file.split('.')[-1]
  57. # 尝试检测文件编码
  58. with open(one_file, 'rb') as f:
  59. rawdata = f.read()
  60. encoding = chardet.detect(rawdata)['encoding']
  61. print(f"Detected encoding for {one_file}: {encoding}")
  62. if file_type == 'csv':
  63. df = pd.read_csv(one_file)
  64. for index, row in df.iterrows():
  65. output_str = ""
  66. text = row[1]
  67. text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150)
  68. all_splits = text_splitter.split_text(text)
  69. for split_text in all_splits:
  70. split_text = row[0] + " " + split_text
  71. my_list = []
  72. my_list.append(output_str)
  73. response = erniebot.Embedding.create(
  74. model='ernie-text-embedding',
  75. input=my_list)
  76. embeddings.append(response.data[0].embedding)
  77. docs.append(output_str)
  78. metadatas.append({"source": one_file.split(".")[0]})
  79. ids.append(f"id{index}")
  80. else:
  81. continue
  82. return docs
  83. # 目标文件夹
  84. tar_dir = [
  85. "files/",
  86. ]
  87. # 加载目标文件
  88. docs = []
  89. for dir_path in tar_dir:
  90. print(get_text(dir_path))
  91. docs.extend(get_text(dir_path))
  92. print(docs)
  93. #embeddings = HuggingFaceEmbeddings(model_name="/mnt/sdb/zhaoyuan/rongrunxiang/acge_text_embedding")
  94. # 构建向量数据库
  95. # 定义持久化路径
  96. persist_directory = 'data_base/vector_db/chroma'
  97. # 加载数据库
  98. vectordb = Chroma.from_documents(
  99. documents=docs,
  100. embedding=embeddings,
  101. persist_directory=persist_directory # 允许我们将persist_directory目录保存到磁盘上
  102. )
  103. # 将加载的向量数据库持久化到磁盘上
  104. vectordb.persist()