pere commited on
Commit
1e425ac
1 Parent(s): 572a0c2

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +27 -26
run_mlm_flax.py CHANGED
@@ -751,38 +751,39 @@ def main():
751
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
752
 
753
  def train_step(state, batch, dropout_rng):
754
- dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
755
 
756
- def loss_fn(params):
757
- labels = batch.pop("labels")
758
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
759
 
760
- # compute loss, ignore padded input tokens
761
- label_mask = jnp.where(labels!=-100, 1.0, 0.0)
762
- loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
763
-
764
- # take average on the per device loss
765
- loss = loss.sum()
766
- num_labels = label_mask.sum()
767
-
768
- return loss, num_labels
769
 
770
- grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
771
- (loss, num_labels), grad = grad_fn(state.params)
772
- # Sum number of labels
773
- num_labels = jax.lax.psum(num_labels, "batch")
774
-
775
- # Sum loss over devices, but only AFTER dividing by the number of labels
776
- loss = jax.lax.psum(loss, "batch") / num_labels
 
777
 
778
- # true grad = total grad / total samples
779
- grad = jax.lax.psum(grad, "batch")
780
- grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
781
- new_state = state.apply_gradients(grads=grad)
 
 
 
 
 
 
782
 
783
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
784
 
785
- return new_state, metrics, new_dropout_rng
786
 
787
  # Create parallel version of the train step
788
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
 
751
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
752
 
753
  def train_step(state, batch, dropout_rng):
754
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
755
 
756
+ def loss_fn(params):
757
+ labels = batch.pop("labels")
758
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
759
 
760
+ # compute loss, ignore padded input tokens
761
+ label_mask = jnp.where(labels != -100, 1.0, 0.0)
762
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
 
 
 
 
 
 
763
 
764
+ # take average on the per device loss
765
+ loss = loss.sum()
766
+ num_labels = label_mask.sum()
767
+
768
+ return loss, num_labels
769
+
770
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
771
+ (loss, num_labels), grad = grad_fn(state.params)
772
 
773
+ # Sum number of labels
774
+ num_labels = jax.lax.psum(num_labels, "batch")
775
+
776
+ # Sum loss over devices, but only AFTER dividing by the number of labels
777
+ loss = jax.lax.psum(loss, "batch") / num_labels
778
+
779
+ # true grad = total grad / total samples
780
+ grad = jax.lax.psum(grad, "batch")
781
+ grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
782
+ new_state = state.apply_gradients(grads=grad)
783
 
784
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
785
 
786
+ return new_state, metrics, new_dropout_rng
787
 
788
  # Create parallel version of the train step
789
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))