chroma新增、删除、知识库应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

76 lines
2.7 KiB

  1. # from chromadb.config import Settings
  2. from sentence_transformers import SentenceTransformer
  3. import chromadb
  4. from tqdm import tqdm
  5. from langchain.text_splitter import RecursiveCharacterTextSplitter
  6. import uuid
  7. class LangChainChroma:
  8. def __init__(self,collection_name):
  9. # dirPath="D:/ellie/project/2023/analyst_assistant/chromaDB/allField/"
  10. dirPath="./allField/"
  11. self.chroma_client = chromadb.PersistentClient(path=dirPath+collection_name)
  12. # chroma_client=chromadb.Client(Settings(allow_reset=True,persist_directory="../allField/demo/"))
  13. self.collection=self.chroma_client.get_or_create_collection(name=collection_name,metadata={"hnsw:space": "cosine"})
  14. model = SentenceTransformer('text_analysis/shibing624/text2vec-base-chinese')
  15. # model = SentenceTransformer('shibing624/text2vec-base-chinese')
  16. self.bge_model = model
  17. self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=200,chunk_overlap=0,separators=["\n\n", "\n", " ", ""])
  18. # chroma_client.reset()
  19. def db_close(self):
  20. self.chroma_client.clear_system_cache()
  21. def embedding_fn(self, paragraphs):
  22. '''文本向量化'''
  23. doc_vecs = [
  24. self.bge_model.encode(doc, normalize_embeddings=True).tolist()
  25. for doc in paragraphs
  26. ]
  27. return doc_vecs
  28. def add_documents(self,documents,dataid):
  29. data_id=[]
  30. ids=[]
  31. for _ in range(len(documents)):
  32. data_id.append({"dataId": dataid})
  33. ids.append(str(uuid.uuid1()))
  34. # embeddings=get_embeddings(documents)
  35. #向collection中添加文档与向量
  36. # ids = ["{}".format(uuid.uuid1()) for i in range(len(documents))]
  37. self.collection.add(
  38. embeddings=self.embedding_fn(documents),#每个文档的向量
  39. documents=documents,
  40. metadatas=data_id,
  41. ids=ids
  42. )
  43. c=self.collection.count()
  44. return ids,c
  45. def search(self,queryQ,top_n):
  46. results=self.collection.query(
  47. # query_texts=[query],
  48. query_embeddings=self.embedding_fn([queryQ]),
  49. n_results=top_n
  50. )
  51. return results
  52. if __name__=="__main__":
  53. pass
  54. # vector_db=LangChainChroma("demo")
  55. # with open("policy_test2.txt", "r", encoding="utf8") as f:
  56. # for line in tqdm(f):
  57. # # documents = [Document(page_content=line)]
  58. # docs = text_splitter.split_text(line)
  59. # a=vector_db.add_documents(docs)
  60. # print(a)
  61. # print("over")
  62. # user_query="鲍炳章同志?"
  63. # results=vector_db.search(user_query,3)
  64. # for para in results['documents'][0]:
  65. # print(para+'\n')