buildVectors2.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # coding=utf-8
  2. from langchain.document_loaders import UnstructuredFileLoader
  3. from langchain.document_loaders import UnstructuredMarkdownLoader
  4. from langchain.text_splitter import RecursiveCharacterTextSplitter
  5. from langchain.vectorstores import Chroma
  6. from langchain.embeddings.huggingface import HuggingFaceEmbeddings
  7. from langchain_community.document_loaders import PyPDFLoader
  8. from langchain_community.document_loaders import MathpixPDFLoader
  9. from langchain_community.document_loaders import UnstructuredPDFLoader
  10. from langchain_community.document_loaders import AzureAIDocumentIntelligenceLoader
  11. from langchain_community.document_loaders import Docx2txtLoader
  12. from langchain_community.document_loaders.csv_loader import CSVLoader
  13. from langchain.text_splitter import MarkdownHeaderTextSplitter
  14. from tqdm import tqdm
  15. from sentence_transformers import SentenceTransformer, util
  16. import os
  17. import chardet
  18. import erniebot
  19. import numpy as np
  20. from langchain_embedding_ErnieBotSDK import ErnieEmbeddings
  21. erniebot.api_type = "aistudio"
  22. erniebot.access_token = "ff1531c8c0f429f92adbc2eaed2e23bfb5349e0f"
  23. embeddings=ErnieEmbeddings(access_token="ff1531c8c0f429f92adbc2eaed2e23bfb5349e0f", chunk_size=1)
  24. # 获取文件路径函数
  25. def get_files(dir_path):
  26. # args:dir_path,目标文件夹路径
  27. file_list = []
  28. for filepath, dirnames, filenames in os.walk(dir_path):
  29. # os.walk 函数将递归遍历指定文件夹
  30. for filename in filenames:
  31. # 通过后缀名判断文件类型是否满足要求
  32. if filename.endswith(".md"):
  33. # 如果满足要求,将其绝对路径加入到结果列表
  34. file_list.append(os.path.join(filepath, filename))
  35. elif filename.endswith(".txt"):
  36. file_list.append(os.path.join(filepath, filename))
  37. #elif filename.endswith(".pdf"):
  38. #file_list.append(os.path.join(filepath, filename))
  39. elif filename.endswith(".docx"):
  40. file_list.append(os.path.join(filepath, filename))
  41. elif filename.endswith(".csv"):
  42. file_list.append(os.path.join(filepath, filename))
  43. return file_list
  44. # 加载文件函数
  45. def get_text(dir_path):
  46. file_lst = get_files(dir_path)
  47. docs = []
  48. for one_file in tqdm(file_lst):
  49. file_type = one_file.split('.')[-1]
  50. # 尝试检测文件编码
  51. with open(one_file, 'rb') as f:
  52. rawdata = f.read()
  53. encoding = chardet.detect(rawdata)['encoding']
  54. print(f"Detected encoding for {one_file}: {encoding}")
  55. # 根据文件类型创建适当的加载器,并指定编码
  56. if file_type == 'md':
  57. loader = UnstructuredMarkdownLoader(one_file, encoding=encoding)
  58. elif file_type == 'txt':
  59. loader = UnstructuredFileLoader(one_file, encoding=encoding)
  60. #elif file_type == 'pdf':
  61. # loader = PyPDFLoader(one_file)
  62. # loader = MathpixPDFLoader(one_file)
  63. #loader = UnstructuredPDFLoader(one_file)
  64. elif file_type == 'docx':
  65. loader = Docx2txtLoader(one_file, encoding=encoding)
  66. elif file_type == 'csv':
  67. loader = CSVLoader(one_file, encoding=encoding)
  68. else:
  69. continue
  70. # 加载文档
  71. try:
  72. docs.extend(loader.load())
  73. except UnicodeDecodeError as e:
  74. print(f"Failed to load {one_file} due to encoding error: {str(e)}")
  75. continue
  76. return docs
  77. # 目标文件夹
  78. tar_dir = [
  79. "files/",
  80. ]
  81. # 加载目标文件
  82. docs = []
  83. for dir_path in tar_dir:
  84. print(get_text(dir_path))
  85. docs.extend(get_text(dir_path))
  86. print(docs)
  87. #embeddings = HuggingFaceEmbeddings(model_name="/mnt/sdb/zhaoyuan/rongrunxiang/acge_text_embedding")
  88. # 构建向量数据库
  89. # 定义持久化路径
  90. persist_directory = 'data_base/vector_db/chroma'
  91. # 加载数据库
  92. vectordb = Chroma.from_documents(
  93. documents=docs,
  94. embedding=embeddings,
  95. persist_directory=persist_directory # 允许我们将persist_directory目录保存到磁盘上
  96. )
  97. # 将加载的向量数据库持久化到磁盘上
  98. vectordb.persist()