Amitontheweb commited on
Commit
ae1e60f
·
verified ·
1 Parent(s): 6fdbb11

Upload app.py

Browse files

17-8-2024
v1.0

Files changed (1) hide show
  1. app.py +407 -0
app.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Gradio Params Playground
3
+
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import torch
6
+ import gradio as gr
7
+
8
+
9
+ # Load default model as GPT2
10
+
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
13
+ model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device)
14
+
15
+
16
+ # Define functions
17
+
18
+
19
+ global chosen_strategy
20
+
21
+ def generate(input_text, number_steps, number_beams, number_beam_groups, diversity_penalty, length_penalty, num_return_sequences, temperature, no_repeat_ngram_size, repetition_penalty, early_stopping, beam_temperature, top_p, top_k,penalty_alpha,top_p_box,top_k_box,strategy_selected,model_selected):
22
+
23
+ chosen_strategy = strategy_selected
24
+ inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
25
+
26
+ if chosen_strategy == "Sampling":
27
+
28
+ top_p_flag = top_p_box
29
+ top_k_flag = top_k_box
30
+
31
+ outputs = model.generate(
32
+ **inputs,
33
+ max_new_tokens=number_steps,
34
+ return_dict_in_generate=False,
35
+ temperature=temperature,
36
+ top_p=top_p if top_p_flag else None,
37
+ top_k=top_k if top_k_flag else None,
38
+ no_repeat_ngram_size = no_repeat_ngram_size,
39
+ repetition_penalty = repetition_penalty if (repetition_penalty > 0) else None,
40
+ output_scores=False,
41
+ do_sample=True
42
+ )
43
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
44
+
45
+ elif chosen_strategy == "Beam Search":
46
+
47
+ beam_temp_flag = beam_temperature
48
+ early_stop_flag = early_stopping
49
+
50
+ inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
51
+ outputs = model.generate(
52
+
53
+ **inputs,
54
+ max_new_tokens=number_steps,
55
+ num_beams=number_beams,
56
+ num_return_sequences=min(num_return_sequences, number_beams),
57
+ return_dict_in_generate=False,
58
+ length_penalty=length_penalty,
59
+ temperature=temperature if beam_temp_flag else None,
60
+ no_repeat_ngram_size = no_repeat_ngram_size,
61
+ repetition_penalty = repetition_penalty if (repetition_penalty > 0) else None,
62
+ early_stopping = True if early_stop_flag else False,
63
+ output_scores=False,
64
+ do_sample=True if beam_temp_flag else False
65
+ )
66
+
67
+ beam_options_list = []
68
+ for i, beam_output in enumerate(outputs):
69
+ beam_options_list.append (tokenizer.decode(beam_output, skip_special_tokens=True))
70
+ options = "\n\n - Option - \n".join(beam_options_list)
71
+ return ("Beam Search Generation" + "\n" + "-" * 10 + "\n" + options)
72
+ #print ("Option {}: {}\n".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))
73
+
74
+ elif chosen_strategy == "Diversity Beam Search":
75
+
76
+ early_stop_flag = early_stopping
77
+
78
+ if number_beam_groups == 1:
79
+ number_beam_groups = 2
80
+
81
+
82
+ if number_beam_groups > number_beams:
83
+ number_beams = number_beam_groups
84
+
85
+ inputs = tokenizer(input_text, return_tensors="pt").to(torch_device)
86
+ outputs = model.generate(
87
+
88
+ **inputs,
89
+ max_new_tokens=number_steps,
90
+ num_beams=number_beams,
91
+ num_beam_groups=number_beam_groups,
92
+ diversity_penalty=float(diversity_penalty),
93
+ num_return_sequences=min(num_return_sequences, number_beams),
94
+ return_dict_in_generate=False,
95
+ length_penalty=length_penalty,
96
+ no_repeat_ngram_size = no_repeat_ngram_size,
97
+ repetition_penalty = repetition_penalty if (repetition_penalty > 0) else None,
98
+ early_stopping = True if early_stop_flag else False,
99
+ output_scores=False,
100
+ )
101
+
102
+ beam_options_list = []
103
+ for i, beam_output in enumerate(outputs):
104
+ beam_options_list.append (tokenizer.decode(beam_output, skip_special_tokens=True))
105
+ options = "\n\n ------ Option ------- \n".join(beam_options_list)
106
+ return ("Diversity Beam Search Generation" + "\n" + "-" * 10 + "\n" + options)
107
+
108
+ elif chosen_strategy == "Contrastive":
109
+
110
+ top_k_flag = top_k_box
111
+
112
+ outputs = model.generate(
113
+ **inputs,
114
+ max_new_tokens=number_steps,
115
+ return_dict_in_generate=False,
116
+ temperature=temperature,
117
+ penalty_alpha=penalty_alpha,
118
+ top_k=top_k if top_k_flag else None,
119
+ no_repeat_ngram_size = no_repeat_ngram_size,
120
+ repetition_penalty = repetition_penalty if (repetition_penalty > 0) else None,
121
+ output_scores=False,
122
+ do_sample=True
123
+ )
124
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
125
+
126
+
127
+ #--------ON SELECTING MODEL------------------------
128
+
129
+ def load_model(model_selected):
130
+
131
+ if model_selected == "gpt2":
132
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
133
+ model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).to(torch_device)
134
+ #print (model_selected + " loaded")
135
+
136
+ if model_selected == "Gemma 2":
137
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
138
+ model = AutoModelForCausalLM.from_pretrained("google/gemma-2b").to(torch_device)
139
+
140
+
141
+
142
+ #--------ON SELECT NO. OF RETURN SEQUENCES----------
143
+
144
+ def change_num_return_sequences(n_beams, num_return_sequences):
145
+
146
+ if (num_return_sequences > n_beams):
147
+ return gr.Slider(
148
+ label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=n_beams)
149
+
150
+ return gr.Slider (
151
+ label="Number of sequences", minimum=1, maximum=n_beams, step=1, value=num_return_sequences)
152
+
153
+ #--------ON CHANGING NO OF BEAMS------------------
154
+
155
+ def popualate_beam_groups (n_beams):
156
+
157
+ global chosen_strategy
158
+ no_of_beams = n_beams
159
+ No_beam_group_list = [] #list for beam group selection
160
+ for y in range (2, no_of_beams+1):
161
+ if no_of_beams % y == 0: #perfectly divisible
162
+ No_beam_group_list.append (y) #add to list, use as list for beam group selection
163
+
164
+ if chosen_strategy == "Diversity Beam Search":
165
+ return {beam_groups: gr.Dropdown(No_beam_group_list, value=max(No_beam_group_list), label="Beam groups", info="Divide beams into equal groups", visible=True),
166
+ num_return_sequences: gr.Slider(maximum=no_of_beams)
167
+ }
168
+ if chosen_strategy == "Beam Search":
169
+ return {beam_groups: gr.Dropdown(No_beam_group_list, value=max(No_beam_group_list), label="Beam groups", info="Divide beams into equal groups", visible=False),
170
+ num_return_sequences: gr.Slider(maximum=no_of_beams)
171
+ }
172
+
173
+ #-----------ON SELECTING TOP P / TOP K--------------
174
+
175
+ def top_p_switch(input_p_box):
176
+ value = input_p_box
177
+ if value:
178
+ return {top_p: gr.Slider(visible = True)}
179
+ else:
180
+ return {top_p: gr.Slider(visible = False)}
181
+
182
+
183
+ def top_k_switch(input_k_box):
184
+ value = input_k_box
185
+ if value:
186
+ return {top_k: gr.Slider(visible = True)}
187
+ else:
188
+ return {top_k: gr.Slider(visible = False)}
189
+
190
+
191
+ #-----------ON SELECTING BEAM TEMPERATURE--------------
192
+
193
+ def beam_temp_switch (input):
194
+ value = input
195
+ if value:
196
+ return {temperature: gr.Slider (visible=True)}
197
+ else:
198
+ return {temperature: gr.Slider (visible=False)}
199
+
200
+
201
+ #-----------ON COOOSING STRATEGY: HIDE/DISPLAY PARAMS -----------
202
+
203
+ def select_strategy(input_strategy):
204
+
205
+ global chosen_strategy
206
+ chosen_strategy = input_strategy
207
+
208
+ if chosen_strategy == "Beam Search":
209
+ return {n_beams: gr.Slider(visible=True),
210
+ num_return_sequences: gr.Slider(visible=True),
211
+ beam_temperature: gr.Checkbox(visible=True),
212
+ early_stopping: gr.Checkbox(visible=True),
213
+ length_penalty: gr.Slider(visible=True),
214
+ beam_groups: gr.Dropdown(visible=False),
215
+ diversity_penalty: gr.Slider(visible=False),
216
+ temperature: gr.Slider (visible=False),
217
+ top_k: gr.Slider(visible=False),
218
+ top_p: gr.Slider(visible=False),
219
+ top_k_box: gr.Checkbox(visible = False),
220
+ top_p_box: gr.Checkbox(visible = False),
221
+ penalty_alpha: gr.Slider (visible=False)
222
+
223
+ }
224
+ if chosen_strategy == "Sampling":
225
+ if top_k_box == True:
226
+ {top_k: gr.Slider(visible = True)}
227
+ if top_p_box == True:
228
+ {top_p: gr.Slider(visible = True)}
229
+
230
+ return {
231
+ temperature: gr.Slider (visible=True),
232
+ top_p: gr.Slider(visible=False),
233
+ top_k: gr.Slider(visible=False),
234
+ n_beams: gr.Slider(visible=False),
235
+ beam_groups: gr.Dropdown(visible=False),
236
+ diversity_penalty: gr.Slider(visible=False),
237
+ num_return_sequences: gr.Slider(visible=False),
238
+ beam_temperature: gr.Checkbox(visible=False),
239
+ early_stopping: gr.Checkbox(visible=False),
240
+ length_penalty: gr.Slider(visible=False),
241
+ top_p_box: gr.Checkbox(visible = True, value=False),
242
+ top_k_box: gr.Checkbox(visible = True, value=False),
243
+ penalty_alpha: gr.Slider (visible=False)
244
+ }
245
+ if chosen_strategy == "Diversity Beam Search":
246
+
247
+ return {n_beams: gr.Slider(visible=True),
248
+ beam_groups: gr.Dropdown(visible=True),
249
+ diversity_penalty: gr.Slider(visible=True),
250
+ num_return_sequences: gr.Slider(visible=True),
251
+ length_penalty: gr.Slider(visible=True),
252
+ beam_temperature: gr.Checkbox(visible=False),
253
+ early_stopping: gr.Checkbox(visible=True),
254
+ temperature: gr.Slider (visible=False),
255
+ top_k: gr.Slider(visible=False),
256
+ top_p: gr.Slider(visible=False),
257
+ top_k_box: gr.Checkbox(visible = False),
258
+ top_p_box: gr.Checkbox(visible = False),
259
+ penalty_alpha: gr.Slider (visible=False),
260
+ }
261
+
262
+ if chosen_strategy == "Contrastive":
263
+ if top_k_box:
264
+ {top_k: gr.Slider(visible = True)}
265
+
266
+ return {
267
+ temperature: gr.Slider (visible=True),
268
+ penalty_alpha: gr.Slider (visible=True),
269
+ top_p: gr.Slider(visible=False),
270
+ #top_k: gr.Slider(visible = True) if top_k_box
271
+ #top_k: gr.Slider(visible=False),
272
+ n_beams: gr.Slider(visible=False),
273
+ beam_groups: gr.Dropdown(visible=False),
274
+ diversity_penalty: gr.Slider(visible=False),
275
+ num_return_sequences: gr.Slider(visible=False),
276
+ beam_temperature: gr.Checkbox(visible=False),
277
+ early_stopping: gr.Checkbox(visible=False),
278
+ length_penalty: gr.Slider(visible=False),
279
+ top_p_box: gr.Checkbox(visible = False),
280
+ top_k_box: gr.Checkbox(visible = True)
281
+ }
282
+
283
+ def clear():
284
+ print ("")
285
+
286
+
287
+ #------------------MAIN BLOCKS DISPLAY---------------
288
+
289
+ with gr.Blocks() as demo:
290
+
291
+ No_beam_group_list = [2]
292
+ text = gr.Textbox(
293
+ label="Prompt",
294
+ value="It's a rainy day today",
295
+ )
296
+
297
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
298
+ model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id, cache_dir=cache_dir).to(torch_device)
299
+
300
+
301
+ with gr.Row():
302
+
303
+ with gr.Column (scale=0, min_width=200) as Models_Strategy:
304
+
305
+ model_selected = gr.Radio (["gpt2", "Gemma 2"], label="ML Model", value="gpt2")
306
+ strategy_selected = gr.Radio (["Sampling", "Beam Search", "Diversity Beam Search","Contrastive"], label="Search strategy", value = "Sampling", interactive=True)
307
+
308
+
309
+ with gr.Column (scale=0, min_width=250) as Beam_Params:
310
+ n_steps = gr.Slider(
311
+ label="Number of steps/tokens", minimum=1, maximum=100, step=1, value=20
312
+ )
313
+ n_beams = gr.Slider(
314
+ label="Number of beams", minimum=2, maximum=100, step=1, value=4, visible=False
315
+ )
316
+
317
+ #----------------Dropdown-----------------
318
+
319
+ beam_groups = gr.Dropdown(No_beam_group_list, value=2, label="Beam groups", info="Divide beams into equal groups", visible=False
320
+ )
321
+
322
+ diversity_penalty = gr.Slider(
323
+ label="Group diversity penalty", minimum=0.1, maximum=2, step=0.1, value=0.8, visible=False
324
+ )
325
+
326
+ num_return_sequences = gr.Slider(
327
+ label="Number of return sequences", minimum=1, maximum=3, step=1, value=2, visible=False
328
+ )
329
+ temperature = gr.Slider(
330
+ label="Temperature", minimum=0.1, maximum=3, step=0.1, value=0.6, visible = True
331
+ )
332
+
333
+ top_k = gr.Slider(
334
+ label="Top_K", minimum=1, maximum=50, step=1, value=5, visible = False
335
+ )
336
+ top_p = gr.Slider(
337
+ label="Top_P", minimum=0.1, maximum=3, step=0.1, value=0.3, visible = False
338
+ )
339
+
340
+ penalty_alpha = gr.Slider(
341
+ label="Contrastive penalty α", minimum=0.1, maximum=2, step=0.1, value=0.6, visible=False
342
+ )
343
+
344
+ top_p_box = gr.Checkbox(label="Top P", info="Turn on Top P", visible = True, interactive=True)
345
+ top_k_box = gr.Checkbox(label="Top K", info="Turn on Top K", visible = True, interactive=True)
346
+
347
+
348
+ early_stopping = gr.Checkbox(label="Early stopping", info="Stop with heuristically chosen good result", visible = False, interactive=True)
349
+ beam_temperature = gr.Checkbox(label="Beam Temperature", info="Turn on sampling", visible = False, interactive=True)
350
+
351
+ with gr.Column(scale=0, min_width=200):
352
+
353
+ length_penalty = gr.Slider(
354
+ label="Length penalty", minimum=-3, maximum=3, step=0.5, value=0, info="'+' more, '-' less no. of words", visible = False, interactive=True
355
+ )
356
+
357
+ no_repeat_ngram_size = gr.Slider(
358
+ label="No repeat n-gram phrase size", minimum=0, maximum=8, step=1, value=4, info="Not to repeat 'n' words"
359
+ )
360
+ repetition_penalty = gr.Slider(
361
+ label="Repetition penalty", minimum=0, maximum=3, step=1, value=float(0), info="Prior context based penalty for unique text"
362
+ )
363
+
364
+
365
+
366
+ #----------ON SELECTING/CHANGING: RETURN SEEQUENCES/NO OF BEAMS/BEAM GROUPS/TEMPERATURE--------
367
+
368
+ model_selected.change(
369
+ fn=load_model, inputs=[model_selected], outputs=[]
370
+ )
371
+
372
+ #num_return_sequences.change(
373
+ #fn=change_num_return_sequences, inputs=[n_beams,num_return_sequences], outputs=num_return_sequences
374
+ #)
375
+
376
+ n_beams.change(
377
+ fn=popualate_beam_groups, inputs=[n_beams], outputs=[beam_groups,num_return_sequences]
378
+ )
379
+
380
+ strategy_selected.change(fn=select_strategy, inputs=strategy_selected, outputs=[n_beams,beam_groups,length_penalty,diversity_penalty,num_return_sequences,temperature,early_stopping,beam_temperature,penalty_alpha,top_p,top_k,top_p_box,top_k_box])
381
+
382
+ beam_temperature.change (fn=beam_temp_switch, inputs=beam_temperature, outputs=temperature)
383
+
384
+ top_p_box.change (fn=top_p_switch, inputs=top_p_box, outputs=top_p)
385
+
386
+ top_k_box.change (fn=top_k_switch, inputs=top_k_box, outputs=top_k)
387
+
388
+
389
+ #-------------GENERATE BUTTON-------------------
390
+
391
+ button = gr.Button("Generate")
392
+ out_markdown = gr.Textbox()
393
+
394
+
395
+ button.click(
396
+ fn = generate,
397
+ inputs=[text, n_steps, n_beams, beam_groups, diversity_penalty, length_penalty, num_return_sequences, temperature, no_repeat_ngram_size, repetition_penalty, early_stopping, beam_temperature, top_p, top_k,penalty_alpha,top_p_box,top_k_box,strategy_selected,model_selected],
398
+ outputs=[out_markdown]
399
+ )
400
+
401
+ cleared = gr.Button ("Clear")
402
+ cleared.click (fn=clear, inputs=[], outputs=[out_markdown])
403
+
404
+
405
+
406
+ demo.launch()
407
+