Xianbao QIAN commited on
Commit
93e4eee
1 Parent(s): d76943e

update new ui

Browse files
Files changed (1) hide show
  1. app.py +88 -30
app.py CHANGED
@@ -7,41 +7,102 @@ import sambanova
7
  def generate(
8
  message: str,
9
  chat_history: list[tuple[str, str]],
10
- max_new_tokens: int = 1024,
 
11
  temperature: float = 0.6,
12
  top_p: float = 0.9,
13
  top_k: int = 50,
14
  repetition_penalty: float = 1.2,
15
  ) -> Iterator[str]:
16
- conversation = []
17
- for user, assistant in chat_history:
18
- conversation.extend(
19
- [
20
- {"role": "user", "content": user},
21
- {"role": "assistant", "content": assistant},
22
- ]
23
- )
24
- conversation.append({"role": "user", "content": message})
25
 
26
  outputs = []
27
- for text in sambanova.Streamer(conversation, new_tokens=max_new_tokens,
28
- temperature=temperature, top_k=top_k, top_p=top_p):
 
 
 
29
  outputs.append(text)
30
  yield "".join(outputs)
31
 
32
- MAX_MAX_NEW_TOKENS = 2048
33
- DEFAULT_MAX_NEW_TOKENS = 1024
 
34
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  chat_interface = gr.ChatInterface(
37
- fn=generate,
38
  additional_inputs=[
 
 
39
  gr.Slider(
40
- label="Max new tokens",
41
  minimum=1,
42
- maximum=MAX_MAX_NEW_TOKENS,
43
  step=1,
44
- value=DEFAULT_MAX_NEW_TOKENS,
45
  ),
46
  gr.Slider(
47
  label="Temperature",
@@ -64,29 +125,26 @@ chat_interface = gr.ChatInterface(
64
  step=1,
65
  value=50,
66
  ),
67
- gr.Slider(
68
- label="Repetition penalty",
69
- minimum=1.0,
70
- maximum=2.0,
71
- step=0.05,
72
- value=1.2,
73
- ),
74
  ],
75
- stop_btn=None,
76
- fill_height=True,
77
  examples=[
78
  ["Which one is bigger? 4.9 or 4.11"],
79
- ["Can you explain briefly to me what is the Python programming language?"],
 
 
80
  ["Explain the plot of Cinderella in a sentence."],
81
  ["How many hours does it take a man to eat a Helicopter?"],
82
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
 
 
83
  ],
84
  cache_examples=False,
85
  )
 
86
  with gr.Blocks() as demo:
87
  gr.Markdown('# Sambanova model inference LLAMA 405B')
88
 
89
  chat_interface.render()
90
-
91
  if __name__ == "__main__":
92
  demo.queue(max_size=20).launch()
 
7
  def generate(
8
  message: str,
9
  chat_history: list[tuple[str, str]],
10
+ system_message,
11
+ max_tokens: int = 1024,
12
  temperature: float = 0.6,
13
  top_p: float = 0.9,
14
  top_k: int = 50,
15
  repetition_penalty: float = 1.2,
16
  ) -> Iterator[str]:
17
+
18
+ conversation = [{"role": "system", "content": system_message}]
19
+
20
+ for val in chat_history:
21
+ if val[0]:
22
+ conversation.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ conversation.append({"role": "assistant", "content": val[1]})
 
25
 
26
  outputs = []
27
+ for text in sambanova.Streamer(conversation,
28
+ new_tokens=max_tokens,
29
+ temperature=temperature,
30
+ top_k=top_k,
31
+ top_p=top_p):
32
  outputs.append(text)
33
  yield "".join(outputs)
34
 
35
+
36
+ MAX_MAX_TOKENS = 2048
37
+ DEFAULT_MAX_TOKENS = 1024
38
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
39
 
40
+ # chat_interface = gr.ChatInterface(
41
+ # fn=generate,
42
+ # additional_inputs=[
43
+ # gr.Slider(
44
+ # label="Max new tokens",
45
+ # minimum=1,
46
+ # maximum=MAX_MAX_NEW_TOKENS,
47
+ # step=1,
48
+ # value=DEFAULT_MAX_NEW_TOKENS,
49
+ # ),
50
+ # gr.Slider(
51
+ # label="Temperature",
52
+ # minimum=0.1,
53
+ # maximum=4.0,
54
+ # step=0.1,
55
+ # value=0.6,
56
+ # ),
57
+ # gr.Slider(
58
+ # label="Top-p (nucleus sampling)",
59
+ # minimum=0.05,
60
+ # maximum=1.0,
61
+ # step=0.05,
62
+ # value=0.9,
63
+ # ),
64
+ # gr.Slider(
65
+ # label="Top-k",
66
+ # minimum=1,
67
+ # maximum=1000,
68
+ # step=1,
69
+ # value=50,
70
+ # ),
71
+ # gr.Slider(
72
+ # label="Repetition penalty",
73
+ # minimum=1.0,
74
+ # maximum=2.0,
75
+ # step=0.05,
76
+ # value=1.2,
77
+ # ),
78
+ # ],
79
+ # stop_btn=None,
80
+ # fill_height=True,
81
+ # examples=[
82
+ # ["Which one is bigger? 4.9 or 4.11"],
83
+ # [
84
+ # "Can you explain briefly to me what is the Python programming language?"
85
+ # ],
86
+ # ["Explain the plot of Cinderella in a sentence."],
87
+ # ["How many hours does it take a man to eat a Helicopter?"],
88
+ # [
89
+ # "Write a 100-word article on 'Benefits of Open-Source in AI research'"
90
+ # ],
91
+ # ],
92
+ # cache_examples=False,
93
+ # )
94
+
95
  chat_interface = gr.ChatInterface(
96
+ generate,
97
  additional_inputs=[
98
+ gr.Textbox(value="You are a friendly Chatbot.",
99
+ label="System message"),
100
  gr.Slider(
101
+ label="Max tokens",
102
  minimum=1,
103
+ maximum=MAX_MAX_TOKENS,
104
  step=1,
105
+ value=DEFAULT_MAX_TOKENS,
106
  ),
107
  gr.Slider(
108
  label="Temperature",
 
125
  step=1,
126
  value=50,
127
  ),
128
+
 
 
 
 
 
 
129
  ],
 
 
130
  examples=[
131
  ["Which one is bigger? 4.9 or 4.11"],
132
+ [
133
+ "Can you explain briefly to me what is the Python programming language?"
134
+ ],
135
  ["Explain the plot of Cinderella in a sentence."],
136
  ["How many hours does it take a man to eat a Helicopter?"],
137
+ [
138
+ "Write a 100-word article on 'Benefits of Open-Source in AI research'"
139
+ ],
140
  ],
141
  cache_examples=False,
142
  )
143
+
144
  with gr.Blocks() as demo:
145
  gr.Markdown('# Sambanova model inference LLAMA 405B')
146
 
147
  chat_interface.render()
148
+
149
  if __name__ == "__main__":
150
  demo.queue(max_size=20).launch()