pere commited on
Commit
c102460
1 Parent(s): af7221f

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +4 -5
run_mlm_flax.py CHANGED
@@ -687,8 +687,8 @@ def main():
687
  per_device_eval_batch_size = training_args.per_device_eval_batch_size
688
  eval_batch_size = per_device_eval_batch_size * local_device_count
689
 
690
- num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
691
-
692
  # Create learning rate schedule
693
  warmup_fn = optax.linear_schedule(
694
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
@@ -814,9 +814,8 @@ def main():
814
 
815
  # Generate an epoch by shuffling sampling indices from the train dataset
816
  num_train_samples = len(tokenized_datasets["train"])
817
- # Avoid using jax.numpy here in case of TPU training
818
- #train_samples_idx = np.random.permutation(np.arange(num_train_samples))
819
-
820
  train_samples_idx = np.random.permutation(train_samples_idx)
821
  # Split the training indices across processes train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
822
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
 
687
  per_device_eval_batch_size = training_args.per_device_eval_batch_size
688
  eval_batch_size = per_device_eval_batch_size * local_device_count
689
 
690
+ num_train_steps = (len(tokenized_datasets["train"]) // (train_batch_size * jax.process_count())) * num_epochs
691
+
692
  # Create learning rate schedule
693
  warmup_fn = optax.linear_schedule(
694
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
 
814
 
815
  # Generate an epoch by shuffling sampling indices from the train dataset
816
  num_train_samples = len(tokenized_datasets["train"])
817
+
818
+ train_samples_idx = np.arange(num_train_samples)
 
819
  train_samples_idx = np.random.permutation(train_samples_idx)
820
  # Split the training indices across processes train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
821
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)