pere commited on
Commit
5224747
·
verified ·
1 Parent(s): cd20b10

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +20 -1
run_mlm_flax.py CHANGED
@@ -760,7 +760,7 @@ def main():
760
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
761
 
762
  # compute loss, ignore padded input tokens
763
- label_mask = jnp.where(labels > -100, 1.0, 0.0)
764
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
765
 
766
  # take average
@@ -816,6 +816,11 @@ def main():
816
  train_time = 0
817
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
818
  for epoch in epochs:
 
 
 
 
 
819
  # ======================== Training ================================
820
  train_start = time.time()
821
  train_metrics = []
@@ -837,6 +842,9 @@ def main():
837
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
838
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
839
 
 
 
 
840
  # Model forward
841
  model_inputs = shard(model_inputs.data)
842
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
@@ -845,6 +853,17 @@ def main():
845
  cur_step = epoch * (num_train_samples // (train_batch_size * jax.process_count())) + step
846
 
847
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
 
 
 
 
 
 
 
 
 
 
 
848
  # Save metrics
849
  train_metric = jax_utils.unreplicate(train_metric)
850
  train_time += time.time() - train_start
 
760
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
761
 
762
  # compute loss, ignore padded input tokens
763
+ label_mask = jnp.where(labels!=-100, 1.0, 0.0)
764
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
765
 
766
  # take average
 
816
  train_time = 0
817
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
818
  for epoch in epochs:
819
+ # Debug
820
+ testsample = [tokenized_datasets["train"][0]] # Just one sample
821
+ testbatch = data_collator(testsample, pad_to_multiple_of=16)
822
+ print("Labels in a sample batch:", testbatch["labels"])
823
+
824
  # ======================== Training ================================
825
  train_start = time.time()
826
  train_metrics = []
 
842
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
843
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
844
 
845
+
846
+
847
+
848
  # Model forward
849
  model_inputs = shard(model_inputs.data)
850
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
 
853
  cur_step = epoch * (num_train_samples // (train_batch_size * jax.process_count())) + step
854
 
855
  if cur_step % training_args.logging_steps == 0 and cur_step > 0:
856
+
857
+
858
+ # Debug
859
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
860
+ params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
861
+ total_norm = 0
862
+ for p in jax.tree_util.tree_leaves(params):
863
+ total_norm += np.sum(np.square(p))
864
+ total_norm = np.sqrt(total_norm)
865
+ print("Parameter norm:", total_norm)
866
+
867
  # Save metrics
868
  train_metric = jax_utils.unreplicate(train_metric)
869
  train_time += time.time() - train_start