pere commited on
Commit
9868152
1 Parent(s): f247126

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +1 -1
run_mlm_flax.py CHANGED
@@ -796,7 +796,7 @@ def main():
796
  logits = model(**batch, params=params, train=False)[0]
797
 
798
  # compute loss, ignore padded input tokens
799
- label_mask = jnp.where(labels > 0, 1.0, 0.0)
800
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
801
 
802
  # compute accuracy
 
796
  logits = model(**batch, params=params, train=False)[0]
797
 
798
  # compute loss, ignore padded input tokens
799
+ label_mask = jnp.where(labels!=-100, 1.0, 0.0)
800
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
801
 
802
  # compute accuracy