chroma新增、删除、知识库应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

47 lines
1.8 KiB

  1. #coding:utf8
  2. from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
  3. # from langchain.text_splitter import CharacterTextSplitter
  4. from langchain.vectorstores import Chroma
  5. # from langchain.document_loaders import TextLoader
  6. from langchain.text_splitter import RecursiveCharacterTextSplitter
  7. from langchain.docstore.document import Document
  8. class LangChainChroma():
  9. def __init__(self):
  10. self.embedding_function = SentenceTransformerEmbeddings(model_name="text_analysis/shibing624/text2vec-base-chinese")
  11. def addChroma(self,data,baseName,logging,chunkSize=500):
  12. fpath='allField/'+baseName
  13. db = Chroma(collection_name=baseName, embedding_function=self.embedding_function,persist_directory=fpath)
  14. documents = [Document(page_content=data)]
  15. # text_splitter = CharacterTextSplitter(separator=",",chunk_size=500,chunk_overlap=0,length_function=len)
  16. text_splitter = RecursiveCharacterTextSplitter(
  17. chunk_size=chunkSize,
  18. chunk_overlap=0,
  19. separators=["\n\n", "\n"," ", ""]
  20. )
  21. docs = text_splitter.split_documents(documents)
  22. res=db.add_documents(documents=docs)
  23. db_count = db._collection.count()
  24. logging.info('当前数据划分{}个块,块大小{}。数据库{}共有{}个块'.format(len(docs),chunkSize,baseName,db_count))
  25. return res
  26. if __name__=="__main__":
  27. LC=LangChainChroma()
  28. db = Chroma(collection_name='policy', embedding_function=LC.embedding_function,
  29. persist_directory='../policy')
  30. # db.delete('a5909489-5bc4-4b30-a949-9bd4bb06c477')
  31. # #创建数据库
  32. # with open("policy_test.txt", "r", encoding="utf8") as f:
  33. # for line in tqdm(f):
  34. # LC.createChroma(line,db)
  35. # print("over")