Spaces:
Runtime error
Runtime error
add functionnal sliders for hyperparameters
Browse files- app_dialogue.py +70 -14
app_dialogue.py
CHANGED
@@ -282,7 +282,15 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
|
|
282 |
interactive=True,
|
283 |
label="Top P",
|
284 |
)
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
minimum=0,
|
287 |
maximum=1024,
|
288 |
value=512,
|
@@ -290,6 +298,46 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
|
|
290 |
interactive=True,
|
291 |
label="Max output tokens",
|
292 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
with gr.Column(scale=6):
|
295 |
chatbot = gr.Chatbot(
|
@@ -357,22 +405,30 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
|
|
357 |
def model_inference(
|
358 |
user_prompt,
|
359 |
chat_history,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
):
|
361 |
global processor, model, tokenizer
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
min_length = 16
|
367 |
force_words = ""
|
368 |
-
repetition_penalty = 1.0
|
369 |
hide_special_tokens = False
|
370 |
decoding_strategy = "greedy"
|
371 |
num_beams = 3
|
372 |
-
length_penalty = 1.0
|
373 |
-
top_k = 50
|
374 |
-
top_p = 0.95
|
375 |
-
penalty_alpha = 0.95
|
376 |
|
377 |
formated_prompt = format_prompt_with_history_and_system_conditioning(
|
378 |
current_user_prompt=user_prompt.strip(),
|
@@ -406,13 +462,13 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
|
|
406 |
|
407 |
textbox.submit(
|
408 |
fn=model_inference,
|
409 |
-
inputs=[textbox, chatbot],
|
410 |
outputs=[textbox, chatbot],
|
411 |
)
|
412 |
submit_btn.click(
|
413 |
fn=model_inference,
|
414 |
-
inputs=[textbox, chatbot],
|
415 |
-
outputs=[textbox, chatbot],
|
416 |
)
|
417 |
|
418 |
demo.queue()
|
|
|
282 |
interactive=True,
|
283 |
label="Top P",
|
284 |
)
|
285 |
+
top_k = gr.Slider(
|
286 |
+
minimum=0.0,
|
287 |
+
maximum=100.0,
|
288 |
+
value=50.0,
|
289 |
+
step=1.0,
|
290 |
+
interactive=True,
|
291 |
+
label="Top K",
|
292 |
+
)
|
293 |
+
max_new_tokens = gr.Slider(
|
294 |
minimum=0,
|
295 |
maximum=1024,
|
296 |
value=512,
|
|
|
298 |
interactive=True,
|
299 |
label="Max output tokens",
|
300 |
)
|
301 |
+
repetition_penalty = gr.Slider(
|
302 |
+
minimum=0.0,
|
303 |
+
maximum=10.0,
|
304 |
+
value=1.0,
|
305 |
+
step=0.1,
|
306 |
+
interactive=True,
|
307 |
+
label="Repetition penalty",
|
308 |
+
)
|
309 |
+
min_length = gr.Slider(
|
310 |
+
minimum=0.0,
|
311 |
+
maximum=50.0,
|
312 |
+
value=0.0,
|
313 |
+
step=1.0,
|
314 |
+
interactive=True,
|
315 |
+
label="No repeat ngram size",
|
316 |
+
)
|
317 |
+
length_penalty = gr.Slider(
|
318 |
+
minimum=0.0,
|
319 |
+
maximum=10.0,
|
320 |
+
value=1.0,
|
321 |
+
step=0.1,
|
322 |
+
interactive=True,
|
323 |
+
label="Length penalty",
|
324 |
+
)
|
325 |
+
no_repeat_ngram_size = gr.Slider(
|
326 |
+
minimum=0.0,
|
327 |
+
maximum=10.0,
|
328 |
+
value=0.0,
|
329 |
+
step=1.0,
|
330 |
+
interactive=True,
|
331 |
+
label="No repeat ngram size",
|
332 |
+
)
|
333 |
+
penalty_alpha = gr.Slider(
|
334 |
+
minimum=0.0,
|
335 |
+
maximum=10.0,
|
336 |
+
value=0.95,
|
337 |
+
step=1.0,
|
338 |
+
interactive=True,
|
339 |
+
label="Penalty alpha",
|
340 |
+
)
|
341 |
|
342 |
with gr.Column(scale=6):
|
343 |
chatbot = gr.Chatbot(
|
|
|
405 |
def model_inference(
|
406 |
user_prompt,
|
407 |
chat_history,
|
408 |
+
temperature = 1.0,
|
409 |
+
no_repeat_ngram_size = 0,
|
410 |
+
max_new_tokens = 512,
|
411 |
+
min_length = 16,
|
412 |
+
repetition_penalty = 1.0,
|
413 |
+
length_penalty = 1.0,
|
414 |
+
top_k = 50,
|
415 |
+
top_p = 0.95,
|
416 |
+
penalty_alpha = 0.95,
|
417 |
):
|
418 |
global processor, model, tokenizer
|
419 |
+
# temperature = 1.0
|
420 |
+
# no_repeat_ngram_size = 0
|
421 |
+
# max_new_tokens = 512
|
422 |
+
# min_length = 16
|
|
|
423 |
force_words = ""
|
424 |
+
# repetition_penalty = 1.0
|
425 |
hide_special_tokens = False
|
426 |
decoding_strategy = "greedy"
|
427 |
num_beams = 3
|
428 |
+
# length_penalty = 1.0
|
429 |
+
# top_k = 50
|
430 |
+
# top_p = 0.95
|
431 |
+
# penalty_alpha = 0.95
|
432 |
|
433 |
formated_prompt = format_prompt_with_history_and_system_conditioning(
|
434 |
current_user_prompt=user_prompt.strip(),
|
|
|
462 |
|
463 |
textbox.submit(
|
464 |
fn=model_inference,
|
465 |
+
inputs=[textbox, chatbot, temperature, ],
|
466 |
outputs=[textbox, chatbot],
|
467 |
)
|
468 |
submit_btn.click(
|
469 |
fn=model_inference,
|
470 |
+
inputs=[textbox, chatbot, temperature, no_repeat_ngram_size, max_new_tokens, min_length, repetition_penalty, length_penalty, top_k, top_p, penalty_alpha],
|
471 |
+
outputs=[textbox, chatbot, temperature, no_repeat_ngram_size, max_new_tokens, min_length, repetition_penalty, length_penalty, top_k, top_p, penalty_alpha],
|
472 |
)
|
473 |
|
474 |
demo.queue()
|