You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.4 KiB
66 lines
2.4 KiB
# from chromadb.config import Settings
|
|
from sentence_transformers import SentenceTransformer
|
|
import chromadb
|
|
from tqdm import tqdm
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
import uuid
|
|
|
|
class LangChainChroma:
|
|
def __init__(self,collection_name):
|
|
dirPath="../chromaDB/allField/"
|
|
self.chroma_client = chromadb.PersistentClient(path=dirPath+collection_name)
|
|
# chroma_client=chromadb.Client(Settings(allow_reset=True,persist_directory="../allField/demo/"))
|
|
self.collection=self.chroma_client.get_or_create_collection(name=collection_name,metadata={"hnsw:space": "cosine"})
|
|
model = SentenceTransformer('text_analysis/shibing624/text2vec-base-chinese')
|
|
# model = SentenceTransformer('shibing624/text2vec-base-chinese')
|
|
self.bge_model = model
|
|
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,chunk_overlap=0,separators=["\n\n", "\n", " ", "。"])
|
|
# chroma_client.reset()
|
|
|
|
def db_close(self):
|
|
self.chroma_client.clear_system_cache()
|
|
|
|
def embedding_fn(self, paragraphs):
|
|
'''文本向量化'''
|
|
doc_vecs = [
|
|
self.bge_model.encode(doc, normalize_embeddings=True).tolist()
|
|
for doc in paragraphs
|
|
]
|
|
return doc_vecs
|
|
|
|
def add_documents(self,documents):
|
|
# embeddings=get_embeddings(documents)
|
|
#向collection中添加文档与向量
|
|
ids = ["id-{}".format(uuid.uuid1()) for i in range(len(documents))]
|
|
self.collection.add(
|
|
embeddings=self.embedding_fn(documents),#每个文档的向量
|
|
documents=documents,
|
|
ids=ids
|
|
)
|
|
# logging.info('当前数据划分{}个块。数据库共有{}个块'.format(len(documents),db_count))
|
|
return ids
|
|
|
|
def search(self,queryQ,top_n):
|
|
results=self.collection.query(
|
|
# query_texts=[query],
|
|
query_embeddings=self.embedding_fn([queryQ]),
|
|
n_results=top_n
|
|
)
|
|
return results
|
|
|
|
|
|
# vector_db=LangChainChroma("demo")
|
|
# with open("policy_test2.txt", "r", encoding="utf8") as f:
|
|
# for line in tqdm(f):
|
|
# # documents = [Document(page_content=line)]
|
|
# docs = text_splitter.split_text(line)
|
|
# a=vector_db.add_documents(docs)
|
|
# print(a)
|
|
# print("over")
|
|
|
|
|
|
# user_query="鲍炳章同志?"
|
|
# results=vector_db.search(user_query,3)
|
|
# for para in results['documents'][0]:
|
|
# print(para+'\n')
|
|
|