Update run_mlm_flax.py
Browse files- run_mlm_flax.py +20 -1
run_mlm_flax.py
CHANGED
@@ -760,7 +760,7 @@ def main():
|
|
760 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
761 |
|
762 |
# compute loss, ignore padded input tokens
|
763 |
-
label_mask = jnp.where(labels
|
764 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
765 |
|
766 |
# take average
|
@@ -816,6 +816,11 @@ def main():
|
|
816 |
train_time = 0
|
817 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
818 |
for epoch in epochs:
|
|
|
|
|
|
|
|
|
|
|
819 |
# ======================== Training ================================
|
820 |
train_start = time.time()
|
821 |
train_metrics = []
|
@@ -837,6 +842,9 @@ def main():
|
|
837 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
838 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
839 |
|
|
|
|
|
|
|
840 |
# Model forward
|
841 |
model_inputs = shard(model_inputs.data)
|
842 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
@@ -845,6 +853,17 @@ def main():
|
|
845 |
cur_step = epoch * (num_train_samples // (train_batch_size * jax.process_count())) + step
|
846 |
|
847 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
848 |
# Save metrics
|
849 |
train_metric = jax_utils.unreplicate(train_metric)
|
850 |
train_time += time.time() - train_start
|
|
|
760 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
761 |
|
762 |
# compute loss, ignore padded input tokens
|
763 |
+
label_mask = jnp.where(labels!=-100, 1.0, 0.0)
|
764 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
765 |
|
766 |
# take average
|
|
|
816 |
train_time = 0
|
817 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
818 |
for epoch in epochs:
|
819 |
+
# Debug
|
820 |
+
testsample = [tokenized_datasets["train"][0]] # Just one sample
|
821 |
+
testbatch = data_collator(testsample, pad_to_multiple_of=16)
|
822 |
+
print("Labels in a sample batch:", testbatch["labels"])
|
823 |
+
|
824 |
# ======================== Training ================================
|
825 |
train_start = time.time()
|
826 |
train_metrics = []
|
|
|
842 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
843 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
844 |
|
845 |
+
|
846 |
+
|
847 |
+
|
848 |
# Model forward
|
849 |
model_inputs = shard(model_inputs.data)
|
850 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
|
|
853 |
cur_step = epoch * (num_train_samples // (train_batch_size * jax.process_count())) + step
|
854 |
|
855 |
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
856 |
+
|
857 |
+
|
858 |
+
# Debug
|
859 |
+
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
860 |
+
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
861 |
+
total_norm = 0
|
862 |
+
for p in jax.tree_util.tree_leaves(params):
|
863 |
+
total_norm += np.sum(np.square(p))
|
864 |
+
total_norm = np.sqrt(total_norm)
|
865 |
+
print("Parameter norm:", total_norm)
|
866 |
+
|
867 |
# Save metrics
|
868 |
train_metric = jax_utils.unreplicate(train_metric)
|
869 |
train_time += time.time() - train_start
|