demo_chatbot / app.py
khanhdhq's picture
Duplicate from mandar100/chatbot_bloom3B
662a39e
raw
history blame
3.16 kB
#!/usr/bin/env python
# coding: utf-8
# In[ ]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import re
def cleaning_history_tuple(history):
s=sum(history,())
s=list(s)
s2=""
for i in s:
i=re.sub("\n", '', i)
i=re.sub("<p>", '', i)
i=re.sub("</p>", '', i)
s2=s2+i+'\n'
return s2
def ai_output(string1,string2):
a1=len(string1)
a2=len(string2)
string3=string2[a1:]
sub1="A:"
sub2="User"
#sub3="\n"
try:
try:
idx1=string3.index(sub1)
response=string3[:idx1]
return response
except:
idx1=string3.index(sub2)
response=string3[:idx1]
return response
except:
return string3
model4 = AutoModelForCausalLM.from_pretrained("bigscience/bloom-3b")
tokenizer4 = AutoTokenizer.from_pretrained("bigscience/bloom-3b")
def predict(input,initial_prompt, temperature=0.7,top_p=1,top_k=5,max_tokens=64,no_repeat_ngram_size=1,num_beams=6,do_sample=True, history=[]):
s = cleaning_history_tuple(history)
s = s+ "\n"+ "User: "+ input + "\n" + "Assistant: "
s2=initial_prompt+" " + s
input_ids = tokenizer4.encode(str(s2), return_tensors="pt")
response = model4.generate(input_ids, min_length = 10,
max_new_tokens=int(max_tokens),
top_k=int(top_k),
top_p=float(top_p),
temperature=float(temperature),
no_repeat_ngram_size=int(no_repeat_ngram_size),
num_beams = int(num_beams),
do_sample = bool(do_sample),
)
response2 = tokenizer4.decode(response[0])
print("Response after decoding tokenizer: ",response2)
print("\n\n")
response3=ai_output(s2,response2)
input="User: "+input
response3="Assistant: "+ response3
history.append((input, response3))
return history, history
#gr.Interface(fn=predict,title="BLOOM-3b",
# inputs=["text","text","text","text","text","text","text","text","text",'state'],
#
# outputs=["chatbot",'state']).launch()
gr.Interface(inputs=[gr.Textbox(label="input", lines=1, value=""),
gr.Textbox(label="initial_prompt", lines=1, value=prompt),
gr.Textbox(label="temperature", lines=1, value=0.7),
gr.Textbox(label="top_p", lines=1, value=1),
gr.Textbox(label="top_k", lines=1, value=5),
gr.Textbox(label="max_tokens", lines=1, value=64),
gr.Textbox(label="no_repeat_ngram_size", lines=1, value=1),
gr.Textbox(label="num_beams", lines=1, value=6),
gr.Textbox(label="do_sample", lines=1, value="True"), 'state'],
fn=predict, title="OPT-6.7B", outputs=["chatbot",'state']
#inputs=["text","text","text","text","text","text","text","text","text",'state'],
).launch()