AppleSwing commited on
Commit
a3cdaa8
1 Parent(s): c2dbb45

fix generation issue

Browse files
src/backend/hflm_with_measurement.py CHANGED
@@ -294,7 +294,7 @@ class HFLMWithMeasurement(HFLM):
294
 
295
  return re_ord.get_original(res)
296
 
297
- def _model_generate(self, context, max_length, stop, **generation_kwargs):
298
  # temperature = 0.0 if not set
299
  # if do_sample is false and temp==0.0:
300
  # remove temperature, as do_sample=False takes care of this
@@ -302,7 +302,7 @@ class HFLMWithMeasurement(HFLM):
302
  generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
303
  do_sample = generation_kwargs.get("do_sample", None)
304
 
305
- is_gsm8k = generation_kwargs.get("is_gsm8k", False)
306
 
307
  # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
308
  if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
@@ -311,12 +311,12 @@ class HFLMWithMeasurement(HFLM):
311
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
312
  generation_kwargs.pop("temperature")
313
 
314
- if is_gsm8k:
315
- generation_kwargs.pop("is_gsm8k")
316
 
317
  context_length = context.shape[1]
318
  model_config = self.model.config
319
-
320
  if not self.precision:
321
  if model_config.quantization_config._load_in_4bit:
322
  self.precision = "4bit"
@@ -325,38 +325,21 @@ class HFLMWithMeasurement(HFLM):
325
  else:
326
  raise ValueError("Unknown precision")
327
 
328
- if not is_gsm8k:
329
- # build stopping criteria
330
- print("Using normal stopping criteria")
331
- stopping_criteria = stop_sequences_criteria(
332
- self.tokenizer, stop, context.shape[1], context.shape[0]
333
- )
334
- stop_watch = StopWatch(self.tokenizer)
335
- start = time()
336
- res = self.model.generate(
337
- input_ids=context,
338
- max_length=max_length,
339
- stopping_criteria=stopping_criteria,
340
- pad_token_id=self.tokenizer.pad_token_id,
341
- use_cache=True,
342
- streamer=stop_watch,
343
- **generation_kwargs,
344
- )
345
- end = time()
346
- else:
347
- # print("Using GSM8K")
348
- stop_watch = StopWatch(self.tokenizer)
349
- start = time()
350
- res = self.model.generate(
351
- input_ids=context,
352
- max_length=max_length,
353
- eos_token_id=stop,
354
- pad_token_id=self.tokenizer.pad_token_id,
355
- use_cache=True,
356
- streamer=stop_watch,
357
- **generation_kwargs,
358
- )
359
- end = time()
360
 
361
  batch_size = context.shape[0]
362
  output_length = stop_watch.decoding_iterations
@@ -498,15 +481,18 @@ class HFLMWithMeasurement(HFLM):
498
  f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
499
  )
500
  # add EOS token to stop sequences
501
- eos = self.tok_decode(self.eot_token_id)
502
  if not until:
503
  until = [eos]
504
  else:
505
  until.append(eos)
506
 
507
- is_gsm8k = kwargs.get("is_gsm8k", False)
508
- if is_gsm8k:
509
- until = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
 
 
 
510
 
511
  if "max_gen_toks" in kwargs.keys():
512
  max_gen_toks = kwargs.pop("max_gen_toks")
@@ -532,8 +518,8 @@ class HFLMWithMeasurement(HFLM):
532
  context_enc = context_enc.to(self.device)
533
  attn_masks = attn_masks.to(self.device)
534
 
535
- if "max_length" not in kwargs:
536
- kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
537
 
538
  # perform batched generation
