You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
54 lines
2.1 KiB
54 lines
2.1 KiB
#coding:utf8
|
|
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
|
|
# from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain.vectorstores import Chroma
|
|
# from langchain.document_loaders import TextLoader
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
from langchain.docstore.document import Document
|
|
|
|
|
|
class LangChainChroma():
|
|
def __init__(self):
|
|
self.embedding_function = SentenceTransformerEmbeddings(model_name="text_analysis/shibing624/text2vec-base-chinese")
|
|
|
|
def addChroma(self,data,baseName,logging,chunkSize=500):
|
|
fpath='allField/'+baseName
|
|
db = Chroma(collection_name=baseName, embedding_function=self.embedding_function,persist_directory=fpath)
|
|
documents = [Document(page_content=data)]
|
|
# text_splitter = CharacterTextSplitter(separator=",",chunk_size=500,chunk_overlap=0,length_function=len)
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chunkSize,
|
|
chunk_overlap=0,
|
|
separators=["\n\n", "\n"," ", "。"]
|
|
)
|
|
docs = text_splitter.split_documents(documents)
|
|
res=db.add_documents(documents=docs)
|
|
db_count = db._collection.count()
|
|
logging.info('当前数据划分{}个块,块大小{}。数据库{}共有{}个块'.format(len(docs),chunkSize,baseName,db_count))
|
|
return res
|
|
|
|
def similarity_search(self,baseName,prompt,topn=3):
|
|
# fpath="allField/"+baseName
|
|
fpath="/opt/analyze/apps/chromaDB/allField/"+baseName
|
|
db = Chroma(collection_name=baseName, embedding_function=self.embedding_function,persist_directory=fpath)
|
|
docs = db.similarity_search_with_score(prompt,k=topn)
|
|
return docs
|
|
|
|
|
|
if __name__=="__main__":
|
|
LC=LangChainChroma()
|
|
db = Chroma(collection_name='policy', embedding_function=LC.embedding_function,
|
|
persist_directory='../policy')
|
|
# db.delete('a5909489-5bc4-4b30-a949-9bd4bb06c477')
|
|
# #创建数据库
|
|
# with open("policy_test.txt", "r", encoding="utf8") as f:
|
|
# for line in tqdm(f):
|
|
# LC.createChroma(line,db)
|
|
# print("over")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|