Update run_mlm_flax.py
Browse files- run_mlm_flax.py +13 -1
run_mlm_flax.py
CHANGED
@@ -687,7 +687,18 @@ def main():
|
|
687 |
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
688 |
eval_batch_size = per_device_eval_batch_size * local_device_count
|
689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
num_train_steps = (len(tokenized_datasets["train"]) // (train_batch_size * jax.process_count())) * num_epochs
|
|
|
691 |
|
692 |
# Create learning rate schedule
|
693 |
warmup_fn = optax.linear_schedule(
|
@@ -817,7 +828,8 @@ def main():
|
|
817 |
|
818 |
train_samples_idx = np.arange(num_train_samples)
|
819 |
train_samples_idx = np.random.permutation(train_samples_idx)
|
820 |
-
# Split the training indices across processes
|
|
|
821 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
|
822 |
|
823 |
# Gather the indexes for creating the batch and do a training step
|
|
|
687 |
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
688 |
eval_batch_size = per_device_eval_batch_size * local_device_count
|
689 |
|
690 |
+
# Calculate Global Batch Sizes
|
691 |
+
global_train_batch_size = train_batch_size * jax.process_count()
|
692 |
+
global_eval_batch_size = eval_batch_size * jax.process_count()
|
693 |
+
|
694 |
+
# Log Batch Sizes
|
695 |
+
logger.info(f"Per-process train batch size: {train_batch_size}")
|
696 |
+
logger.info(f"Global train batch size: {global_train_batch_size}")
|
697 |
+
logger.info(f"Per-process eval batch size: {per_device_eval_batch_size}")
|
698 |
+
logger.info(f"Global eval batch size: {global_eval_batch_size}")
|
699 |
+
|
700 |
num_train_steps = (len(tokenized_datasets["train"]) // (train_batch_size * jax.process_count())) * num_epochs
|
701 |
+
logger.info(f"Number of training steps: {num_train_steps}")
|
702 |
|
703 |
# Create learning rate schedule
|
704 |
warmup_fn = optax.linear_schedule(
|
|
|
828 |
|
829 |
train_samples_idx = np.arange(num_train_samples)
|
830 |
train_samples_idx = np.random.permutation(train_samples_idx)
|
831 |
+
# Split the training indices across processes
|
832 |
+
train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
|
833 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
|
834 |
|
835 |
# Gather the indexes for creating the batch and do a training step
|