539
  cont, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu = self._model_generate(
@@ -551,17 +537,16 @@ class HFLMWithMeasurement(HFLM):
551
  cont_toks = cont_toks[context_enc.shape[1] :]
552
 
553
  s = self.tok_decode(cont_toks)
 
 
 
 
 
 
 
 
554
 
555
  # print(s)
556
-
557
- # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
558
- if not is_gsm8k:
559
- for term in until:
560
- if len(term) > 0:
561
- # ignore '' separator,
562
- # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
563
- s = s.split(term)[0]
564
-
565
  res.append((s, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu))
566
 
567
  self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
 
294
 
295
  return re_ord.get_original(res)
296
 
297
+ def _model_generate(self, context, max_tokens, stop, **generation_kwargs):
298
  # temperature = 0.0 if not set
299
  # if do_sample is false and temp==0.0:
300
  # remove temperature, as do_sample=False takes care of this
 
302
  generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
303
  do_sample = generation_kwargs.get("do_sample", None)
304
 
305
+ # is_gsm8k = generation_kwargs.get("is_gsm8k", False)
306
 
307
  # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
308
  if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
 
311
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
312
  generation_kwargs.pop("temperature")
313
 
314
+ # if is_gsm8k:
315
+ # generation_kwargs.pop("is_gsm8k")
316
 
317
  context_length = context.shape[1]
318
  model_config = self.model.config
319
+
320
  if not self.precision:
321
  if model_config.quantization_config._load_in_4bit:
322
  self.precision = "4bit"
 
325
  else:
326
  raise ValueError("Unknown precision")
327
 
328
+ stopping_criteria = stop_sequences_criteria(
329
+ self.tokenizer, stop, context.shape[1], context.shape[0]
330
+ )
331
+ stop_watch = StopWatch(self.tokenizer)
332
+ start = time()
333
+ res = self.model.generate(
334
+ input_ids=context,
335
+ max_new_tokens=max_tokens,
336
+ stopping_criteria=stopping_criteria,
337
+ pad_token_id=self.tokenizer.pad_token_id,
338
+ use_cache=True,
339
+ streamer=stop_watch,
340
+ **generation_kwargs,
341
+ )
342
+ end = time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
  batch_size = context.shape[0]
345
  output_length = stop_watch.decoding_iterations
 
481
  f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
482
  )
483
  # add EOS token to stop sequences
484
+ eos = "<|eot_id|>"
485
  if not until:
486
  until = [eos]
487
  else:
488
  until.append(eos)
489
 
490
+ # is_gsm8k = kwargs.get("is_gsm8k", False)
491
+ # if is_gsm8k:
492
+ # until = ["Question:", "Question", "</s>"]
493
+ # eos_ids = [self.tokenizer.eos_token_id,
494
+ # self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
495
+
496
 
497
  if "max_gen_toks" in kwargs.keys():
498
  max_gen_toks = kwargs.pop("max_gen_toks")
 
518
  context_enc = context_enc.to(self.device)
519
  attn_masks = attn_masks.to(self.device)
520
 
521
+ if "max_tokens" not in kwargs:
522
+ kwargs["max_tokens"] = max_gen_toks
523
 
524
  # perform batched generation
525
  cont, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu = self._model_generate(
 
537
  cont_toks = cont_toks[context_enc.shape[1] :]
538
 
539
  s = self.tok_decode(cont_toks)
540
+
541
+ # # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
542
+ # if not is_gsm8k:
543
+ for term in until:
544
+ if len(term) > 0:
545
+ # ignore '' separator,
546
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
547
+ s = s.split(term)[0]
548
 
549
  # print(s)
 
 
 
 
 
 
 
 
 
550
  res.append((s, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu))
551
 
552
  self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
src/backend/tasks/gsm8k/gsm8k-custom.yaml CHANGED
@@ -22,18 +22,21 @@ metric_list:
22
  - "\\.$"
23
  generation_kwargs:
24
  until:
25
- - "<|eot_id|>"
 
 
 
26
  do_sample: false
27
  temperature: 0.0
28
- is_gsm8k: true
29
  repeats: 1
30
  num_fewshot: 5
31
  filter_list:
32
- # - name: "strict-match"
33
- # filter:
34
- # - function: "regex"
35
- # regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
36
- # - function: "take_first"
37
  - name: "flexible-extract"
38
  filter:
39
  - function: "regex"
 
22
  - "\\.$"
23
  generation_kwargs:
24
  until:
25
+ - "Question:"
26
+ - "Question"
27
+ - "</s>"
28
+ - "<|im_end|>"
29
  do_sample: false
30
  temperature: 0.0
31
+ # is_gsm8k: true
32
  repeats: 1
33
  num_fewshot: 5
34
  filter_list:
35
+ - name: "strict-match"
36
+ filter:
37
+ - function: "regex"
38
+ regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
39
+ - function: "take_first"
40
  - name: "flexible-extract"
41
  filter:
42
  - function: "regex"