Update run_mlm_flax.py
Browse files- run_mlm_flax.py +27 -26
run_mlm_flax.py
CHANGED
@@ -751,38 +751,39 @@ def main():
|
|
751 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
752 |
|
753 |
def train_step(state, batch, dropout_rng):
|
754 |
-
|
755 |
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
# take average on the per device loss
|
765 |
-
loss = loss.sum()
|
766 |
-
num_labels = label_mask.sum()
|
767 |
-
|
768 |
-
return loss, num_labels
|
769 |
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
|
|
777 |
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
782 |
|
783 |
-
|
784 |
|
785 |
-
|
786 |
|
787 |
# Create parallel version of the train step
|
788 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
|
751 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
752 |
|
753 |
def train_step(state, batch, dropout_rng):
|
754 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
755 |
|
756 |
+
def loss_fn(params):
|
757 |
+
labels = batch.pop("labels")
|
758 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
759 |
|
760 |
+
# compute loss, ignore padded input tokens
|
761 |
+
label_mask = jnp.where(labels != -100, 1.0, 0.0)
|
762 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
|
764 |
+
# take average on the per device loss
|
765 |
+
loss = loss.sum()
|
766 |
+
num_labels = label_mask.sum()
|
767 |
+
|
768 |
+
return loss, num_labels
|
769 |
+
|
770 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
771 |
+
(loss, num_labels), grad = grad_fn(state.params)
|
772 |
|
773 |
+
# Sum number of labels
|
774 |
+
num_labels = jax.lax.psum(num_labels, "batch")
|
775 |
+
|
776 |
+
# Sum loss over devices, but only AFTER dividing by the number of labels
|
777 |
+
loss = jax.lax.psum(loss, "batch") / num_labels
|
778 |
+
|
779 |
+
# true grad = total grad / total samples
|
780 |
+
grad = jax.lax.psum(grad, "batch")
|
781 |
+
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
782 |
+
new_state = state.apply_gradients(grads=grad)
|
783 |
|
784 |
+
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
785 |
|
786 |
+
return new_state, metrics, new_dropout_rng
|
787 |
|
788 |
# Create parallel version of the train step
|
789 |
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|