查询知识库应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

65 lines
2.3 KiB

  1. # from chromadb.config import Settings
  2. from sentence_transformers import SentenceTransformer
  3. import chromadb
  4. from tqdm import tqdm
  5. from langchain.text_splitter import RecursiveCharacterTextSplitter
  6. import uuid
  7. class LangChainChroma:
  8. def __init__(self,collection_name):
  9. dirPath="D:/ellie/project/2023/analyst_assistant/chromaDB/allField/"
  10. # dirPath="./allField/"
  11. self.chroma_client = chromadb.PersistentClient(path=dirPath+collection_name)
  12. # chroma_client=chromadb.Client(Settings(allow_reset=True,persist_directory="../allField/demo/"))
  13. self.collection=self.chroma_client.get_or_create_collection(name=collection_name,metadata={"hnsw:space": "cosine"})
  14. # model = SentenceTransformer('text_analysis/shibing624/text2vec-base-chinese')
  15. model = SentenceTransformer('shibing624/text2vec-base-chinese')
  16. self.bge_model = model
  17. self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000,chunk_overlap=0,separators=["\n\n", "\n", " ", ""])
  18. def db_close(self):
  19. self.chroma_client.clear_system_cache()
  20. def embedding_fn(self, paragraphs):
  21. '''文本向量化'''
  22. doc_vecs = [
  23. self.bge_model.encode(doc, normalize_embeddings=True).tolist()
  24. for doc in paragraphs
  25. ]
  26. return doc_vecs
  27. def add_documents(self,documents):
  28. # embeddings=get_embeddings(documents)
  29. #向collection中添加文档与向量
  30. ids = ["{}".format(uuid.uuid1()) for i in range(len(documents))]
  31. self.collection.add(
  32. embeddings=self.embedding_fn(documents),#每个文档的向量
  33. documents=documents,
  34. ids=ids
  35. )
  36. c=self.collection.count()
  37. return ids,c
  38. def search(self,query,top_n):
  39. results=self.collection.query(
  40. # query_texts=[query],
  41. query_embeddings=self.embedding_fn([query]),
  42. n_results=top_n
  43. )
  44. return results
  45. if __name__=="__main__":
  46. db_name="gameStart"
  47. vector_db=LangChainChroma(db_name)
  48. #计数
  49. # vec_count=vector_db.collection.count()
  50. # print(res2)
  51. #删除
  52. # vector_db.collection.delete(ids=["1bb71e6b-173b-11ef-9151-e4aaea9df84e"])
  53. #查询
  54. # res = vector_db.collection.get(where_document={"$contains": "政府"})
  55. # print(res)