Update run_mlm_flax.py
Browse files- run_mlm_flax.py +4 -5
run_mlm_flax.py
CHANGED
@@ -687,8 +687,8 @@ def main():
|
|
687 |
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
688 |
eval_batch_size = per_device_eval_batch_size * local_device_count
|
689 |
|
690 |
-
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
691 |
-
|
692 |
# Create learning rate schedule
|
693 |
warmup_fn = optax.linear_schedule(
|
694 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
@@ -814,9 +814,8 @@ def main():
|
|
814 |
|
815 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
816 |
num_train_samples = len(tokenized_datasets["train"])
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
train_samples_idx = np.random.permutation(train_samples_idx)
|
821 |
# Split the training indices across processes train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
|
822 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
|
|
|
687 |
per_device_eval_batch_size = training_args.per_device_eval_batch_size
|
688 |
eval_batch_size = per_device_eval_batch_size * local_device_count
|
689 |
|
690 |
+
num_train_steps = (len(tokenized_datasets["train"]) // (train_batch_size * jax.process_count())) * num_epochs
|
691 |
+
|
692 |
# Create learning rate schedule
|
693 |
warmup_fn = optax.linear_schedule(
|
694 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
|
|
814 |
|
815 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
816 |
num_train_samples = len(tokenized_datasets["train"])
|
817 |
+
|
818 |
+
train_samples_idx = np.arange(num_train_samples)
|
|
|
819 |
train_samples_idx = np.random.permutation(train_samples_idx)
|
820 |
# Split the training indices across processes train_samples_idx = np.array_split(train_samples_idx, jax.process_count())[jax.process_index()]
|
821 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size, drop_last=True)
|