File size: 2,123 Bytes
ca877b2
 
7d1962a
ab13bd6
 
ca877b2
7d1962a
ab13bd6
ca877b2
 
ab13bd6
 
 
 
ca877b2
0cfd320
 
ca877b2
 
 
0cfd320
032e12f
ca877b2
 
 
 
 
 
 
 
 
ab13bd6
ca877b2
 
 
 
 
 
 
 
 
 
 
 
ab13bd6
 
ca877b2
 
ab13bd6
ca877b2
 
ab13bd6
ca877b2
 
ab13bd6
ca877b2
 
 
 
 
 
 
 
ab13bd6
ca877b2
 
 
 
 
 
 
ab13bd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import random
import time
import os
import gradio as gr
from text_generation import Client
from conversation import get_default_conv_template


endpoint_url = os.environ.get("ENDPOINT_URL")
client = Client(endpoint_url, timeout=120)
eos_token = "</s>"



def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True):
    conv = get_default_conv_template("vicuna").copy()
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
    for user, bot in history:
        conv.append_message(roles['human'], user)
        conv.append_message(roles["gpt"], bot)
    msg = conv.get_prompt()

    for response in client.generate_stream(
            msg,
            max_new_tokens=max_new_token,
            top_p=top_p,
            temperature=temperature,
            do_sample=do_sample,
    ):
        if not response.token.special:
            yield response.token.text

    # res = client.generate(
    #     msg,
    #     stop_sequences=["<|assistant|>", eos_token, "<|system|>", "<|user|>"],
    #     max_new_tokens=max_new_token,
    #     top_p=top_p,
    #     top_k=top_k,
    #     do_sample=do_sample,
    #     temperature=temperature,
    #     repetition_penalty=repetition_penalty,
    # )
    # return [("assistant", res.generated_text)]
    #
with gr.Blocks() as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox()
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history):
        # history = list of [[user_message, bot_message], ...]

        import ipdb
        ipdb.set_trace()
        bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
        history[-1][1] = ""
        for character in bot_message:
            history[-1][1] += character
            time.sleep(0.05)
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        generate_response, chatbot, chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)
    
demo.queue()
demo.launch()