Spaces:
Running
Running
import random | |
import time | |
import os | |
import gradio as gr | |
from text_generation import Client | |
from conversation import get_default_conv_template | |
endpoint_url = os.environ.get("ENDPOINT_URL") | |
client = Client(endpoint_url, timeout=120) | |
eos_token = "</s>" | |
def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True): | |
conv = get_default_conv_template("vicuna").copy() | |
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT | |
for user, bot in history: | |
conv.append_message(roles['human'], user) | |
conv.append_message(roles["gpt"], bot) | |
msg = conv.get_prompt() | |
for response in client.generate_stream( | |
msg, | |
max_new_tokens=max_new_token, | |
top_p=top_p, | |
temperature=temperature, | |
do_sample=do_sample, | |
): | |
if not response.token.special: | |
yield response.token.text | |
# res = client.generate( | |
# msg, | |
# stop_sequences=["<|assistant|>", eos_token, "<|system|>", "<|user|>"], | |
# max_new_tokens=max_new_token, | |
# top_p=top_p, | |
# top_k=top_k, | |
# do_sample=do_sample, | |
# temperature=temperature, | |
# repetition_penalty=repetition_penalty, | |
# ) | |
# return [("assistant", res.generated_text)] | |
# | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox() | |
clear = gr.Button("Clear") | |
def user(user_message, history): | |
return "", history + [[user_message, None]] | |
def bot(history): | |
# history = list of [[user_message, bot_message], ...] | |
import ipdb | |
ipdb.set_trace() | |
bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"]) | |
history[-1][1] = "" | |
for character in bot_message: | |
history[-1][1] += character | |
time.sleep(0.05) | |
yield history | |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
generate_response, chatbot, chatbot | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
demo.queue() | |
demo.launch() | |