import pandas as pd
import chromadb
import os
import erniebot
erniebot.api_type = "aistudio"
erniebot.access_token = "ff1531c8c0f429f92adbc2eaed2e23bfb5349e0f"

chroma_client = chromadb.PersistentClient(path="data/chroma")
collection = chroma_client.create_collection(name="collection")

from langchain.text_splitter import RecursiveCharacterTextSplitter

def get_files(dir_path):
    # args:dir_path,目标文件夹路径
    file_list = []
    for filepath, dirnames, filenames in os.walk(dir_path):
        # os.walk 函数将递归遍历指定文件夹
        for filename in filenames:
            # 通过后缀名判断文件类型是否满足要求
            if filename.endswith(".csv"):
                file_list.append(os.path.join(filepath, filename))

    return file_list

tar_dir = "files/"

file_list = get_files(tar_dir)
for file in file_list:
    df = pd.read_csv(file)
    print(df.head())
    print(file.split("/")[-1].split(".")[0])
    print("---")

file_list = get_files(tar_dir)
my_id = 1

for file in file_list:
    print(file)
    docs = []
    metadatas = []
    ids = []
    embeddings = []
    books = []
    df = pd.read_csv(file)
    book_name = file.split("/")[-1].split(".")[0]

    for index, row in df.iterrows():
        title = row.iloc[0]
        text = row.iloc[1]
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=0)
        all_splits = text_splitter.split_text(text)
        all_splits = ["title: " + title + " content:" + s for s in all_splits]

        try:
            response = erniebot.Embedding.create(
                model='ernie-text-embedding',
                input=all_splits)
        except Exception as e:
            print(all_splits)
            print(e)
            continue
        for i in range(len(all_splits)):
            docs.append(text)
            metadatas.append({"book": book_name, "title": title})
            books.append(book_name)
            embeddings.append(response.data[i]['embedding'])
            ids.append(f"id{my_id}")
            my_id += 1

        collection.add(documents=docs,
                       metadatas=metadatas,
                       ids=ids,
                       embeddings=embeddings)

print("Number of vectors in vectordb: ", collection.count())