查询知识库应用
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

65 lines
2.3 KiB

# from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
import chromadb
from tqdm import tqdm
from langchain.text_splitter import RecursiveCharacterTextSplitter
import uuid
class LangChainChroma:
def __init__(self,collection_name):
dirPath="D:/ellie/project/2023/analyst_assistant/chromaDB/allField/"
# dirPath="./allField/"
self.chroma_client = chromadb.PersistentClient(path=dirPath+collection_name)
# chroma_client=chromadb.Client(Settings(allow_reset=True,persist_directory="../allField/demo/"))
self.collection=self.chroma_client.get_or_create_collection(name=collection_name,metadata={"hnsw:space": "cosine"})
# model = SentenceTransformer('text_analysis/shibing624/text2vec-base-chinese')
model = SentenceTransformer('shibing624/text2vec-base-chinese')
self.bge_model = model
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000,chunk_overlap=0,separators=["\n\n", "\n", " ", ""])
def db_close(self):
self.chroma_client.clear_system_cache()
def embedding_fn(self, paragraphs):
'''文本向量化'''
doc_vecs = [
self.bge_model.encode(doc, normalize_embeddings=True).tolist()
for doc in paragraphs
]
return doc_vecs
def add_documents(self,documents):
# embeddings=get_embeddings(documents)
#向collection中添加文档与向量
ids = ["{}".format(uuid.uuid1()) for i in range(len(documents))]
self.collection.add(
embeddings=self.embedding_fn(documents),#每个文档的向量
documents=documents,
ids=ids
)
c=self.collection.count()
return ids,c
def search(self,query,top_n):
results=self.collection.query(
# query_texts=[query],
query_embeddings=self.embedding_fn([query]),
n_results=top_n
)
return results
if __name__=="__main__":
db_name="gameStart"
vector_db=LangChainChroma(db_name)
#计数
# vec_count=vector_db.collection.count()
# print(res2)
#删除
# vector_db.collection.delete(ids=["1bb71e6b-173b-11ef-9151-e4aaea9df84e"])
#查询
# res = vector_db.collection.get(where_document={"$contains": "政府"})
# print(res)