Saving weights and logs of step 5000
Browse files
events.out.tfevents.1734082062.t1v-n-53cd541d-w-35.1086346.0.v2
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b85b9a72fe6b897c2e4406605020d7c710cb9eebc969bde09e273fb974300238
|
3 |
+
size 228898
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1421658229
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c4c381d4203e6b03a768f5cd7066942ed7e02f54c5534d09efd039be2c86d2e
|
3 |
size 1421658229
|
run_mlm_flax.py
CHANGED
@@ -760,7 +760,7 @@ def main():
|
|
760 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
761 |
|
762 |
# compute loss, ignore padded input tokens
|
763 |
-
label_mask = jnp.where(labels
|
764 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
765 |
|
766 |
# take average
|
|
|
760 |
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
761 |
|
762 |
# compute loss, ignore padded input tokens
|
763 |
+
label_mask = jnp.where(labels > -100, 1.0, 0.0)
|
764 |
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
765 |
|
766 |
# take average
|