File size: 3,159 Bytes
73e8b86
 
 
 
67812d2
73e8b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from random import sample

from fastchat.serve import gradio_web_server
from fastchat.serve.gradio_web_server import bot_response
import gradio as gr

# TODO(#1): Add more models.
SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"]


def user(state_a, state_b, user_message):
  if state_a is None and state_b is None:
    model_pair = sample(SUPPORTED_MODELS, 2)
    state_a = gradio_web_server.State(model_pair[0])
    state_b = gradio_web_server.State(model_pair[1])

  for state in [state_a, state_b]:
    state.conv.append_message(state.conv.roles[0], user_message)
    state.conv.append_message(state.conv.roles[1], None)
    state.skip_next = False

  empty_prompt = ""

  return [
      state_a, state_b,
      state_a.to_gradio_chatbot(),
      state_b.to_gradio_chatbot(), state_a.model_name, state_b.model_name,
      empty_prompt
  ]


def bot(state_a, state_b, request: gr.Request):
  if state_a is None or state_b is None:
    raise RuntimeError(f"states cannot be None, got [{state_a}, {state_b}]")

  generators = []
  for state in [state_a, state_b]:
    try:
      # TODO(#1): Allow user to set configuration.
      # bot_response returns a generator yielding states and chatbots.
      generator = bot_response(state,
                               temperature=0.9,
                               top_p=0.9,
                               max_new_tokens=100,
                               request=request)
      generators.append(generator)

    # TODO(#1): Narrow down the exception type.
    except Exception as e:  # pylint: disable=broad-except
      print(f"Error in bot_response: {e}")
      raise e

  new_chatbots = [None, None]
  while True:
    stop = True

    for i in range(2):
      try:
        generator = next(generators[i])
        states[i], new_chatbots[i] = generator[0], generator[1]
        stop = False
      except StopIteration:
        pass

    yield [state_a, state_b] + new_chatbots

    if stop:
      break


with gr.Blocks() as app:
  with gr.Row():
    response_type = gr.Radio(
        ["Summarization", "Translation"],
        value="Summarization",
        label="Response type",
        info="Choose the type of response you want from the model.")
    language = gr.Dropdown(["Korean", "English"],
                           value="Korean",
                           label="Language",
                           info="Choose the target language.")

  chatbots = [None, None]
  with gr.Row():
    chatbots[0] = gr.Chatbot(label="Model A")
    chatbots[1] = gr.Chatbot(label="Model B")

  model_names = [None, None]
  with gr.Accordion("Show models", open=False):
    with gr.Row():
      model_names[0] = gr.Textbox(label="Model A", interactive=False)
      model_names[1] = gr.Textbox(label="Model B", interactive=False)

  prompt = gr.Textbox(label="Prompt")

  states = [gr.State(None), gr.State(None)]
  prompt.submit(user,
                states + [prompt],
                states + chatbots + model_names + [prompt],
                queue=False).then(bot, states, states + chatbots)

if __name__ == "__main__":
  # We need to enable queue to use generators.
  app.queue()
  app.launch()