pere commited on
Commit
cd20b10
1 Parent(s): e1c2a47

Saving weights and logs of step 5000

Browse files
events.out.tfevents.1734082062.t1v-n-53cd541d-w-35.1086346.0.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:22d24708144113ab237002eb807b1d5c7a911ce761f98ce2daadc5f2d1a7c3ed
3
- size 63038
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b85b9a72fe6b897c2e4406605020d7c710cb9eebc969bde09e273fb974300238
3
+ size 228898
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f1adcfdad8a681e08bfe5535021067e1c876db067eda92f40f92ccf9c4c5e63
3
  size 1421658229
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c4c381d4203e6b03a768f5cd7066942ed7e02f54c5534d09efd039be2c86d2e
3
  size 1421658229
run_mlm_flax.py CHANGED
@@ -760,7 +760,7 @@ def main():
760
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
761
 
762
  # compute loss, ignore padded input tokens
763
- label_mask = jnp.where(labels!=-100, 1.0, 0.0)
764
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
765
 
766
  # take average
 
760
  logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
761
 
762
  # compute loss, ignore padded input tokens
763
+ label_mask = jnp.where(labels > -100, 1.0, 0.0)
764
  loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
765
 
766
  # take average