123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import pandas as pd
- import chromadb
- import os
- import erniebot
- erniebot.api_type = "aistudio"
- erniebot.access_token = "ff1531c8c0f429f92adbc2eaed2e23bfb5349e0f"
- chroma_client = chromadb.PersistentClient(path="data/chroma")
- collection = chroma_client.create_collection(name="collection")
- from langchain.text_splitter import RecursiveCharacterTextSplitter
- def get_files(dir_path):
- # args:dir_path,目标文件夹路径
- file_list = []
- for filepath, dirnames, filenames in os.walk(dir_path):
- # os.walk 函数将递归遍历指定文件夹
- for filename in filenames:
- # 通过后缀名判断文件类型是否满足要求
- if filename.endswith(".csv"):
- file_list.append(os.path.join(filepath, filename))
- return file_list
- tar_dir = "files/"
- file_list = get_files(tar_dir)
- for file in file_list:
- df = pd.read_csv(file)
- print(df.head())
- print(file.split("/")[-1].split(".")[0])
- print("---")
- file_list = get_files(tar_dir)
- my_id = 1
- for file in file_list:
- print(file)
- docs = []
- metadatas = []
- ids = []
- embeddings = []
- books = []
- df = pd.read_csv(file)
- book_name = file.split("/")[-1].split(".")[0]
- for index, row in df.iterrows():
- title = row.iloc[0]
- text = row.iloc[1]
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=0)
- all_splits = text_splitter.split_text(text)
- all_splits = ["title: " + title + " content:" + s for s in all_splits]
- try:
- response = erniebot.Embedding.create(
- model='ernie-text-embedding',
- input=all_splits)
- except Exception as e:
- print(all_splits)
- print(e)
- continue
- for i in range(len(all_splits)):
- docs.append(text)
- metadatas.append({"book": book_name, "title": title})
- books.append(book_name)
- embeddings.append(response.data[i]['embedding'])
- ids.append(f"id{my_id}")
- my_id += 1
- collection.add(documents=docs,
- metadatas=metadatas,
- ids=ids,
- embeddings=embeddings)
- print("Number of vectors in vectordb: ", collection.count())
|