Spaces:
Running
Running
from langchain_community.chat_models import ChatOpenAI | |
from langchain.chains.retrieval_qa.base import RetrievalQA | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain.schema import HumanMessage, SystemMessage | |
import os | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain_community.vectorstores import Chroma | |
import gradio as gr | |
import requests | |
from langchain_core.prompts import PromptTemplate | |
from qwen_api import qwen_api | |
def load_documents(directory='../langchain-database'): | |
loader = DirectoryLoader(directory, show_progress=True, use_multithreading=True, silent_errors=True) | |
documents = loader.load() | |
text_spliter = CharacterTextSplitter(chunk_size=2048, chunk_overlap=200) | |
split_docs = text_spliter.split_documents(documents) | |
print(len(split_docs)) | |
return split_docs | |
def load_embedding_mode(): | |
# embedding_model_dict = {"m3e-base": "/home/xiongwen/m3e-base"} | |
encode_kwargs = {"normalize_embeddings": False} | |
model_kwargs = {"device": 'cuda'} | |
return HuggingFaceEmbeddings(model_name="/home/xiongwen/bge-m3", | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs) | |
def store_chroma(docs,embedding,persist_directory='./VecterStore2'): | |
db = Chroma.from_documents(docs, embedding) | |
# db.persist() | |
return db | |
def chat(question, history): | |
if len(history) == 0: | |
response = qa.invoke(question)['result'] | |
else: | |
response = qwen_api(question, gradio_history=history) | |
return response | |
if __name__ == '__main__': | |
embedding = load_embedding_mode() | |
db = Chroma(persist_directory='/home/xiongwen/llama2-a40-ner/langchain-qwen/VecterStore2_512_txt/VecterStore2_512_txt', embedding_function=embedding) | |
os.environ["OPENAI_API_BASE"] = 'http://localhost:8000/v1' | |
os.environ["OPENAI_API_KEY"] = 'none' | |
llm = ChatOpenAI( | |
model="/home/xiongwen/Qwen1.5-110B-Chat", | |
temperature=0.8,) | |
prompt_template = """ | |
{context} | |
The above content is a form of biological background knowledge. Please answer the questions according to the above content. Please be sure to answer the questions according to the background knowledge and attach the doi number of the information source when answering. | |
Question: {question} | |
Answer in English:""" | |
PROMPT = PromptTemplate( | |
template=prompt_template, input_variables=["context", "question"] | |
) | |
chain_type_kwargs = {"prompt": PROMPT} | |
# messages = [ | |
# SystemMessage(content="you are an assistant in biology."), | |
# HumanMessage(content="which gene should be knocked to produce hyaluronic acid?") | |
# ] | |
# response = llm(messages) | |
# print('----------') | |
# print(response.content) | |
# print('----------') | |
# interface = gr.ChatInterface(chat) | |
# interface.launch(inbrowser=True) | |
retriever = db.as_retriever() | |
print(dir(retriever)) | |
question = "which gene should be knocked in the process of producing ethanol in E.coli?" | |
# docs = retriever.get_relevant_documents(question, top_k=10) | |
# print(docs) | |
# docs = db.similarity_search(question, k=5) | |
# print(docs) | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
chain_type_kwargs=chain_type_kwargs, | |
return_source_documents=True | |
) | |
interface = gr.ChatInterface( | |
fn=chat, | |
chatbot=gr.Chatbot(height=800, bubble_full_width=False), | |
theme=gr.themes.Default(spacing_size='sm', radius_size='sm'), | |
examples=['which gene should be knocked in the process of producing ethanol in Saccharomyces cerevisiae?'] | |
) | |
interface.launch(inbrowser=True) | |
# response = qa.invoke("which gene should be knocked in the process of producing ethanol in Saccharomyces cerevisiae?") | |
# # response = qa({"query": question}) | |
# print('----------') | |
# print(response) | |
# print('----------') | |
# print(response['source_documents']) | |