buildVectors.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import pandas as pd
  2. import chromadb
  3. import os
  4. import erniebot
  5. erniebot.api_type = "aistudio"
  6. erniebot.access_token = "ff1531c8c0f429f92adbc2eaed2e23bfb5349e0f"
  7. chroma_client = chromadb.PersistentClient(path="data/chroma")
  8. collection = chroma_client.create_collection(name="collection")
  9. from langchain.text_splitter import RecursiveCharacterTextSplitter
  10. def get_files(dir_path):
  11. # args:dir_path,目标文件夹路径
  12. file_list = []
  13. for filepath, dirnames, filenames in os.walk(dir_path):
  14. # os.walk 函数将递归遍历指定文件夹
  15. for filename in filenames:
  16. # 通过后缀名判断文件类型是否满足要求
  17. if filename.endswith(".csv"):
  18. file_list.append(os.path.join(filepath, filename))
  19. return file_list
  20. tar_dir = "files/"
  21. file_list = get_files(tar_dir)
  22. for file in file_list:
  23. df = pd.read_csv(file)
  24. print(df.head())
  25. print(file.split("/")[-1].split(".")[0])
  26. print("---")
  27. file_list = get_files(tar_dir)
  28. my_id = 1
  29. for file in file_list:
  30. print(file)
  31. docs = []
  32. metadatas = []
  33. ids = []
  34. embeddings = []
  35. books = []
  36. df = pd.read_csv(file)
  37. book_name = file.split("/")[-1].split(".")[0]
  38. for index, row in df.iterrows():
  39. title = row.iloc[0]
  40. text = row.iloc[1]
  41. text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=0)
  42. all_splits = text_splitter.split_text(text)
  43. all_splits = ["title: " + title + " content:" + s for s in all_splits]
  44. try:
  45. response = erniebot.Embedding.create(
  46. model='ernie-text-embedding',
  47. input=all_splits)
  48. except Exception as e:
  49. print(all_splits)
  50. print(e)
  51. continue
  52. for i in range(len(all_splits)):
  53. docs.append(text)
  54. metadatas.append({"book": book_name, "title": title})
  55. books.append(book_name)
  56. embeddings.append(response.data[i]['embedding'])
  57. ids.append(f"id{my_id}")
  58. my_id += 1
  59. collection.add(documents=docs,
  60. metadatas=metadatas,
  61. ids=ids,
  62. embeddings=embeddings)
  63. print("Number of vectors in vectordb: ", collection.count())