D2Cell-chatbot / app.py
kenghuoxiong's picture
Update app.py
98b2a1b verified
import gradio as gr
from huggingface_hub import InferenceClient
from langchain_community.chat_models import ChatOpenAI
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.schema import HumanMessage, SystemMessage
from langchain_community.document_loaders import DirectoryLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain_community.embeddings import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
import requests
from langchain_core.prompts import PromptTemplate
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
import gradio as gr
from openai import OpenAI
import os
TOKEN = os.getenv("HF_TOKEN")
def load_embedding_mode():
# embedding_model_dict = {"m3e-base": "/home/xiongwen/m3e-base"}
encode_kwargs = {"normalize_embeddings": False}
model_kwargs = {"device": 'cpu'}
return HuggingFaceEmbeddings(model_name="BAAI/bge-m3",
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
client = OpenAI(
base_url="https://api-inference.huggingface.co/v1/",
api_key=TOKEN,
)
def qwen_api(user_message, top_p=0.9,temperature=0.7, system_message='', max_tokens=1024, gradio_history=[]):
history = []
if gradio_history:
for message in history:
if message:
history.append({"role": "user", "content": message[0]})
history.append({"role": "assistant", "content": message[1]})
if system_message!='':
history.append({'role': 'system', 'content': system_message})
history.append({"role": "user", "content": user_message})
response = ""
for message in client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
# model="Qwen/Qwen1.5-4B-Chat",
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
messages=history,
):
token = message.choices[0].delta.content
response += token
return response
os.environ["OPENAI_API_BASE"] = "https://api-inference.huggingface.co/v1/"
os.environ["OPENAI_API_KEY"] = TOKEN
embedding = load_embedding_mode()
db = Chroma(persist_directory='./VecterStore2_512_txt/VecterStore2_512_txt', embedding_function=embedding)
prompt_template = """
{context}
The above content is a form of biological background knowledge. Please answer the questions according to the above content.
Question: {question}
Please be sure to answer the questions according to the background knowledge and attach the doi number of the information source when answering.
Answer in English:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}
retriever = db.as_retriever()
def langchain_chat(message, temperature, top_p, max_tokens):
llm = ChatOpenAI(
model="meta-llama/Meta-Llama-3-8B-Instruct",
# model="Qwen/Qwen1.5-4B-Chat",
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens)
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True
)
response = qa.invoke(message)['result']
return response
def chat(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
if len(history) == 0:
response = langchain_chat(message, temperature, top_p, max_tokens)
else:
response = qwen_api(message, gradio_history=history, max_tokens=max_tokens, top_p=top_p, temperature=temperature)
print(response)
yield response
return response
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
for message in client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
# model="Qwen/Qwen1.5-4B-Chat",
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
messages=messages,
):
token = message.choices[0].delta.content
response += token
yield response
chatbot = gr.Chatbot(height=600)
demo = gr.ChatInterface(
fn=chat,
fill_height=True,
chatbot=chatbot,
additional_inputs=[
gr.Textbox(label="System message"),
gr.Slider(minimum=1, maximum=1024, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch()