LeonardoEmili
commited on
Commit
•
5a5d474
1
Parent(s):
8430db2
Add seq2seq eval benchmark callback (#1274)
Browse files* Add CausalLMBenchEvalCallback for measuring seq2seq performance
* Fix code for pre-commit
* Fix typing and improve logging
* eval_sample_packing must be false with CausalLMBenchEvalCallback
- README.md +2 -1
- examples/llama-2/loftq.yml +1 -1
- examples/llama-2/lora.yml +1 -1
- examples/mamba/config.yml +1 -1
- examples/mistral/Mistral-7b-example/config.yml +1 -1
- examples/mistral/config.yml +1 -1
- examples/mistral/mixtral.yml +1 -1
- examples/mistral/qlora.yml +1 -1
- examples/qwen/lora.yml +1 -1
- examples/qwen/qlora.yml +1 -1
- examples/yi-34B-chat/qlora.yml +1 -1
- requirements.txt +1 -1
- src/axolotl/core/trainer_builder.py +11 -0
- src/axolotl/utils/callbacks.py +182 -1
- src/axolotl/utils/config.py +22 -1
README.md
CHANGED
@@ -784,7 +784,8 @@ save_total_limit: # Checkpoints saved at a time
|
|
784 |
max_steps:
|
785 |
|
786 |
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
787 |
-
|
|
|
788 |
|
789 |
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
790 |
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
|
|
784 |
max_steps:
|
785 |
|
786 |
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
787 |
+
eval_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
788 |
+
eval_causal_lm_metrics: # HF evaluate metrics used during evaluation. Default is ["sacrebleu", "comet", "ter", chrf]
|
789 |
|
790 |
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
791 |
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
examples/llama-2/loftq.yml
CHANGED
@@ -60,7 +60,7 @@ s2_attention:
|
|
60 |
warmup_steps: 10
|
61 |
evals_per_epoch: 4
|
62 |
eval_table_size:
|
63 |
-
|
64 |
saves_per_epoch: 1
|
65 |
debug:
|
66 |
deepspeed:
|
|
|
60 |
warmup_steps: 10
|
61 |
evals_per_epoch: 4
|
62 |
eval_table_size:
|
63 |
+
eval_max_new_tokens: 128
|
64 |
saves_per_epoch: 1
|
65 |
debug:
|
66 |
deepspeed:
|
examples/llama-2/lora.yml
CHANGED
@@ -57,7 +57,7 @@ s2_attention:
|
|
57 |
warmup_steps: 10
|
58 |
evals_per_epoch: 4
|
59 |
eval_table_size:
|
60 |
-
|
61 |
saves_per_epoch: 1
|
62 |
debug:
|
63 |
deepspeed:
|
|
|
57 |
warmup_steps: 10
|
58 |
evals_per_epoch: 4
|
59 |
eval_table_size:
|
60 |
+
eval_max_new_tokens: 128
|
61 |
saves_per_epoch: 1
|
62 |
debug:
|
63 |
deepspeed:
|
examples/mamba/config.yml
CHANGED
@@ -49,7 +49,7 @@ flash_attention:
|
|
49 |
warmup_steps: 10
|
50 |
evals_per_epoch: 4
|
51 |
eval_table_size:
|
52 |
-
|
53 |
saves_per_epoch: 1
|
54 |
debug:
|
55 |
deepspeed:
|
|
|
49 |
warmup_steps: 10
|
50 |
evals_per_epoch: 4
|
51 |
eval_table_size:
|
52 |
+
eval_max_new_tokens: 128
|
53 |
saves_per_epoch: 1
|
54 |
debug:
|
55 |
deepspeed:
|
examples/mistral/Mistral-7b-example/config.yml
CHANGED
@@ -61,7 +61,7 @@ flash_attention: true
|
|
61 |
warmup_steps: 10
|
62 |
evals_per_epoch: 4
|
63 |
eval_table_size:
|
64 |
-
|
65 |
saves_per_epoch: 1
|
66 |
debug:
|
67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
|
|
61 |
warmup_steps: 10
|
62 |
evals_per_epoch: 4
|
63 |
eval_table_size:
|
64 |
+
eval_max_new_tokens: 128
|
65 |
saves_per_epoch: 1
|
66 |
debug:
|
67 |
#default deepspeed, can use more aggresive if needed like zero2, zero3
|
examples/mistral/config.yml
CHANGED
@@ -49,7 +49,7 @@ flash_attention: true
|
|
49 |
warmup_steps: 10
|
50 |
evals_per_epoch: 4
|
51 |
eval_table_size:
|
52 |
-
|
53 |
saves_per_epoch: 1
|
54 |
debug:
|
55 |
deepspeed:
|
|
|
49 |
warmup_steps: 10
|
50 |
evals_per_epoch: 4
|
51 |
eval_table_size:
|
52 |
+
eval_max_new_tokens: 128
|
53 |
saves_per_epoch: 1
|
54 |
debug:
|
55 |
deepspeed:
|
examples/mistral/mixtral.yml
CHANGED
@@ -81,7 +81,7 @@ loss_watchdog_patience: 3
|
|
81 |
warmup_steps: 10
|
82 |
evals_per_epoch: 4
|
83 |
eval_table_size:
|
84 |
-
|
85 |
saves_per_epoch: 1
|
86 |
debug:
|
87 |
deepspeed: deepspeed_configs/zero2.json
|
|
|
81 |
warmup_steps: 10
|
82 |
evals_per_epoch: 4
|
83 |
eval_table_size:
|
84 |
+
eval_max_new_tokens: 128
|
85 |
saves_per_epoch: 1
|
86 |
debug:
|
87 |
deepspeed: deepspeed_configs/zero2.json
|
examples/mistral/qlora.yml
CHANGED
@@ -68,7 +68,7 @@ loss_watchdog_patience: 3
|
|
68 |
warmup_steps: 10
|
69 |
evals_per_epoch: 4
|
70 |
eval_table_size:
|
71 |
-
|
72 |
saves_per_epoch: 1
|
73 |
debug:
|
74 |
deepspeed:
|
|
|
68 |
warmup_steps: 10
|
69 |
evals_per_epoch: 4
|
70 |
eval_table_size:
|
71 |
+
eval_max_new_tokens: 128
|
72 |
saves_per_epoch: 1
|
73 |
debug:
|
74 |
deepspeed:
|
examples/qwen/lora.yml
CHANGED
@@ -58,7 +58,7 @@ flash_attention:
|
|
58 |
warmup_steps: 10
|
59 |
evals_per_epoch: 4
|
60 |
eval_table_size:
|
61 |
-
|
62 |
saves_per_epoch: 1
|
63 |
debug:
|
64 |
deepspeed:
|
|
|
58 |
warmup_steps: 10
|
59 |
evals_per_epoch: 4
|
60 |
eval_table_size:
|
61 |
+
eval_max_new_tokens: 128
|
62 |
saves_per_epoch: 1
|
63 |
debug:
|
64 |
deepspeed:
|
examples/qwen/qlora.yml
CHANGED
@@ -58,7 +58,7 @@ flash_attention:
|
|
58 |
warmup_steps: 10
|
59 |
evals_per_epoch: 4
|
60 |
eval_table_size:
|
61 |
-
|
62 |
saves_per_epoch: 1
|
63 |
debug:
|
64 |
deepspeed:
|
|
|
58 |
warmup_steps: 10
|
59 |
evals_per_epoch: 4
|
60 |
eval_table_size:
|
61 |
+
eval_max_new_tokens: 128
|
62 |
saves_per_epoch: 1
|
63 |
debug:
|
64 |
deepspeed:
|
examples/yi-34B-chat/qlora.yml
CHANGED
@@ -29,7 +29,7 @@ num_epochs: 1
|
|
29 |
val_set_size: 0.1
|
30 |
evals_per_epoch: 5
|
31 |
eval_table_size:
|
32 |
-
|
33 |
eval_sample_packing: false
|
34 |
eval_batch_size: 1
|
35 |
|
|
|
29 |
val_set_size: 0.1
|
30 |
evals_per_epoch: 5
|
31 |
eval_table_size:
|
32 |
+
eval_max_new_tokens: 128
|
33 |
eval_sample_packing: false
|
34 |
eval_batch_size: 1
|
35 |
|
requirements.txt
CHANGED
@@ -23,7 +23,7 @@ numba
|
|
23 |
numpy>=1.24.4
|
24 |
mlflow
|
25 |
# qlora things
|
26 |
-
evaluate==0.4.
|
27 |
scipy
|
28 |
scikit-learn==1.2.2
|
29 |
pynvml
|
|
|
23 |
numpy>=1.24.4
|
24 |
mlflow
|
25 |
# qlora things
|
26 |
+
evaluate==0.4.1
|
27 |
scipy
|
28 |
scikit-learn==1.2.2
|
29 |
pynvml
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -38,6 +38,7 @@ from axolotl.utils.callbacks import (
|
|
38 |
SaveAxolotlConfigtoWandBCallback,
|
39 |
SaveBetterTransformerModelCallback,
|
40 |
bench_eval_callback_factory,
|
|
|
41 |
log_prediction_callback_factory,
|
42 |
)
|
43 |
from axolotl.utils.collators import (
|
@@ -148,6 +149,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
148 |
do_bench_eval: Optional[bool] = field(
|
149 |
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
150 |
)
|
|
|
|
|
|
|
151 |
max_bench_samples: Optional[int] = field(
|
152 |
default=None,
|
153 |
metadata={
|
@@ -664,6 +668,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
664 |
|
665 |
if self.cfg.do_bench_eval:
|
666 |
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
|
|
|
|
|
|
|
|
|
|
667 |
|
668 |
if self.cfg.early_stopping_patience:
|
669 |
early_stop_cb = EarlyStoppingCallback(
|
@@ -812,6 +821,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
812 |
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
813 |
if self.cfg.bench_dataset:
|
814 |
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
|
|
|
|
815 |
if self.cfg.metric_for_best_model:
|
816 |
training_arguments_kwargs[
|
817 |
"metric_for_best_model"
|
|
|
38 |
SaveAxolotlConfigtoWandBCallback,
|
39 |
SaveBetterTransformerModelCallback,
|
40 |
bench_eval_callback_factory,
|
41 |
+
causal_lm_bench_eval_callback_factory,
|
42 |
log_prediction_callback_factory,
|
43 |
)
|
44 |
from axolotl.utils.collators import (
|
|
|
149 |
do_bench_eval: Optional[bool] = field(
|
150 |
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
151 |
)
|
152 |
+
do_causal_lm_eval: Optional[bool] = field(
|
153 |
+
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
154 |
+
)
|
155 |
max_bench_samples: Optional[int] = field(
|
156 |
default=None,
|
157 |
metadata={
|
|
|
668 |
|
669 |
if self.cfg.do_bench_eval:
|
670 |
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
671 |
+
if self.cfg.do_causal_lm_eval:
|
672 |
+
CausalLMBenchEvalCallback = causal_lm_bench_eval_callback_factory(
|
673 |
+
trainer, self.tokenizer
|
674 |
+
)
|
675 |
+
callbacks.append(CausalLMBenchEvalCallback(self.cfg))
|
676 |
|
677 |
if self.cfg.early_stopping_patience:
|
678 |
early_stop_cb = EarlyStoppingCallback(
|
|
|
821 |
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
822 |
if self.cfg.bench_dataset:
|
823 |
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
824 |
+
if self.cfg.do_causal_lm_eval:
|
825 |
+
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
826 |
if self.cfg.metric_for_best_model:
|
827 |
training_arguments_kwargs[
|
828 |
"metric_for_best_model"
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -361,6 +361,187 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
361 |
return BenchEvalCallback
|
362 |
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
365 |
class LogPredictionCallback(TrainerCallback):
|
366 |
"""Callback to log prediction values during each evaluation"""
|
@@ -388,7 +569,7 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
388 |
|
389 |
# pylint: disable=duplicate-code
|
390 |
generation_config = GenerationConfig(
|
391 |
-
max_new_tokens=self.cfg.
|
392 |
bos_token_id=tokenizer.bos_token_id,
|
393 |
eos_token_id=tokenizer.eos_token_id,
|
394 |
pad_token_id=tokenizer.pad_token_id,
|
|
|
361 |
return BenchEvalCallback
|
362 |
|
363 |
|
364 |
+
def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
365 |
+
class CausalLMBenchEvalCallback(TrainerCallback):
|
366 |
+
"""Callback to log prediction values during each evaluation"""
|
367 |
+
|
368 |
+
def __init__(self, cfg):
|
369 |
+
self.cfg = cfg
|
370 |
+
self.logged = False
|
371 |
+
self.metrics = self.__maybe_load_metrics()
|
372 |
+
|
373 |
+
def __maybe_load_metrics(self):
|
374 |
+
metrics = {}
|
375 |
+
for metric in self.cfg.eval_causal_lm_metrics:
|
376 |
+
try:
|
377 |
+
metrics[metric] = evaluate.load(metric)
|
378 |
+
except Exception as exc: # pylint: disable=broad-exception-caught
|
379 |
+
LOG.warning(f"{metric}: {exc.args}")
|
380 |
+
return metrics
|
381 |
+
|
382 |
+
def on_evaluate(
|
383 |
+
self,
|
384 |
+
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
385 |
+
state: TrainerState,
|
386 |
+
control: TrainerControl,
|
387 |
+
train_dataloader, # pylint: disable=unused-argument
|
388 |
+
eval_dataloader,
|
389 |
+
**kwargs, # pylint: disable=unused-argument
|
390 |
+
):
|
391 |
+
trainer.model.eval()
|
392 |
+
device = torch.device(self.cfg.device)
|
393 |
+
|
394 |
+
# pylint: disable=duplicate-code
|
395 |
+
generation_config = GenerationConfig(
|
396 |
+
max_new_tokens=self.cfg.eval_max_new_tokens,
|
397 |
+
bos_token_id=tokenizer.bos_token_id,
|
398 |
+
eos_token_id=tokenizer.eos_token_id,
|
399 |
+
pad_token_id=tokenizer.pad_token_id,
|
400 |
+
do_sample=False,
|
401 |
+
use_cache=True,
|
402 |
+
return_dict_in_generate=True,
|
403 |
+
output_attentions=False,
|
404 |
+
output_hidden_states=False,
|
405 |
+
output_scores=False,
|
406 |
+
)
|
407 |
+
|
408 |
+
def find_ranges(lst):
|
409 |
+
ranges = []
|
410 |
+
start = 0
|
411 |
+
for i in range(1, len(lst)):
|
412 |
+
if lst[i] == 0:
|
413 |
+
ranges.append((start, i - 1))
|
414 |
+
start = i
|
415 |
+
end = len(lst) - 1
|
416 |
+
ranges.append((start, end))
|
417 |
+
return ranges
|
418 |
+
|
419 |
+
def compute(metric: evaluate.Metric, **kwargs):
|
420 |
+
# safely compute a metric and return the score if the format is correct
|
421 |
+
metric_score = None
|
422 |
+
try:
|
423 |
+
metric_score = metric.compute(**kwargs)
|
424 |
+
return (
|
425 |
+
metric_score["score"]
|
426 |
+
if "score" in metric_score
|
427 |
+
else metric_score["mean_score"]
|
428 |
+
)
|
429 |
+
except Exception: # pylint: disable=broad-exception-caught
|
430 |
+
LOG.debug(
|
431 |
+
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
|
432 |
+
)
|
433 |
+
return metric_score
|
434 |
+
|
435 |
+
def evaluate_preds(sources, predictions, references):
|
436 |
+
scores = {}
|
437 |
+
|
438 |
+
for metric_name, metric in self.metrics.items():
|
439 |
+
score = compute(
|
440 |
+
metric,
|
441 |
+
references=references,
|
442 |
+
predictions=predictions,
|
443 |
+
sources=sources,
|
444 |
+
)
|
445 |
+
score = score or compute(
|
446 |
+
metric,
|
447 |
+
references=[[r] for r in references],
|
448 |
+
predictions=predictions,
|
449 |
+
)
|
450 |
+
scores[metric_name] = score
|
451 |
+
return scores
|
452 |
+
|
453 |
+
def predict_with_generate():
|
454 |
+
eval_src, eval_pred, eval_ref = [], [], []
|
455 |
+
|
456 |
+
for batch in tqdm(eval_dataloader):
|
457 |
+
batch_labels = batch["labels"].to(device)
|
458 |
+
batch_input_ids = batch["input_ids"].to(device)
|
459 |
+
|
460 |
+
if "position_ids" in batch:
|
461 |
+
batch_pos_ids = batch["position_ids"].tolist()
|
462 |
+
else:
|
463 |
+
batch_pos_ids = [None] * len(batch["input_ids"])
|
464 |
+
|
465 |
+
prompt_token_ids_list = []
|
466 |
+
completion_token_ids_list = []
|
467 |
+
|
468 |
+
for input_ids_all, labels_all, pos_ids in zip(
|
469 |
+
batch_input_ids,
|
470 |
+
batch_labels,
|
471 |
+
batch_pos_ids,
|
472 |
+
):
|
473 |
+
if pos_ids is None:
|
474 |
+
pos_ranges = [(0, len(input_ids_all) - 1)]
|
475 |
+
else:
|
476 |
+
pos_ranges = find_ranges(pos_ids)
|
477 |
+
|
478 |
+
for pos_range in pos_ranges:
|
479 |
+
start, end = pos_range
|
480 |
+
if start == end:
|
481 |
+
continue
|
482 |
+
|
483 |
+
input_ids = input_ids_all[start : end + 1]
|
484 |
+
labels = labels_all[start : end + 1]
|
485 |
+
|
486 |
+
tokens_without_loss = labels == IGNORE_INDEX
|
487 |
+
tokens_with_loss = labels != IGNORE_INDEX
|
488 |
+
tokens_exclude_padding = input_ids != tokenizer.pad_token_id
|
489 |
+
prompt_token_includes = (
|
490 |
+
tokens_without_loss & tokens_exclude_padding
|
491 |
+
)
|
492 |
+
|
493 |
+
prompt_token_ids = input_ids[prompt_token_includes]
|
494 |
+
prompt_token_ids_list.append(prompt_token_ids)
|
495 |
+
|
496 |
+
completion_token_ids = input_ids[tokens_with_loss]
|
497 |
+
completion_token_ids_list.append(completion_token_ids)
|
498 |
+
|
499 |
+
prompt_texts = tokenizer.batch_decode(
|
500 |
+
prompt_token_ids_list, skip_special_tokens=True
|
501 |
+
)
|
502 |
+
completion_texts = tokenizer.batch_decode(
|
503 |
+
completion_token_ids_list, skip_special_tokens=True
|
504 |
+
)
|
505 |
+
|
506 |
+
with torch.no_grad():
|
507 |
+
prompt_encoding = tokenizer(
|
508 |
+
prompt_texts, padding=True, return_tensors="pt"
|
509 |
+
).to(self.cfg.device)
|
510 |
+
predictions = trainer.model.generate(
|
511 |
+
**prompt_encoding, generation_config=generation_config
|
512 |
+
)
|
513 |
+
|
514 |
+
prediction_all_tokens = predictions["sequences"].cpu().tolist()
|
515 |
+
prediction_without_prompt_tokens_list = []
|
516 |
+
for prompt_token_ids, prediction_tokens in zip(
|
517 |
+
prompt_token_ids_list, prediction_all_tokens
|
518 |
+
):
|
519 |
+
prediction_without_prompt_tokens = prediction_tokens[
|
520 |
+
len(prompt_token_ids) :
|
521 |
+
]
|
522 |
+
prediction_without_prompt_tokens_list.append(
|
523 |
+
prediction_without_prompt_tokens
|
524 |
+
)
|
525 |
+
|
526 |
+
predicted_texts = tokenizer.batch_decode(
|
527 |
+
prediction_without_prompt_tokens_list, skip_special_tokens=True
|
528 |
+
)
|
529 |
+
|
530 |
+
eval_src.extend(prompt_texts)
|
531 |
+
eval_pred.extend(predicted_texts)
|
532 |
+
eval_ref.extend(completion_texts)
|
533 |
+
|
534 |
+
return eval_src, eval_pred, eval_ref
|
535 |
+
|
536 |
+
if is_main_process():
|
537 |
+
eval_preds = predict_with_generate()
|
538 |
+
trainer.log(evaluate_preds(*eval_preds))
|
539 |
+
|
540 |
+
return control
|
541 |
+
|
542 |
+
return CausalLMBenchEvalCallback
|
543 |
+
|
544 |
+
|
545 |
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
546 |
class LogPredictionCallback(TrainerCallback):
|
547 |
"""Callback to log prediction values during each evaluation"""
|
|
|
569 |
|
570 |
# pylint: disable=duplicate-code
|
571 |
generation_config = GenerationConfig(
|
572 |
+
max_new_tokens=self.cfg.eval_max_new_tokens,
|
573 |
bos_token_id=tokenizer.bos_token_id,
|
574 |
eos_token_id=tokenizer.eos_token_id,
|
575 |
pad_token_id=tokenizer.pad_token_id,
|
src/axolotl/utils/config.py
CHANGED
@@ -56,7 +56,13 @@ def normalize_config(cfg):
|
|
56 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
57 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
58 |
cfg.eval_table_size = cfg.eval_table_size or 0
|
59 |
-
cfg.
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
choose_device(cfg)
|
61 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
62 |
if cfg.ddp:
|
@@ -550,6 +556,21 @@ def validate_config(cfg):
|
|
550 |
if cfg.fsdp and "bnb" in cfg.optimizer:
|
551 |
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
# TODO
|
554 |
# MPT 7b
|
555 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
56 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
57 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
58 |
cfg.eval_table_size = cfg.eval_table_size or 0
|
59 |
+
cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
|
60 |
+
cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
|
61 |
+
"sacrebleu",
|
62 |
+
"comet",
|
63 |
+
"ter",
|
64 |
+
"chrf",
|
65 |
+
]
|
66 |
choose_device(cfg)
|
67 |
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
|
68 |
if cfg.ddp:
|
|
|
556 |
if cfg.fsdp and "bnb" in cfg.optimizer:
|
557 |
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
|
558 |
|
559 |
+
if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
|
560 |
+
raise ValueError(
|
561 |
+
"do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
|
562 |
+
)
|
563 |
+
|
564 |
+
if cfg.eval_causal_lm_metrics:
|
565 |
+
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
566 |
+
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
567 |
+
raise ValueError("eval_causal_lm_metrics must be a list")
|
568 |
+
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
569 |
+
if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
|
570 |
+
raise ValueError(
|
571 |
+
f"eval_causal_lm_metrics must be one of {supported_metrics}"
|
572 |
+
)
|
573 |
+
|
574 |
# TODO
|
575 |
# MPT 7b
|
576 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|