Update run_mlm_flax.py
Browse files- run_mlm_flax.py +1 -1
run_mlm_flax.py
CHANGED
@@ -796,7 +796,7 @@ def main():
|
|
796 |
logits = model(**batch, params=params, train=False)[0]
|
797 |
|
798 |
# compute loss, ignore padded input tokens
|
799 |
-
label_mask = jnp.where(labels
|
800 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
801 |
|
802 |
# compute accuracy
|
|
|
796 |
logits = model(**batch, params=params, train=False)[0]
|
797 |
|
798 |
# compute loss, ignore padded input tokens
|
799 |
+
label_mask = jnp.where(labels!=-100, 1.0, 0.0)
|
800 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
801 |
|
802 |
# compute accuracy
|