bofenghuang commited on
Commit
81da7d2
1 Parent(s): 1e1921b

Fix `total_num_steps` (#1566)

Browse files

* Fix `total_num_steps`

* Fix total_num_steps

* lint

Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +5 -15
src/axolotl/utils/trainer.py CHANGED
@@ -330,7 +330,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
330
  / cfg.sample_packing_eff_est
331
  / cfg.sequence_len
332
  // cfg.batch_size
333
- // int(os.environ.get("WORLD_SIZE", 1))
334
  )
335
  - 1
336
  )
@@ -359,18 +358,14 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
359
  train_dataset.remove_columns(["length"]),
360
  batch_sampler=sampler,
361
  )
362
- data_loader_len = len(data_loader) // cfg.batch_size
 
 
363
  actual_eff = sampler.efficiency()
364
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
365
  # FIXME: is there a bug here somewhere? the total num steps depends
366
  # on the agreed on value for sample_packing_eff_est
367
- total_num_steps = int(
368
- math.floor(
369
- data_loader_len
370
- * cfg.num_epochs
371
- / int(os.environ.get("WORLD_SIZE", 1))
372
- )
373
- )
374
 
375
  def calc_sample_packing_eff_est(estimates: List[float]):
376
  LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -391,12 +386,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
391
  )
392
  else:
393
  total_num_steps = int(
394
- math.ceil(
395
- len(train_dataset)
396
- * cfg.num_epochs
397
- / int(os.environ.get("WORLD_SIZE", 1))
398
- / cfg.batch_size
399
- )
400
  )
401
  LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
402
  return total_num_steps
 
330
  / cfg.sample_packing_eff_est
331
  / cfg.sequence_len
332
  // cfg.batch_size
 
333
  )
334
  - 1
335
  )
 
358
  train_dataset.remove_columns(["length"]),
359
  batch_sampler=sampler,
360
  )
361
+ data_loader_len = len(data_loader) // (
362
+ cfg.world_size * cfg.gradient_accumulation_steps
363
+ )
364
  actual_eff = sampler.efficiency()
365
  LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
366
  # FIXME: is there a bug here somewhere? the total num steps depends
367
  # on the agreed on value for sample_packing_eff_est
368
+ total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
 
 
 
 
 
 
369
 
370
  def calc_sample_packing_eff_est(estimates: List[float]):
371
  LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
 
386
  )
387
  else:
388
  total_num_steps = int(
389
+ math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
 
 
 
 
 
390
  )
391
  LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
392
  return total_num_steps