jmercat commited on
Commit
0bb8dc5
1 Parent(s): 425b71f

added IT models for chat

Browse files
Files changed (1) hide show
  1. app.py +144 -68
app.py CHANGED
@@ -1,19 +1,17 @@
1
- from threading import Thread
2
-
3
  import gradio as gr
4
- from gradio.layouts import Accordion
5
- import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  import torch
8
-
9
  from open_lm.hf import *
10
  from open_lm.precision import get_autocast
11
-
12
 
13
  # Define model options
14
  MODEL_OPTIONS = {
15
  "TRI DCLM-1B": "TRI-ML/DCLM-1B",
16
- "Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B"
 
 
17
  }
18
 
19
  # Global variables for model and tokenizer
@@ -29,13 +27,13 @@ def load_model(model_name):
29
  return f"Loaded model: {model_name}"
30
 
31
  @spaces.GPU
32
- def generate(
33
  prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
34
  ):
35
  global current_model, current_tokenizer
36
 
37
  if current_model is None or current_tokenizer is None:
38
- return "Please load a model first."
39
 
40
  temperature = float(temperature)
41
  if temperature < 1e-2:
@@ -63,7 +61,6 @@ def generate(
63
  thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
64
  thread.start()
65
 
66
- # Write the prompt in blue
67
  output = "<span style='color: blue;'>" + prompt + "</span>"
68
  for new_text in streamer:
69
  if isinstance(new_text, torch.Tensor):
@@ -77,80 +74,159 @@ def generate(
77
  thread.join()
78
  return output
79
 
80
- additional_inputs=[
81
- gr.Slider(
82
- label="Temperature",
83
- value=0.9,
84
- minimum=0.0,
85
- maximum=1.0,
86
- step=0.05,
87
- interactive=True,
88
- info="Higher values produce more diverse outputs",
89
- ),
90
- gr.Slider(
91
- label="Max new tokens",
92
- value=256,
93
- minimum=0,
94
- maximum=1048,
95
- step=64,
96
- interactive=True,
97
- info="The maximum numbers of new tokens",
98
- ),
99
- gr.Slider(
100
- label="Top-p (nucleus sampling)",
101
- value=0.90,
102
- minimum=0.0,
103
- maximum=1,
104
- step=0.05,
105
- interactive=True,
106
- info="Higher values sample more low-probability tokens",
107
- ),
108
- gr.Slider(
109
- label="Repetition penalty",
110
- value=1.2,
111
- minimum=1.0,
112
- maximum=2.0,
113
- step=0.05,
114
- interactive=True,
115
- info="Penalize repeated tokens",
116
- )
117
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  with gr.Blocks() as demo:
120
  gr.Markdown(
121
  """
122
- # DCLM Text Completion Demo
123
- This demo allows you to generate text using a DCLM model.
124
- These models are trained to predict the next word in a sequence of text, and can be used to generate text completions, they are not chatbots.
 
 
 
125
 
126
- First select a model from the dropdown and click "Load Model".
127
- Then enter some text in the text box and click "Generate" to see the model's completion.
128
  """
129
  )
130
 
131
-
132
  with gr.Row():
133
  model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- model_dropdown.select(
136
- load_model,
137
  inputs=[model_dropdown],
138
- outputs=[gr.Textbox(label="Model Status")]
139
  )
140
 
141
- text_input = gr.Textbox(lines=3, label="Input Text")
142
- text_output = gr.Markdown(label="Generated Text")
143
-
144
- generate_button = gr.Button("Generate")
145
-
146
  generate_button.click(
147
- generate,
148
  inputs=[text_input, model_dropdown, *additional_inputs],
149
  outputs=[text_output]
150
  )
151
- with Accordion(label="Advanced Options", open=False):
152
- for input_component in additional_inputs:
153
- if not input_component.is_rendered:
154
- input_component.render()
155
 
156
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ from threading import Thread
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
  import torch
 
5
  from open_lm.hf import *
6
  from open_lm.precision import get_autocast
7
+ import spaces
8
 
9
  # Define model options
10
  MODEL_OPTIONS = {
11
  "TRI DCLM-1B": "TRI-ML/DCLM-1B",
12
+ "Apple DCLM-Baseline-7B": "apple/DCLM-Baseline-7B",
13
+ "[IT] TRI DCLM-1B": "TRI-ML/DCLM-1B-IT",
14
+ "[IT] Apple DCLM-Baseline-7B": "mlfoundations/dclm-7b-it",
15
  }
16
 
17
  # Global variables for model and tokenizer
 
27
  return f"Loaded model: {model_name}"
28
 
29
  @spaces.GPU
30
+ def generate_completion(
31
  prompt, model_choice, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
32
  ):
33
  global current_model, current_tokenizer
34
 
35
  if current_model is None or current_tokenizer is None:
36
+ return "Please select a model first."
37
 
38
  temperature = float(temperature)
39
  if temperature < 1e-2:
 
61
  thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
62
  thread.start()
63
 
 
64
  output = "<span style='color: blue;'>" + prompt + "</span>"
65
  for new_text in streamer:
66
  if isinstance(new_text, torch.Tensor):
 
74
  thread.join()
75
  return output
76
 
77
+ def format_prompt(message, history):
78
+ prompt = ""
79
+ for user_prompt, bot_response in history:
80
+ prompt += f"User: {user_prompt}\nAssistant: {bot_response}\n"
81
+ prompt += f"User: {message}\nAssistant:"
82
+ return prompt
83
+
84
+ @spaces.GPU
85
+ def generate_chat(
86
+ message, chat_history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
87
+ ):
88
+ global current_model, current_tokenizer
89
+
90
+ if current_model is None or current_tokenizer is None:
91
+ yield chat_history + [("Error", "Please select a model first.")]
92
+ return
93
+
94
+ temperature = float(temperature)
95
+ if temperature < 1e-2:
96
+ temperature = 1e-2
97
+ top_p = float(top_p)
98
+
99
+ formatted_prompt = format_prompt(message, chat_history)
100
+ inputs = current_tokenizer(formatted_prompt, return_tensors="pt").to(current_model.device)
101
+
102
+ generate_kwargs = dict(
103
+ **inputs,
104
+ max_new_tokens=max_new_tokens,
105
+ temperature=temperature,
106
+ top_p=top_p,
107
+ repetition_penalty=repetition_penalty,
108
+ do_sample=True,
109
+ pad_token_id=current_tokenizer.eos_token_id
110
+ )
111
+
112
+ streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
113
+ streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
114
+ generate_kwargs["streamer"] = streamer
115
+
116
+ thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
117
+ thread.start()
118
+
119
+ new_history = chat_history + [(message, "")]
120
+ for new_text in streamer:
121
+ if isinstance(new_text, torch.Tensor):
122
+ new_text = current_tokenizer.decode(new_text)
123
+ if streamer.stop_signal in new_text:
124
+ new_text = new_text.split(streamer.stop_signal)[0]
125
+ new_history[-1] = (message, new_history[-1][1] + new_text)
126
+ break
127
+ new_history[-1] = (message, new_history[-1][1] + new_text)
128
+ yield new_history
129
+
130
+ thread.join()
131
+
132
+ additional_inputs = [
133
+ gr.Slider(
134
+ label="Temperature",
135
+ value=0.9,
136
+ minimum=0.0,
137
+ maximum=1.0,
138
+ step=0.05,
139
+ interactive=True,
140
+ info="Higher values produce more diverse outputs",
141
+ ),
142
+ gr.Slider(
143
+ label="Max new tokens",
144
+ value=256,
145
+ minimum=0,
146
+ maximum=1048,
147
+ step=64,
148
+ interactive=True,
149
+ info="The maximum numbers of new tokens",
150
+ ),
151
+ gr.Slider(
152
+ label="Top-p (nucleus sampling)",
153
+ value=0.90,
154
+ minimum=0.0,
155
+ maximum=1,
156
+ step=0.05,
157
+ interactive=True,
158
+ info="Higher values sample more low-probability tokens",
159
+ ),
160
+ gr.Slider(
161
+ label="Repetition penalty",
162
+ value=1.2,
163
+ minimum=1.0,
164
+ maximum=2.0,
165
+ step=0.05,
166
+ interactive=True,
167
+ info="Penalize repeated tokens",
168
+ )
169
+ ]
170
 
171
  with gr.Blocks() as demo:
172
  gr.Markdown(
173
  """
174
+ # DCLM Demo
175
+ This demo allows you to generate text using DCLM models in two modes:
176
+ 1. Text Completion:
177
+ For non-Instruction-Tuned models, it generates the continuation of the input text.
178
+ 2. Chatbot:
179
+ For Instruction-Tuned [IT] models, it generates responses to user messages as a chatbot.
180
 
181
+ Select a model from the dropdown to start, it might take a few seconds to load.
182
+ The interface will automatically switch between Text Completion and Chatbot modes based on the selected model.
183
  """
184
  )
185
 
 
186
  with gr.Row():
187
  model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model")
188
+ model_status = gr.Textbox(label="Model Status")
189
+
190
+ # Text Completion interface
191
+ with gr.Row(visible=False) as completion_interface:
192
+ with gr.Column():
193
+ text_input = gr.Textbox(lines=3, label="Input Text")
194
+ text_output = gr.Markdown(label="Generated Text")
195
+ generate_button = gr.Button("Generate")
196
+
197
+ # Chatbot interface
198
+ with gr.Row(visible=False) as chat_interface:
199
+ with gr.Column():
200
+ chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel")
201
+ msg = gr.Textbox(label="Message")
202
+ clear = gr.Button("Clear")
203
+
204
+ with gr.Accordion("Advanced Options", open=False):
205
+ for input_component in additional_inputs:
206
+ input_component.render()
207
+
208
+ def switch_interface(model_name):
209
+ is_it_model = model_name.startswith("[IT]")
210
+ status = load_model(model_name)
211
+ return (
212
+ gr.Row(visible=not is_it_model), # completion_interface
213
+ gr.Row(visible=is_it_model), # chat_interface
214
+ status # model_status
215
+ )
216
 
217
+ model_dropdown.change(
218
+ switch_interface,
219
  inputs=[model_dropdown],
220
+ outputs=[completion_interface, chat_interface, model_status]
221
  )
222
 
 
 
 
 
 
223
  generate_button.click(
224
+ generate_completion,
225
  inputs=[text_input, model_dropdown, *additional_inputs],
226
  outputs=[text_output]
227
  )
 
 
 
 
228
 
229
+ msg.submit(generate_chat, [msg, chatbot, *additional_inputs], chatbot)
230
+ clear.click(lambda: None, None, chatbot, queue=False)
231
+
232
+ demo.queue().launch()