Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM | |
from langchain import PromptTemplate | |
from langchain.llms import HuggingFacePipeline | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.memory import ConversationSummaryBufferMemory | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
def add_text(history, text): | |
history = history + [[text, None]] | |
return history, gr.update(value="", interactive=False) | |
def process_input(history): | |
inp = history[-1][0] | |
# response = "I have received your input, which is: \n" + inp | |
if len(history) <= 2: | |
chat_bot.reset_context() | |
response = chat_bot.chat(inp) | |
history[-1][1] = response | |
return history | |
def build_qa_chain(): | |
torch.cuda.empty_cache() | |
# Defining our prompt content. | |
# langchain will load our similar documents as {context} | |
template = """You are a chatbot having a conversation with a human. You are asked to answer career questions, and you are helping the human apply for jobs. | |
Given the following extracted parts of a long document and a question, answer the user question. If you don't know, say that you do not know. | |
{context} | |
{chat_history} | |
{human_input} | |
Response: | |
""" | |
prompt = PromptTemplate(input_variables=['context', 'human_input', 'chat_history'], template=template) | |
# Increase max_new_tokens for a longer response | |
# Other settings might give better results! Play around | |
model_name = "databricks/dolly-v2-3b" # can use dolly-v2-3b, dolly-v2-7b or dolly-v2-12b for smaller model and faster inferences. | |
instruct_pipeline = pipeline(model=model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", | |
return_full_text=True, max_new_tokens=256, top_p=0.95, top_k=50) | |
hf_pipe = HuggingFacePipeline(pipeline=instruct_pipeline) | |
# Add a summarizer to our memory conversation | |
# Let's make sure we don't summarize the discussion too much to avoid losing to much of the content | |
# Models we'll use to summarize our chat history | |
# We could use one of these models: https://huggingface.co/models?filter=summarization. facebook/bart-large-cnn gives great results, we'll use t5-small for memory | |
summarize_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True) | |
summarize_tokenizer = AutoTokenizer.from_pretrained("t5-small", padding_side="left", model_max_length = 512) | |
pipe_summary = pipeline("summarization", model=summarize_model, tokenizer=summarize_tokenizer) #, max_new_tokens=500, min_new_tokens=300 | |
# langchain pipeline doesn't support summarization yet, we added it as temp fix in the companion notebook _resources/00-init | |
hf_summary = HuggingFacePipeline(pipeline=pipe_summary) | |
#will keep 500 token and then ask for a summary. Removes prefix as our model isn't trained on specific chat prefix and can get confused. | |
memory = ConversationSummaryBufferMemory(llm=hf_summary, memory_key="chat_history", input_key="human_input", max_token_limit=500, human_prefix = "", ai_prefix = "") | |
# Set verbose=True to see the full prompt: | |
print("loading chain, this can take some time...") | |
return load_qa_chain(llm=hf_pipe, chain_type="stuff", verbose=True, prompt=prompt, memory=memory) | |
class ChatBot(): | |
def __init__(self, db): | |
self.reset_context() | |
self.db = db | |
def reset_context(self): | |
self.sources = [] | |
self.discussion = [] | |
# Building the chain will load Dolly and can take some time depending on the model size and your GPU | |
self.qa_chain = build_qa_chain() | |
def get_similar_docs(self, question, similar_doc_count): | |
return self.db.similarity_search(question, k=similar_doc_count) | |
def chat(self, question): | |
# Keep the last 3 discussion to search similar content | |
self.discussion.append(question) | |
similar_docs = self.get_similar_docs(" \n".join(self.discussion[-3:]), similar_doc_count=2) | |
# Remove similar doc if they're already in the last questions (as it's already in the history) | |
similar_docs = [doc for doc in similar_docs if doc.metadata['source'] not in self.sources[-3:]] | |
result = self.qa_chain({"input_documents": similar_docs, "human_input": question}) | |
# Cleanup the answer for better display: | |
answer = result['output_text'].strip().capitalize() | |
result_html = f"<p><blockquote style=\"font-size:18px\">{answer}</blockquote></p>" | |
result_html += "<p><hr/></p>" | |
for d in result["input_documents"]: | |
source_id = d.metadata["source"] | |
self.sources.append(source_id) | |
result_html += f"<p>(Source: <a href=\"https://workplace.stackexchange.com/a/{source_id}\">{source_id}</a>)</p>" | |
return result_html | |
with gr.Blocks() as demo: | |
global chat_bot | |
workplace_vector_db_path = "workplace_db" | |
hf_embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") | |
chroma_db = Chroma(collection_name="workplace_docs", embedding_function=hf_embed, persist_directory=workplace_vector_db_path) | |
chat_bot = ChatBot(chroma_db) | |
with gr.Row(): | |
output_box = gr.Chatbot([[None, "Welcome! What can I help you with today?"]], show_label=False).style(height=450) | |
with gr.Row(): # TODO: Box or Group instead of row? | |
with gr.Column(scale=7): | |
input_box = gr.Textbox(show_label=False, placeholder="Ask something here and press enter...").style(container=False) | |
with gr.Column(scale=1): | |
clear_btn = gr.Button(value="Clear") | |
txt_msg = input_box.submit(add_text, inputs=[output_box, input_box], outputs=[output_box, input_box], | |
queue=False).then(process_input, output_box, output_box) | |
txt_msg.then(lambda: gr.update(interactive=True), inputs=None, outputs=input_box, queue=False) | |
clear_btn.click(lambda: None, inputs=None, outputs=output_box, queue=False) | |
demo.launch() # server_port=7860, show_api=False, share=False, inline=True) # , share = True, inline = True) | |