log supervised token count (#448)
Browse files- src/axolotl/utils/trainer.py +10 -0
src/axolotl/utils/trainer.py
CHANGED
@@ -401,6 +401,16 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
401 |
LOG.info(f"π UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
402 |
cfg.total_num_tokens = total_num_tokens
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
if cfg.sample_packing_eff_est:
|
405 |
total_num_steps = (
|
406 |
# match count to len est in dataloader
|
|
|
401 |
LOG.info(f"π UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
402 |
cfg.total_num_tokens = total_num_tokens
|
403 |
|
404 |
+
if not cfg.total_supervised_tokens:
|
405 |
+
total_supervised_tokens = (
|
406 |
+
train_dataset.data.column("labels")
|
407 |
+
.to_pandas()
|
408 |
+
.apply(lambda x: np.sum(np.array(x) != -100))
|
409 |
+
.sum()
|
410 |
+
)
|
411 |
+
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
|
412 |
+
cfg.total_supervised_tokens = total_supervised_tokens
|
413 |
+
|
414 |
if cfg.sample_packing_eff_est:
|
415 |
total_num_steps = (
|
416 |
# match count to len est in dataloader
|