Update run_mlm_flax.py
Browse files- run_mlm_flax.py +1 -1
run_mlm_flax.py
CHANGED
@@ -760,7 +760,7 @@ def main():
|
|
760 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
761 |
|
762 |
# compute loss, ignore padded input tokens
|
763 |
-
label_mask = jnp.where(labels
|
764 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
765 |
|
766 |
# take average
|
|
|
760 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
761 |
|
762 |
# compute loss, ignore padded input tokens
|
763 |
+
label_mask = jnp.where(labels!=-100, 1.0, 0.0)
|
764 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
765 |
|
766 |
# take average
|