LeonardoEmili commited on
Commit
5a5d474
1 Parent(s): 8430db2

Add seq2seq eval benchmark callback (#1274)

Browse files

* Add CausalLMBenchEvalCallback for measuring seq2seq performance

* Fix code for pre-commit

* Fix typing and improve logging

* eval_sample_packing must be false with CausalLMBenchEvalCallback

README.md CHANGED
@@ -784,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time
784
  max_steps:
785
 
786
  eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
787
- eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
 
788
 
789
  loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
790
  loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
 
784
  max_steps:
785
 
786
  eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
787
+ eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
788
+ eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
789
 
790
  loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
791
  loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
examples/llama-2/loftq.yml CHANGED
@@ -60,7 +60,7 @@ s2_attention:
60
  warmup_steps: 10
61
  evals_per_epoch: 4
62
  eval_table_size:
63
- eval_table_max_new_tokens: 128
64
  saves_per_epoch: 1
65
  debug:
66
  deepspeed:
 
60
  warmup_steps: 10
61
  evals_per_epoch: 4
62
  eval_table_size:
63
+ eval_max_new_tokens: 128
64
  saves_per_epoch: 1
65
  debug:
66
  deepspeed:
examples/llama-2/lora.yml CHANGED
@@ -57,7 +57,7 @@ s2_attention:
57
  warmup_steps: 10
58
  evals_per_epoch: 4
59
  eval_table_size:
60
- eval_table_max_new_tokens: 128
61
  saves_per_epoch: 1
62
  debug:
63
  deepspeed:
 
57
  warmup_steps: 10
58
  evals_per_epoch: 4
59
  eval_table_size:
60
+ eval_max_new_tokens: 128
61
  saves_per_epoch: 1
62
  debug:
63
  deepspeed:
examples/mamba/config.yml CHANGED
@@ -49,7 +49,7 @@ flash_attention:
49
  warmup_steps: 10
50
  evals_per_epoch: 4
51
  eval_table_size:
52
- eval_table_max_new_tokens: 128
53
  saves_per_epoch: 1
54
  debug:
55
  deepspeed:
 
49
  warmup_steps: 10
50
  evals_per_epoch: 4
51
  eval_table_size:
52
+ eval_max_new_tokens: 128
53
  saves_per_epoch: 1
54
  debug:
55
  deepspeed:
examples/mistral/Mistral-7b-example/config.yml CHANGED
@@ -61,7 +61,7 @@ flash_attention: true
61
  warmup_steps: 10
62
  evals_per_epoch: 4
63
  eval_table_size:
64
- eval_table_max_new_tokens: 128
65
  saves_per_epoch: 1
66
  debug:
67
  #default deepspeed, can use more aggresive if needed like zero2, zero3
 
61
  warmup_steps: 10
62
  evals_per_epoch: 4
63
  eval_table_size:
64
+ eval_max_new_tokens: 128
65
  saves_per_epoch: 1
66
  debug:
67
  #default deepspeed, can use more aggresive if needed like zero2, zero3
examples/mistral/config.yml CHANGED
@@ -49,7 +49,7 @@ flash_attention: true
49
  warmup_steps: 10
50
  evals_per_epoch: 4
51
  eval_table_size:
52
- eval_table_max_new_tokens: 128
53
  saves_per_epoch: 1
54
  debug:
55
  deepspeed:
 
49
  warmup_steps: 10
50
  evals_per_epoch: 4
51
  eval_table_size:
52
+ eval_max_new_tokens: 128
53
  saves_per_epoch: 1
54
  debug:
55
  deepspeed:
examples/mistral/mixtral.yml CHANGED
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
81
  warmup_steps: 10
82
  evals_per_epoch: 4
83
  eval_table_size:
84
- eval_table_max_new_tokens: 128
85
  saves_per_epoch: 1
86
  debug:
87
  deepspeed: deepspeed_configs/zero2.json
 
81
  warmup_steps: 10
82
  evals_per_epoch: 4
83
  eval_table_size:
84
+ eval_max_new_tokens: 128
85
  saves_per_epoch: 1
86
  debug:
87
  deepspeed: deepspeed_configs/zero2.json
examples/mistral/qlora.yml CHANGED
@@ -68,7 +68,7 @@ loss_watchdog_patience: 3
68
  warmup_steps: 10
69
  evals_per_epoch: 4
70
  eval_table_size:
71
- eval_table_max_new_tokens: 128
72
  saves_per_epoch: 1
73
  debug:
74
  deepspeed:
 
68
  warmup_steps: 10
69
  evals_per_epoch: 4
70
  eval_table_size:
71
+ eval_max_new_tokens: 128
72
  saves_per_epoch: 1
73
  debug:
74
  deepspeed:
examples/qwen/lora.yml CHANGED
@@ -58,7 +58,7 @@ flash_attention:
58
  warmup_steps: 10
59
  evals_per_epoch: 4
60
  eval_table_size:
61
- eval_table_max_new_tokens: 128
62
  saves_per_epoch: 1
63
  debug:
64
  deepspeed:
 
58
  warmup_steps: 10
59
  evals_per_epoch: 4
60
  eval_table_size:
61
+ eval_max_new_tokens: 128
62
  saves_per_epoch: 1
63
  debug:
64
  deepspeed:
examples/qwen/qlora.yml CHANGED
@@ -58,7 +58,7 @@ flash_attention:
58
  warmup_steps: 10
59
  evals_per_epoch: 4
60
  eval_table_size:
61
- eval_table_max_new_tokens: 128
62
  saves_per_epoch: 1
63
  debug:
64
  deepspeed:
 
58
  warmup_steps: 10
59
  evals_per_epoch: 4
60
  eval_table_size:
61
+ eval_max_new_tokens: 128
62
  saves_per_epoch: 1
63
  debug:
64
  deepspeed:
examples/yi-34B-chat/qlora.yml CHANGED
@@ -29,7 +29,7 @@ num_epochs: 1
29
  val_set_size: 0.1
30
  evals_per_epoch: 5
31
  eval_table_size:
32
- eval_table_max_new_tokens: 128
33
  eval_sample_packing: false
34
  eval_batch_size: 1
35
 
 
29
  val_set_size: 0.1
30
  evals_per_epoch: 5
31
  eval_table_size:
32
+ eval_max_new_tokens: 128
33
  eval_sample_packing: false
34
  eval_batch_size: 1
35
 
requirements.txt CHANGED
@@ -23,7 +23,7 @@ numba
23
  numpy>=1.24.4
24
  mlflow
25
  # qlora things
26
- evaluate==0.4.0
27
  scipy
28
  scikit-learn==1.2.2
29
  pynvml
 
23
  numpy>=1.24.4
24
  mlflow
25
  # qlora things
26
+ evaluate==0.4.1
27
  scipy
28
  scikit-learn==1.2.2
29
  pynvml
src/axolotl/core/trainer_builder.py CHANGED
@@ -38,6 +38,7 @@ from axolotl.utils.callbacks import (
38
  SaveAxolotlConfigtoWandBCallback,
39
  SaveBetterTransformerModelCallback,
40
  bench_eval_callback_factory,
 
41
  log_prediction_callback_factory,
42
  )
43
  from axolotl.utils.collators import (
@@ -148,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments):
148
  do_bench_eval: Optional[bool] = field(
149
  default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
150
  )
 
 
 
151
  max_bench_samples: Optional[int] = field(
152
  default=None,
153
  metadata={
@@ -664,6 +668,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
664
 
665
  if self.cfg.do_bench_eval:
666
  callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
 
 
 
 
 
667
 
668
  if self.cfg.early_stopping_patience:
669
  early_stop_cb = EarlyStoppingCallback(
@@ -812,6 +821,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
812
  training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
813
  if self.cfg.bench_dataset:
814
  training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
 
 
815
  if self.cfg.metric_for_best_model:
816
  training_arguments_kwargs[
817
  "metric_for_best_model"
 
38
  SaveAxolotlConfigtoWandBCallback,
39
  SaveBetterTransformerModelCallback,
40
  bench_eval_callback_factory,
41
+ causal_lm_bench_eval_callback_factory,
42
  log_prediction_callback_factory,
43
  )
44
  from axolotl.utils.collators import (
 
149
  do_bench_eval: Optional[bool] = field(
150
  default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
151
  )
152
+ do_causal_lm_eval: Optional[bool] = field(
153
+ default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
154
+ )
155
  max_bench_samples: Optional[int] = field(
156
  default=None,
157
  metadata={
 
668
 
669
  if self.cfg.do_bench_eval:
670
  callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
671
+ if self.cfg.do_causal_lm_eval:
672
+ CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
673
+ trainer, self.tokenizer
674
+ )
675
+ callbacks.append(CausalLMBenchEvalCallback(self.cfg))
676
 
677
  if self.cfg.early_stopping_patience:
678
  early_stop_cb = EarlyStoppingCallback(
 
821
  training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
822
  if self.cfg.bench_dataset:
823
  training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
824
+ if self.cfg.do_causal_lm_eval:
825
+ training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
826
  if self.cfg.metric_for_best_model:
827
  training_arguments_kwargs[
828
  "metric_for_best_model"
src/axolotl/utils/callbacks.py CHANGED
@@ -361,6 +361,187 @@ def bench_eval_callback_factory(trainer, tokenizer):
361
  return BenchEvalCallback
362
 
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  def log_prediction_callback_factory(trainer: Trainer, tokenizer):
365
  class LogPredictionCallback(TrainerCallback):
366
  """Callback to log prediction values during each evaluation"""
@@ -388,7 +569,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
388
 
389
  # pylint: disable=duplicate-code
390
  generation_config = GenerationConfig(
391
- max_new_tokens=self.cfg.eval_table_max_new_tokens,
392
  bos_token_id=tokenizer.bos_token_id,
393
  eos_token_id=tokenizer.eos_token_id,
394
  pad_token_id=tokenizer.pad_token_id,
 
361
  return BenchEvalCallback
362
 
363
 
364
+ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
365
+ class CausalLMBenchEvalCallback(TrainerCallback):
366
+ """Callback to log prediction values during each evaluation"""
367
+
368
+ def __init__(self, cfg):
369
+ self.cfg = cfg
370
+ self.logged = False
371
+ self.metrics = self.__maybe_load_metrics()
372
+
373
+ def __maybe_load_metrics(self):
374
+ metrics = {}
375
+ for metric in self.cfg.eval_causal_lm_metrics:
376
+ try:
377
+ metrics[metric] = evaluate.load(metric)
378
+ except Exception as exc: # pylint: disable=broad-exception-caught
379
+ LOG.warning(f"{metric}: {exc.args}")
380
+ return metrics
381
+
382
+ def on_evaluate(
383
+ self,
384
+ args: AxolotlTrainingArguments, # pylint: disable=unused-argument
385
+ state: TrainerState,
386
+ control: TrainerControl,
387
+ train_dataloader, # pylint: disable=unused-argument
388
+ eval_dataloader,
389
+ **kwargs, # pylint: disable=unused-argument
390
+ ):
391
+ trainer.model.eval()
392
+ device = torch.device(self.cfg.device)
393
+
394
+ # pylint: disable=duplicate-code
395
+ generation_config = GenerationConfig(
396
+ max_new_tokens=self.cfg.eval_max_new_tokens,
397
+ bos_token_id=tokenizer.bos_token_id,
398
+ eos_token_id=tokenizer.eos_token_id,
399
+ pad_token_id=tokenizer.pad_token_id,
400
+ do_sample=False,
401
+ use_cache=True,
402
+ return_dict_in_generate=True,
403
+ output_attentions=False,
404
+ output_hidden_states=False,
405
+ output_scores=False,
406
+ )
407
+
408
+ def find_ranges(lst):
409
+ ranges = []
410
+ start = 0
411
+ for i in range(1, len(lst)):
412
+ if lst[i] == 0:
413
+ ranges.append((start, i - 1))
414
+ start = i
415
+ end = len(lst) - 1
416
+ ranges.append((start, end))
417
+ return ranges
418
+
419
+ def compute(metric: evaluate.Metric, **kwargs):
420
+ # safely compute a metric and return the score if the format is correct
421
+ metric_score = None
422
+ try:
423
+ metric_score = metric.compute(**kwargs)
424
+ return (
425
+ metric_score["score"]
426
+ if "score" in metric_score
427
+ else metric_score["mean_score"]
428
+ )
429
+ except Exception: # pylint: disable=broad-exception-caught
430
+ LOG.debug(
431
+ f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
432
+ )
433
+ return metric_score
434
+
435
+ def evaluate_preds(sources, predictions, references):
436
+ scores = {}
437
+
438
+ for metric_name, metric in self.metrics.items():
439
+ score = compute(
440
+ metric,
441
+ references=references,
442
+ predictions=predictions,
443
+ sources=sources,
444
+ )
445
+ score = score or compute(
446
+ metric,
447
+ references=[[r] for r in references],
448
+ predictions=predictions,
449
+ )
450
+ scores[metric_name] = score
451
+ return scores
452
+
453
+ def predict_with_generate():
454
+ eval_src, eval_pred, eval_ref = [], [], []
455
+
456
+ for batch in tqdm(eval_dataloader):
457
+ batch_labels = batch["labels"].to(device)
458
+ batch_input_ids = batch["input_ids"].to(device)
459
+
460
+ if "position_ids" in batch:
461
+ batch_pos_ids = batch["position_ids"].tolist()
462
+ else:
463
+ batch_pos_ids = [None] * len(batch["input_ids"])
464
+
465
+ prompt_token_ids_list = []
466
+ completion_token_ids_list = []
467
+
468
+ for input_ids_all, labels_all, pos_ids in zip(
469
+ batch_input_ids,
470
+ batch_labels,
471
+ batch_pos_ids,
472
+ ):
473
+ if pos_ids is None:
474
+ pos_ranges = [(0, len(input_ids_all) - 1)]
475
+ else:
476
+ pos_ranges = find_ranges(pos_ids)
477
+
478
+ for pos_range in pos_ranges:
479
+ start, end = pos_range
480
+ if start == end:
481
+ continue
482
+
483
+ input_ids = input_ids_all[start : end + 1]
484
+ labels = labels_all[start : end + 1]
485
+
486
+ tokens_without_loss = labels == IGNORE_INDEX
487
+ tokens_with_loss = labels != IGNORE_INDEX
488
+ tokens_exclude_padding = input_ids != tokenizer.pad_token_id
489
+ prompt_token_includes = (
490
+ tokens_without_loss & tokens_exclude_padding
491
+ )
492
+
493
+ prompt_token_ids = input_ids[prompt_token_includes]
494
+ prompt_token_ids_list.append(prompt_token_ids)
495
+
496
+ completion_token_ids = input_ids[tokens_with_loss]
497
+ completion_token_ids_list.append(completion_token_ids)
498
+
499
+ prompt_texts = tokenizer.batch_decode(
500
+ prompt_token_ids_list, skip_special_tokens=True
501
+ )
502
+ completion_texts = tokenizer.batch_decode(
503
+ completion_token_ids_list, skip_special_tokens=True
504
+ )
505
+
506
+ with torch.no_grad():
507
+ prompt_encoding = tokenizer(
508
+ prompt_texts, padding=True, return_tensors="pt"
509
+ ).to(self.cfg.device)
510
+ predictions = trainer.model.generate(
511
+ **prompt_encoding, generation_config=generation_config
512
+ )
513
+
514
+ prediction_all_tokens = predictions["sequences"].cpu().tolist()
515
+ prediction_without_prompt_tokens_list = []
516
+ for prompt_token_ids, prediction_tokens in zip(
517
+ prompt_token_ids_list, prediction_all_tokens
518
+ ):
519
+ prediction_without_prompt_tokens = prediction_tokens[
520
+ len(prompt_token_ids) :
521
+ ]
522
+ prediction_without_prompt_tokens_list.append(
523
+ prediction_without_prompt_tokens
524
+ )
525
+
526
+ predicted_texts = tokenizer.batch_decode(
527
+ prediction_without_prompt_tokens_list, skip_special_tokens=True
528
+ )
529
+
530
+ eval_src.extend(prompt_texts)
531
+ eval_pred.extend(predicted_texts)
532
+ eval_ref.extend(completion_texts)
533
+
534
+ return eval_src, eval_pred, eval_ref
535
+
536
+ if is_main_process():
537
+ eval_preds = predict_with_generate()
538
+ trainer.log(evaluate_preds(*eval_preds))
539
+
540
+ return control
541
+
542
+ return CausalLMBenchEvalCallback
543
+
544
+
545
  def log_prediction_callback_factory(trainer: Trainer, tokenizer):
546
  class LogPredictionCallback(TrainerCallback):
547
  """Callback to log prediction values during each evaluation"""
 
569
 
570
  # pylint: disable=duplicate-code
571
  generation_config = GenerationConfig(
572
+ max_new_tokens=self.cfg.eval_max_new_tokens,
573
  bos_token_id=tokenizer.bos_token_id,
574
  eos_token_id=tokenizer.eos_token_id,
575
  pad_token_id=tokenizer.pad_token_id,
src/axolotl/utils/config.py CHANGED
@@ -56,7 +56,13 @@ def normalize_config(cfg):
56
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
57
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
58
  cfg.eval_table_size = cfg.eval_table_size or 0
59
- cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
 
 
 
 
 
 
60
  choose_device(cfg)
61
  cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
62
  if cfg.ddp:
@@ -550,6 +556,21 @@ def validate_config(cfg):
550
  if cfg.fsdp and "bnb" in cfg.optimizer:
551
  raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  # TODO
554
  # MPT 7b
555
  # https://github.com/facebookresearch/bitsandbytes/issues/25
 
56
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
57
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
58
  cfg.eval_table_size = cfg.eval_table_size or 0
59
+ cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
60
+ cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
61
+ "sacrebleu",
62
+ "comet",
63
+ "ter",
64
+ "chrf",
65
+ ]
66
  choose_device(cfg)
67
  cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
68
  if cfg.ddp:
 
556
  if cfg.fsdp and "bnb" in cfg.optimizer:
557
  raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
558
 
559
+ if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
560
+ raise ValueError(
561
+ "do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
562
+ )
563
+
564
+ if cfg.eval_causal_lm_metrics:
565
+ supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
566
+ if not isinstance(cfg.eval_causal_lm_metrics, list):
567
+ raise ValueError("eval_causal_lm_metrics must be a list")
568
+ # only ["sacrebleu", "comet", "ter", "chrf"] supported
569
+ if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
570
+ raise ValueError(
571
+ f"eval_causal_lm_metrics must be one of {supported_metrics}"
572
+ )
573
+
574
  # TODO
575
  # MPT 7b
576
  # https://github.com/facebookresearch/bitsandbytes/issues/25