refactor to set eval_batch_size earlier if unset, so we can warn if mismatched (#662)
Browse files- README.md +1 -1
- src/axolotl/utils/config.py +7 -0
- src/axolotl/utils/trainer.py +1 -3
README.md
CHANGED
@@ -571,7 +571,7 @@ torch_compile_backend: # Optional[str]
|
|
571 |
# training hyperparameters
|
572 |
gradient_accumulation_steps: 1
|
573 |
micro_batch_size: 2
|
574 |
-
eval_batch_size:
|
575 |
num_epochs: 3
|
576 |
warmup_steps: 100
|
577 |
learning_rate: 0.00003
|
|
|
571 |
# training hyperparameters
|
572 |
gradient_accumulation_steps: 1
|
573 |
micro_batch_size: 2
|
574 |
+
eval_batch_size:
|
575 |
num_epochs: 3
|
576 |
warmup_steps: 100
|
577 |
learning_rate: 0.00003
|
src/axolotl/utils/config.py
CHANGED
@@ -49,6 +49,8 @@ def normalize_config(cfg):
|
|
49 |
cfg.batch_size = (
|
50 |
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
51 |
)
|
|
|
|
|
52 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
53 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
54 |
cfg.eval_table_size = cfg.eval_table_size or 0
|
@@ -157,6 +159,11 @@ def validate_config(cfg):
|
|
157 |
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
158 |
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
159 |
)
|
|
|
|
|
|
|
|
|
|
|
160 |
if cfg.load_4bit:
|
161 |
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
162 |
|
|
|
49 |
cfg.batch_size = (
|
50 |
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
51 |
)
|
52 |
+
if cfg.eval_batch_size is None:
|
53 |
+
cfg.eval_batch_size = cfg.micro_batch_size
|
54 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
55 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
56 |
cfg.eval_table_size = cfg.eval_table_size or 0
|
|
|
159 |
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
160 |
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
161 |
)
|
162 |
+
if cfg.eval_batch_size != cfg.micro_batch_size:
|
163 |
+
LOG.warning(
|
164 |
+
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
165 |
+
)
|
166 |
+
|
167 |
if cfg.load_4bit:
|
168 |
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
169 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -668,9 +668,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
668 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
669 |
max_seq_length=cfg.sequence_len,
|
670 |
per_device_train_batch_size=cfg.micro_batch_size,
|
671 |
-
per_device_eval_batch_size=cfg.eval_batch_size
|
672 |
-
if cfg.eval_batch_size is not None
|
673 |
-
else cfg.micro_batch_size,
|
674 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
675 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
676 |
num_train_epochs=cfg.num_epochs,
|
|
|
668 |
max_steps=total_num_steps if cfg.max_steps else -1,
|
669 |
max_seq_length=cfg.sequence_len,
|
670 |
per_device_train_batch_size=cfg.micro_batch_size,
|
671 |
+
per_device_eval_batch_size=cfg.eval_batch_size,
|
|
|
|
|
672 |
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
673 |
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
674 |
num_train_epochs=cfg.num_epochs,
|