Early stopping metric (#537)
Browse files* set early stopping metric to check
* tweak how load_best_model_at_end gets set for early stopping
* add validation for earl;y stopping patience
* remove negation
* save results to metrics in callback
* move early stopping callback after the benchmark evals
* broadcast metrics so early stopping works
src/axolotl/utils/callbacks.py
CHANGED
@@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
|
25 |
from axolotl.utils.bench import log_gpu_memory_usage
|
26 |
from axolotl.utils.distributed import (
|
27 |
barrier,
|
|
|
28 |
gather_scalar_from_all_ranks,
|
29 |
get_world_size,
|
30 |
is_distributed,
|
@@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
271 |
lambda: len(data_loader), get_world_size()
|
272 |
)
|
273 |
|
|
|
274 |
if is_distributed() and not is_main_process():
|
275 |
dist.gather_object(local_bench_names, dst=0)
|
276 |
else:
|
@@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|
316 |
)["accuracy"]
|
317 |
trainer.log(results)
|
318 |
|
|
|
|
|
|
|
|
|
319 |
return BenchEvalCallback
|
|
|
25 |
from axolotl.utils.bench import log_gpu_memory_usage
|
26 |
from axolotl.utils.distributed import (
|
27 |
barrier,
|
28 |
+
broadcast_dict,
|
29 |
gather_scalar_from_all_ranks,
|
30 |
get_world_size,
|
31 |
is_distributed,
|
|
|
272 |
lambda: len(data_loader), get_world_size()
|
273 |
)
|
274 |
|
275 |
+
results = {}
|
276 |
if is_distributed() and not is_main_process():
|
277 |
dist.gather_object(local_bench_names, dst=0)
|
278 |
else:
|
|
|
318 |
)["accuracy"]
|
319 |
trainer.log(results)
|
320 |
|
321 |
+
results = broadcast_dict(results)
|
322 |
+
for key, val in results.items():
|
323 |
+
metrics[key] = val
|
324 |
+
|
325 |
return BenchEvalCallback
|
src/axolotl/utils/config.py
CHANGED
@@ -220,6 +220,15 @@ def validate_config(cfg):
|
|
220 |
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
221 |
)
|
222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
# TODO
|
224 |
# MPT 7b
|
225 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
220 |
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
221 |
)
|
222 |
|
223 |
+
if cfg.early_stopping_patience:
|
224 |
+
if not cfg.save_steps or not cfg.eval_steps:
|
225 |
+
raise ValueError(
|
226 |
+
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
227 |
+
)
|
228 |
+
if cfg.save_steps % cfg.eval_steps != 0:
|
229 |
+
raise ValueError(
|
230 |
+
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
231 |
+
)
|
232 |
# TODO
|
233 |
# MPT 7b
|
234 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
src/axolotl/utils/distributed.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
utility helpers for distributed checks
|
3 |
"""
|
4 |
import os
|
|
|
5 |
from contextlib import contextmanager
|
6 |
|
7 |
import torch
|
@@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|
93 |
gathered_values.append(float(tensor.item()))
|
94 |
return gathered_values
|
95 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
utility helpers for distributed checks
|
3 |
"""
|
4 |
import os
|
5 |
+
import pickle # nosec
|
6 |
from contextlib import contextmanager
|
7 |
|
8 |
import torch
|
|
|
94 |
gathered_values.append(float(tensor.item()))
|
95 |
return gathered_values
|
96 |
return None
|
97 |
+
|
98 |
+
|
99 |
+
def broadcast_dict(vals: dict):
|
100 |
+
if not is_distributed():
|
101 |
+
return vals
|
102 |
+
|
103 |
+
if is_main_process():
|
104 |
+
data_byte = pickle.dumps(vals)
|
105 |
+
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
106 |
+
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
107 |
+
else:
|
108 |
+
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
109 |
+
data_size = torch.IntTensor([0]).to("cuda")
|
110 |
+
|
111 |
+
dist.broadcast(data_size, 0)
|
112 |
+
if not is_main_process():
|
113 |
+
# resize
|
114 |
+
data_tensor = data_tensor.new_empty([data_size.item()])
|
115 |
+
|
116 |
+
dist.broadcast(data_tensor, 0)
|
117 |
+
|
118 |
+
if not is_main_process():
|
119 |
+
data_list = data_tensor.cpu().tolist()
|
120 |
+
data_byte = bytes(data_list[: data_size.item()])
|
121 |
+
vals = pickle.loads(data_byte) # nosec
|
122 |
+
|
123 |
+
return vals
|
src/axolotl/utils/trainer.py
CHANGED
@@ -576,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
576 |
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
577 |
if cfg.bench_dataset:
|
578 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
|
|
|
|
|
|
|
|
579 |
|
580 |
# DDP Config
|
581 |
if cfg.ddp_timeout:
|
@@ -601,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
601 |
output_dir=cfg.output_dir,
|
602 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
603 |
load_best_model_at_end=(
|
604 |
-
cfg.load_best_model_at_end is not False
|
605 |
and cfg.val_set_size > 0
|
606 |
and cfg.save_steps
|
607 |
and cfg.save_steps % cfg.eval_steps == 0
|
608 |
-
and cfg.load_in_8bit is not True
|
609 |
)
|
610 |
or False,
|
611 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
@@ -637,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
637 |
if cfg.relora_steps:
|
638 |
callbacks.append(ReLoRACallback(cfg))
|
639 |
|
640 |
-
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
641 |
-
if cfg.early_stopping_patience:
|
642 |
-
early_stop_cb = EarlyStoppingCallback(
|
643 |
-
cfg.early_stopping_patience,
|
644 |
-
)
|
645 |
-
callbacks.append(early_stop_cb)
|
646 |
-
|
647 |
if cfg.local_rank == 0 and cfg.adapter in [
|
648 |
"lora",
|
649 |
"qlora",
|
@@ -710,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
710 |
if cfg.do_bench_eval:
|
711 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
712 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
713 |
return trainer
|
|
|
576 |
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
577 |
if cfg.bench_dataset:
|
578 |
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
579 |
+
if cfg.metric_for_best_model:
|
580 |
+
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
|
581 |
+
if cfg.greater_is_better:
|
582 |
+
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
583 |
|
584 |
# DDP Config
|
585 |
if cfg.ddp_timeout:
|
|
|
605 |
output_dir=cfg.output_dir,
|
606 |
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
607 |
load_best_model_at_end=(
|
608 |
+
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
609 |
and cfg.val_set_size > 0
|
610 |
and cfg.save_steps
|
611 |
and cfg.save_steps % cfg.eval_steps == 0
|
|
|
612 |
)
|
613 |
or False,
|
614 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
|
|
640 |
if cfg.relora_steps:
|
641 |
callbacks.append(ReLoRACallback(cfg))
|
642 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
643 |
if cfg.local_rank == 0 and cfg.adapter in [
|
644 |
"lora",
|
645 |
"qlora",
|
|
|
706 |
if cfg.do_bench_eval:
|
707 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
708 |
|
709 |
+
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
710 |
+
if cfg.early_stopping_patience:
|
711 |
+
early_stop_cb = EarlyStoppingCallback(
|
712 |
+
cfg.early_stopping_patience,
|
713 |
+
)
|
714 |
+
trainer.add_callback(early_stop_cb)
|
715 |
+
|
716 |
return trainer
|