Update run_mlm_flax.py
Browse files- run_mlm_flax.py +8 -7
run_mlm_flax.py
CHANGED
@@ -679,11 +679,13 @@ def main():
|
|
679 |
# Store some constant
|
680 |
num_epochs = int(training_args.num_train_epochs)
|
681 |
|
682 |
-
#
|
683 |
-
|
684 |
-
|
|
|
|
|
685 |
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
686 |
-
eval_batch_size = per_device_eval_batch_size *
|
687 |
|
688 |
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
689 |
|
@@ -816,9 +818,8 @@ def main():
|
|
816 |
#train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
817 |
|
818 |
train_samples_idx = np.random.permutation(train_samples_idx)
|
819 |
-
# Split the training indices across processes
|
820 |
-
|
821 |
-
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
822 |
|
823 |
# Gather the indexes for creating the batch and do a training step
|
824 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
|
|
679 |
# Store some constant
|
680 |
num_epochs = int(training_args.num_train_epochs)
|
681 |
|
682 |
+
# Use local_device_count for per-process batch size
|
683 |
+
local_device_count = jax.local_device_count()
|
684 |
+
|
685 |
+
# Each process handles per_device_train_batch_size * local_device_count
|
686 |
+
train_batch_size = training_args.per_device_train_batch_size * local_device_count
|
687 |
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
688 |
+
eval_batch_size = per_device_eval_batch_size * local_device_count
|
689 |
|
690 |
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
691 |
|
|
|
818 |
#train_samples_idx = np.random.permutation(np.arange(num_train_samples))
|
819 |
|
820 |
train_samples_idx = np.random.permutation(train_samples_idx)
|
821 |
+
# Split the training indices across processes train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
|
822 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
|
|
|
823 |
|
824 |
# Gather the indexes for creating the batch and do a training step
|
825 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|