#coding:utf8 from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings # from langchain.text_splitter import CharacterTextSplitter from langchain.vectorstores import Chroma # from langchain.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.docstore.document import Document class LangChainChroma(): def __init__(self): self.embedding_function = SentenceTransformerEmbeddings(model_name="text_analysis/shibing624/text2vec-base-chinese") def addChroma(self,data,baseName,logging,chunkSize=500): fpath='allField/'+baseName db = Chroma(collection_name=baseName, embedding_function=self.embedding_function,persist_directory=fpath) documents = [Document(page_content=data)] # text_splitter = CharacterTextSplitter(separator=",",chunk_size=500,chunk_overlap=0,length_function=len) text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunkSize, chunk_overlap=0, separators=["\n\n", "\n"," ", "。"] ) docs = text_splitter.split_documents(documents) res=db.add_documents(documents=docs) db_count = db._collection.count() logging.info('当前数据划分{}个块,块大小{}。数据库{}共有{}个块'.format(len(docs),chunkSize,baseName,db_count)) return res if __name__=="__main__": LC=LangChainChroma() db = Chroma(collection_name='policy', embedding_function=LC.embedding_function, persist_directory='../policy') # db.delete('a5909489-5bc4-4b30-a949-9bd4bb06c477') # #创建数据库 # with open("policy_test.txt", "r", encoding="utf8") as f: # for line in tqdm(f): # LC.createChroma(line,db) # print("over")