Leyo commited on
Commit
dafb0ab
1 Parent(s): 5167a8a

black formmatting

Browse files
Files changed (1) hide show
  1. app_dialogue.py +70 -21
app_dialogue.py CHANGED
@@ -32,7 +32,12 @@ EOS_TOKENS = "</s>;User"
32
  import logging
33
 
34
  from accelerate.utils import get_max_memory
35
- from transformers import AutoTokenizer, AutoProcessor, AutoConfig, IdeficsForVisionText2Text
 
 
 
 
 
36
 
37
 
38
  TOKENIZER_FAST = True
@@ -56,7 +61,9 @@ def load_processor_tokenizer_model(model_name):
56
  )
57
  # tokenizer.padding_side = "left" -> we don't need that, do we?
58
 
59
- config = AutoConfig.from_pretrained(model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True))
 
 
60
  max_memory_map = get_max_memory()
61
 
62
  for key in max_memory_map.keys():
@@ -92,6 +99,7 @@ def split_prompt_into_list(prompt_str):
92
  prompt_list.append(ps)
93
  return prompt_list
94
 
 
95
  def model_generation(
96
  prompt,
97
  processor,
@@ -117,7 +125,8 @@ def model_generation(
117
  [split_prompt_into_list(prompt)],
118
  eval_mode=True,
119
  truncation=True,
120
- max_length=MAX_SEQ_LEN - 512, # TODO: replace the 512 value with `max_new_tokens`
 
121
  padding=True,
122
  )
123
  for k, v in input_args.items():
@@ -145,7 +154,9 @@ def model_generation(
145
  if len(eos_tokens) > 0:
146
  eos_token_ids = []
147
  for eos_token in eos_tokens:
148
- tokenized_eos_token = tokenizer(eos_token, add_special_tokens=False).input_ids
 
 
149
  if len(tokenized_eos_token) > 1:
150
  raise ValueError(
151
  f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
@@ -203,13 +214,17 @@ def model_generation(
203
 
204
  tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
205
  decoded_skip_special_tokens = repr(
206
- tokenizer.batch_decode(generated_tokens, skip_special_tokens=hide_special_tokens)[0]
 
 
207
  )
208
 
209
  actual_generated_tokens = generated_tokens[:, input_args["input_ids"].shape[-1] :]
210
  first_end_token = len(actual_generated_tokens[0])
211
  actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
212
- generated_text = tokenizer.batch_decode(actual_generated_tokens, skip_special_tokens=hide_special_tokens)[0]
 
 
213
 
214
  logger.info(
215
  "Result: \n"
@@ -252,7 +267,9 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
252
  show_label=False,
253
  container=False,
254
  )
255
- processor, tokenizer, model = load_processor_tokenizer_model(model_selector.value)
 
 
256
 
257
  imagebox = gr.Image(
258
  type="pil",
@@ -394,26 +411,30 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
394
  # inputs = [chatbot]
395
  # )
396
 
397
- def format_prompt_with_history_and_system_conditioning(current_user_prompt, history):
 
 
398
  resulting_text = SYSTEM_PROMPT
399
  for turn in history:
400
  user_utterance, assistant_utterance = turn
401
- resulting_text += f"\nUser: {user_utterance}</s>\nAssistant: {assistant_utterance}</s>"
 
 
402
  resulting_text += f"\nUser: {current_user_prompt}</s>\nAssistant:"
403
  return resulting_text
404
 
405
  def model_inference(
406
  user_prompt,
407
  chat_history,
408
- temperature = 1.0,
409
- no_repeat_ngram_size = 0,
410
- max_new_tokens = 512,
411
- min_length = 16,
412
- repetition_penalty = 1.0,
413
- length_penalty = 1.0,
414
- top_k = 50,
415
- top_p = 0.95,
416
- penalty_alpha = 0.95,
417
  ):
418
  global processor, model, tokenizer
419
  # temperature = 1.0
@@ -462,13 +483,41 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
462
 
463
  textbox.submit(
464
  fn=model_inference,
465
- inputs=[textbox, chatbot, temperature, ],
 
 
 
 
466
  outputs=[textbox, chatbot],
467
  )
468
  submit_btn.click(
469
  fn=model_inference,
470
- inputs=[textbox, chatbot, temperature, no_repeat_ngram_size, max_new_tokens, min_length, repetition_penalty, length_penalty, top_k, top_p, penalty_alpha],
471
- outputs=[textbox, chatbot, temperature, no_repeat_ngram_size, max_new_tokens, min_length, repetition_penalty, length_penalty, top_k, top_p, penalty_alpha],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  )
473
 
474
  demo.queue()
 
32
  import logging
33
 
34
  from accelerate.utils import get_max_memory
35
+ from transformers import (
36
+ AutoTokenizer,
37
+ AutoProcessor,
38
+ AutoConfig,
39
+ IdeficsForVisionText2Text,
40
+ )
41
 
42
 
43
  TOKENIZER_FAST = True
 
61
  )
62
  # tokenizer.padding_side = "left" -> we don't need that, do we?
63
 
64
+ config = AutoConfig.from_pretrained(
65
+ model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True)
66
+ )
67
  max_memory_map = get_max_memory()
68
 
69
  for key in max_memory_map.keys():
 
99
  prompt_list.append(ps)
100
  return prompt_list
101
 
102
+
103
  def model_generation(
104
  prompt,
105
  processor,
 
125
  [split_prompt_into_list(prompt)],
126
  eval_mode=True,
127
  truncation=True,
128
+ max_length=MAX_SEQ_LEN
129
+ - 512, # TODO: replace the 512 value with `max_new_tokens`
130
  padding=True,
131
  )
132
  for k, v in input_args.items():
 
154
  if len(eos_tokens) > 0:
155
  eos_token_ids = []
156
  for eos_token in eos_tokens:
157
+ tokenized_eos_token = tokenizer(
158
+ eos_token, add_special_tokens=False
159
+ ).input_ids
160
  if len(tokenized_eos_token) > 1:
161
  raise ValueError(
162
  f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
 
214
 
215
  tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
216
  decoded_skip_special_tokens = repr(
217
+ tokenizer.batch_decode(
218
+ generated_tokens, skip_special_tokens=hide_special_tokens
219
+ )[0]
220
  )
221
 
222
  actual_generated_tokens = generated_tokens[:, input_args["input_ids"].shape[-1] :]
223
  first_end_token = len(actual_generated_tokens[0])
224
  actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
225
+ generated_text = tokenizer.batch_decode(
226
+ actual_generated_tokens, skip_special_tokens=hide_special_tokens
227
+ )[0]
228
 
229
  logger.info(
230
  "Result: \n"
 
267
  show_label=False,
268
  container=False,
269
  )
270
+ processor, tokenizer, model = load_processor_tokenizer_model(
271
+ model_selector.value
272
+ )
273
 
274
  imagebox = gr.Image(
275
  type="pil",
 
411
  # inputs = [chatbot]
412
  # )
413
 
414
+ def format_prompt_with_history_and_system_conditioning(
415
+ current_user_prompt, history
416
+ ):
417
  resulting_text = SYSTEM_PROMPT
418
  for turn in history:
419
  user_utterance, assistant_utterance = turn
420
+ resulting_text += (
421
+ f"\nUser: {user_utterance}</s>\nAssistant: {assistant_utterance}</s>"
422
+ )
423
  resulting_text += f"\nUser: {current_user_prompt}</s>\nAssistant:"
424
  return resulting_text
425
 
426
  def model_inference(
427
  user_prompt,
428
  chat_history,
429
+ temperature=1.0,
430
+ no_repeat_ngram_size=0,
431
+ max_new_tokens=512,
432
+ min_length=16,
433
+ repetition_penalty=1.0,
434
+ length_penalty=1.0,
435
+ top_k=50,
436
+ top_p=0.95,
437
+ penalty_alpha=0.95,
438
  ):
439
  global processor, model, tokenizer
440
  # temperature = 1.0
 
483
 
484
  textbox.submit(
485
  fn=model_inference,
486
+ inputs=[
487
+ textbox,
488
+ chatbot,
489
+ temperature,
490
+ ],
491
  outputs=[textbox, chatbot],
492
  )
493
  submit_btn.click(
494
  fn=model_inference,
495
+ inputs=[
496
+ textbox,
497
+ chatbot,
498
+ temperature,
499
+ no_repeat_ngram_size,
500
+ max_new_tokens,
501
+ min_length,
502
+ repetition_penalty,
503
+ length_penalty,
504
+ top_k,
505
+ top_p,
506
+ penalty_alpha,
507
+ ],
508
+ outputs=[
509
+ textbox,
510
+ chatbot,
511
+ temperature,
512
+ no_repeat_ngram_size,
513
+ max_new_tokens,
514
+ min_length,
515
+ repetition_penalty,
516
+ length_penalty,
517
+ top_k,
518
+ top_p,
519
+ penalty_alpha,
520
+ ],
521
  )
522
 
523
  demo.queue()