pere commited on
Commit
e1c2a47
1 Parent(s): 79c53a1

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +1 -1
run_mlm_flax.py CHANGED
@@ -760,7 +760,7 @@ def main():
760
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
761
 
762
  # compute loss, ignore padded input tokens
763
- label_mask = jnp.where(labels > -100, 1.0, 0.0)
764
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
765
 
766
  # take average
 
760
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
761
 
762
  # compute loss, ignore padded input tokens
763
+ label_mask = jnp.where(labels!=-100, 1.0, 0.0)
764
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
765
 
766
  # take average