You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
|
|
# from chromadb.config import Settings from sentence_transformers import SentenceTransformer import chromadb from tqdm import tqdm from langchain.text_splitter import RecursiveCharacterTextSplitter import uuid
class LangChainChroma: def __init__(self,collection_name): # dirPath="D:/ellie/project/2023/analyst_assistant/chromaDB/allField/" dirPath="./allField/" self.chroma_client = chromadb.PersistentClient(path=dirPath+collection_name) # chroma_client=chromadb.Client(Settings(allow_reset=True,persist_directory="../allField/demo/")) self.collection=self.chroma_client.get_or_create_collection(name=collection_name,metadata={"hnsw:space": "cosine"}) model = SentenceTransformer('text_analysis/shibing624/text2vec-base-chinese') # model = SentenceTransformer('shibing624/text2vec-base-chinese') self.bge_model = model self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=200,chunk_overlap=0,separators=["\n\n", "\n", " ", "。"]) # chroma_client.reset()
def db_close(self): self.chroma_client.clear_system_cache()
def embedding_fn(self, paragraphs): '''文本向量化''' doc_vecs = [ self.bge_model.encode(doc, normalize_embeddings=True).tolist() for doc in paragraphs ] return doc_vecs
def add_documents(self,documents,dataid): data_id=[] ids=[] for _ in range(len(documents)): data_id.append({"dataId": dataid}) ids.append(str(uuid.uuid1())) # embeddings=get_embeddings(documents) #向collection中添加文档与向量 # ids = ["{}".format(uuid.uuid1()) for i in range(len(documents))] self.collection.add( embeddings=self.embedding_fn(documents),#每个文档的向量 documents=documents, metadatas=data_id, ids=ids ) c=self.collection.count() return ids,c
def search(self,queryQ,top_n): results=self.collection.query( # query_texts=[query], query_embeddings=self.embedding_fn([queryQ]), n_results=top_n ) return results
if __name__=="__main__": pass # vector_db=LangChainChroma("demo") # with open("policy_test2.txt", "r", encoding="utf8") as f: # for line in tqdm(f): # # documents = [Document(page_content=line)] # docs = text_splitter.split_text(line) # a=vector_db.add_documents(docs) # print(a) # print("over")
# user_query="鲍炳章同志?" # results=vector_db.search(user_query,3) # for para in results['documents'][0]: # print(para+'\n')
|