pere commited on
Commit
af7221f
1 Parent(s): 073e1d8

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +8 -7
run_mlm_flax.py CHANGED
@@ -679,11 +679,13 @@ def main():
679
  # Store some constant
680
  num_epochs = int(training_args.num_train_epochs)
681
 
682
- # Take into account all hosts and all devices for proper global batch size scaling
683
- global_device_count = jax.device_count() * jax.process_count()
684
- train_batch_size = training_args.per_device_train_batch_size * global_device_count
 
 
685
  per_device_eval_batch_size = training_args.per_device_eval_batch_size
686
- eval_batch_size = per_device_eval_batch_size * global_device_count
687
 
688
  num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
689
 
@@ -816,9 +818,8 @@ def main():
816
  #train_samples_idx = np.random.permutation(np.arange(num_train_samples))
817
 
818
  train_samples_idx = np.random.permutation(train_samples_idx)
819
- # Split the training indices across processes
820
- train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
821
- train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
822
 
823
  # Gather the indexes for creating the batch and do a training step
824
  for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
 
679
  # Store some constant
680
  num_epochs = int(training_args.num_train_epochs)
681
 
682
+ # Use local_device_count for per-process batch size
683
+ local_device_count = jax.local_device_count()
684
+
685
+ # Each process handles per_device_train_batch_size * local_device_count
686
+ train_batch_size = training_args.per_device_train_batch_size * local_device_count
687
  per_device_eval_batch_size = training_args.per_device_eval_batch_size
688
+ eval_batch_size = per_device_eval_batch_size * local_device_count
689
 
690
  num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
691
 
 
818
  #train_samples_idx = np.random.permutation(np.arange(num_train_samples))
819
 
820
  train_samples_idx = np.random.permutation(train_samples_idx)
821
+ # Split the training indices across processes train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
822
+ train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
 
823
 
824
  # Gather the indexes for creating the batch and do a training step
825
  for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):