Spaces:
Runtime error
Runtime error
integrate model in app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
def add_text(history, text):
|
5 |
history = history + [[text, None]]
|
@@ -7,11 +14,91 @@ def add_text(history, text):
|
|
7 |
|
8 |
def process_input(history):
|
9 |
inp = history[-1][0]
|
10 |
-
response = "I have received your input, which is: \n" + inp
|
|
|
11 |
history[-1][1] = response
|
12 |
return history
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
gr.Markdown('''
|
16 |
## **CareerPal**
|
17 |
here to ease your anxiety about your future
|
@@ -31,6 +118,3 @@ with gr.Blocks() as demo:
|
|
31 |
clear_btn.click(lambda: None, inputs=None, outputs=output_box, queue=False)
|
32 |
|
33 |
demo.launch() # server_port=7860, show_api=False, share=False, inline=True) # , share = True, inline = True)
|
34 |
-
|
35 |
-
# set FLASK_APP=app.py
|
36 |
-
# flask run -h localhost -p 7860
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForSeq2SeqLM
|
4 |
+
from langchain import PromptTemplate
|
5 |
+
from langchain.llms import HuggingFacePipeline
|
6 |
+
from langchain.chains.question_answering import load_qa_chain
|
7 |
+
from langchain.memory import ConversationSummaryBufferMemory
|
8 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
9 |
+
from langchain.vectorstores import Chroma
|
10 |
|
11 |
def add_text(history, text):
|
12 |
history = history + [[text, None]]
|
|
|
14 |
|
15 |
def process_input(history):
|
16 |
inp = history[-1][0]
|
17 |
+
# response = "I have received your input, which is: \n" + inp
|
18 |
+
response = chat_bot.chat(inp)
|
19 |
history[-1][1] = response
|
20 |
return history
|
21 |
|
22 |
+
def build_qa_chain():
|
23 |
+
torch.cuda.empty_cache()
|
24 |
+
# Defining our prompt content.
|
25 |
+
# langchain will load our similar documents as {context}
|
26 |
+
template = """You are a chatbot having a conversation with a human. You are asked to answer career questions, and you are helping the human apply for jobs.
|
27 |
+
Given the following extracted parts of a long document and a question, answer the user question. If you don't know, say that you do not know.
|
28 |
+
|
29 |
+
{context}
|
30 |
+
|
31 |
+
{chat_history}
|
32 |
+
|
33 |
+
{human_input}
|
34 |
+
|
35 |
+
Response:
|
36 |
+
"""
|
37 |
+
prompt = PromptTemplate(input_variables=['context', 'human_input', 'chat_history'], template=template)
|
38 |
+
|
39 |
+
# Increase max_new_tokens for a longer response
|
40 |
+
# Other settings might give better results! Play around
|
41 |
+
model_name = "databricks/dolly-v2-3b" # can use dolly-v2-3b, dolly-v2-7b or dolly-v2-12b for smaller model and faster inferences.
|
42 |
+
instruct_pipeline = pipeline(model=model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto",
|
43 |
+
return_full_text=True, max_new_tokens=256, top_p=0.95, top_k=50)
|
44 |
+
hf_pipe = HuggingFacePipeline(pipeline=instruct_pipeline)
|
45 |
+
|
46 |
+
# Add a summarizer to our memory conversation
|
47 |
+
# Let's make sure we don't summarize the discussion too much to avoid losing to much of the content
|
48 |
+
|
49 |
+
# Models we'll use to summarize our chat history
|
50 |
+
# We could use one of these models: https://huggingface.co/models?filter=summarization. facebook/bart-large-cnn gives great results, we'll use t5-small for memory
|
51 |
+
summarize_model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
|
52 |
+
summarize_tokenizer = AutoTokenizer.from_pretrained("t5-small", padding_side="left", model_max_length = 512)
|
53 |
+
pipe_summary = pipeline("summarization", model=summarize_model, tokenizer=summarize_tokenizer) #, max_new_tokens=500, min_new_tokens=300
|
54 |
+
# langchain pipeline doesn't support summarization yet, we added it as temp fix in the companion notebook _resources/00-init
|
55 |
+
hf_summary = HuggingFacePipeline(pipeline=pipe_summary)
|
56 |
+
#will keep 500 token and then ask for a summary. Removes prefix as our model isn't trained on specific chat prefix and can get confused.
|
57 |
+
memory = ConversationSummaryBufferMemory(llm=hf_summary, memory_key="chat_history", input_key="human_input", max_token_limit=500, human_prefix = "", ai_prefix = "")
|
58 |
+
|
59 |
+
# Set verbose=True to see the full prompt:
|
60 |
+
print("loading chain, this can take some time...")
|
61 |
+
return load_qa_chain(llm=hf_pipe, chain_type="stuff", verbose=True, prompt=prompt, memory=memory)
|
62 |
+
|
63 |
+
class ChatBot():
|
64 |
+
def __init__(self, db):
|
65 |
+
self.reset_context()
|
66 |
+
self.db = db
|
67 |
+
|
68 |
+
def reset_context(self):
|
69 |
+
self.sources = []
|
70 |
+
self.discussion = []
|
71 |
+
# Building the chain will load Dolly and can take some time depending on the model size and your GPU
|
72 |
+
self.qa_chain = build_qa_chain()
|
73 |
+
|
74 |
+
def get_similar_docs(self, question, similar_doc_count):
|
75 |
+
return self.db.similarity_search(question, k=similar_doc_count)
|
76 |
+
|
77 |
+
def chat(self, question):
|
78 |
+
# Keep the last 3 discussion to search similar content
|
79 |
+
self.discussion.append(question)
|
80 |
+
similar_docs = self.get_similar_docs(" \n".join(self.discussion[-3:]), similar_doc_count=2)
|
81 |
+
# Remove similar doc if they're already in the last questions (as it's already in the history)
|
82 |
+
similar_docs = [doc for doc in similar_docs if doc.metadata['source'] not in self.sources[-3:]]
|
83 |
+
|
84 |
+
result = self.qa_chain({"input_documents": similar_docs, "human_input": question})
|
85 |
+
# Cleanup the answer for better display:
|
86 |
+
answer = result['output_text'].strip().capitalize()
|
87 |
+
result_html = f"<p><blockquote style=\"font-size:18px\">{answer}</blockquote></p>"
|
88 |
+
result_html += "<p><hr/></p>"
|
89 |
+
for d in result["input_documents"]:
|
90 |
+
source_id = d.metadata["source"]
|
91 |
+
self.sources.append(source_id)
|
92 |
+
result_html += f"<p>(Source: <a href=\"https://workplace.stackexchange.com/a/{source_id}\">{source_id}</a>)</p>"
|
93 |
+
return result_html
|
94 |
+
|
95 |
with gr.Blocks() as demo:
|
96 |
+
global chat_bot
|
97 |
+
workplace_vector_db_path = "workplace_db"
|
98 |
+
|
99 |
+
hf_embed = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
100 |
+
chroma_db = Chroma(collection_name="workplace_docs", embedding_function=hf_embed, persist_directory=workplace_vector_db_path)
|
101 |
+
chat_bot = ChatBot(chroma_db)
|
102 |
gr.Markdown('''
|
103 |
## **CareerPal**
|
104 |
here to ease your anxiety about your future
|
|
|
118 |
clear_btn.click(lambda: None, inputs=None, outputs=output_box, queue=False)
|
119 |
|
120 |
demo.launch() # server_port=7860, show_api=False, share=False, inline=True) # , share = True, inline = True)
|
|
|
|
|
|