pere commited on
Commit
264fd01
1 Parent(s): 545f88c

Update run_mlm_flax.py

Browse files
Files changed (1) hide show
  1. run_mlm_flax.py +32 -34
run_mlm_flax.py CHANGED
@@ -750,41 +750,39 @@ def main():
750
  # Setup train state
751
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
752
 
753
- # Define gradient update step fn
754
  def train_step(state, batch, dropout_rng):
755
- dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
756
-
757
- def loss_fn(params):
758
- labels = batch.pop("labels")
759
-
760
- logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
761
-
762
- # compute loss, ignore padded input tokens
763
- label_mask = jnp.where(labels!=-100, 1.0, 0.0)
764
- loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
765
-
766
- # take average
767
- loss = loss.sum()
768
- num_labels = label_mask.sum()
769
-
770
- return loss, num_labels
771
-
772
- grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
773
- (loss, num_labels), grad = grad_fn(state.params)
774
- num_labels = jax.lax.psum(num_labels, "batch")
775
-
776
- # true loss = total loss / total samples
777
- loss = jax.lax.psum(loss, "batch")
778
- loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
779
-
780
- # true grad = total grad / total samples
781
- grad = jax.lax.psum(grad, "batch")
782
- grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
783
- new_state = state.apply_gradients(grads=grad)
784
-
785
- metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
786
-
787
- return new_state, metrics, new_dropout_rng
788
 
789
  # Create parallel version of the train step
790
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
 
750
  # Setup train state
751
  state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
752
 
 
753
  def train_step(state, batch, dropout_rng):
754
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
755
+
756
+ def loss_fn(params):
757
+ labels = batch.pop("labels")
758
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
759
+
760
+ # compute loss, ignore padded input tokens
761
+ label_mask = jnp.where(labels!=-100, 1.0, 0.0)
762
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
763
+
764
+ # take average on the per device loss
765
+ loss = loss.sum()
766
+ num_labels = label_mask.sum()
767
+
768
+ return loss, num_labels
769
+
770
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
771
+ (loss, num_labels), grad = grad_fn(state.params)
772
+ # Sum number of labels
773
+ num_labels = jax.lax.psum(num_labels, "batch")
774
+
775
+ # Sum loss over devices, but only AFTER dividing by the number of labels
776
+ loss = jax.lax.psum(loss, "batch") / num_labels
777
+
778
+ # true grad = total grad / total samples
779
+ grad = jax.lax.psum(grad, "batch")
780
+ grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
781
+ new_state = state.apply_gradients(grads=grad)
782
+
783
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
784
+
785
+ return new_state, metrics, new_dropout_rng
 
786
 
787
  # Create parallel version of the train step
788
  p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))