Saving weights and logs of step 300
Browse files
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891548548
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:457c5948252576d9d5252b28b79a754223d3dea5a24a77f5b2b7cb5189129499
|
3 |
size 891548548
|
run_t5.sh
CHANGED
@@ -16,9 +16,8 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
16 |
--preprocessing_num_workers="96" \
|
17 |
--do_train --do_eval \
|
18 |
--adafactor \
|
19 |
-
--dtype="bfloat16" \
|
20 |
--max_seq_length="512" \
|
21 |
-
--gradient_accumulation_steps="
|
22 |
--per_device_train_batch_size="32" \
|
23 |
--per_device_eval_batch_size="32" \
|
24 |
--learning_rate="5e-3" \
|
@@ -32,3 +31,7 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
32 |
#git add pytorch_model.bin
|
33 |
#git commit -m "Update pytorch model after training"
|
34 |
#git push origin main
|
|
|
|
|
|
|
|
|
|
16 |
--preprocessing_num_workers="96" \
|
17 |
--do_train --do_eval \
|
18 |
--adafactor \
|
|
|
19 |
--max_seq_length="512" \
|
20 |
+
--gradient_accumulation_steps="16" \
|
21 |
--per_device_train_batch_size="32" \
|
22 |
--per_device_eval_batch_size="32" \
|
23 |
--learning_rate="5e-3" \
|
|
|
31 |
#git add pytorch_model.bin
|
32 |
#git commit -m "Update pytorch model after training"
|
33 |
#git push origin main
|
34 |
+
|
35 |
+
|
36 |
+
# --dtype="bfloat16" \
|
37 |
+
# --resume_from_checkpoint="${MODEL_DIR}/ckpt-3300" \
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
@@ -722,6 +722,9 @@ if __name__ == "__main__":
|
|
722 |
|
723 |
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
724 |
|
|
|
|
|
|
|
725 |
# Create learning rate schedule
|
726 |
|
727 |
# See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 6% of training steps
|
@@ -775,6 +778,11 @@ if __name__ == "__main__":
|
|
775 |
# Setup train state
|
776 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
777 |
|
|
|
|
|
|
|
|
|
|
|
778 |
# Define gradient update step fn
|
779 |
def train_step(state, batch, dropout_rng):
|
780 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
@@ -828,8 +836,7 @@ if __name__ == "__main__":
|
|
828 |
# Replicate the train state on each device
|
829 |
state = jax_utils.replicate(state)
|
830 |
|
831 |
-
|
832 |
-
total_train_steps = steps_per_epoch * num_epochs
|
833 |
|
834 |
logger.info("***** Running training *****")
|
835 |
logger.info(f" Num examples = {len(datasets['train'])}")
|
@@ -855,6 +862,11 @@ if __name__ == "__main__":
|
|
855 |
|
856 |
# Gather the indexes for creating the batch and do a training step
|
857 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
|
|
|
|
|
|
|
|
|
|
858 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
859 |
model_inputs = data_collator(samples)
|
860 |
|
@@ -863,7 +875,6 @@ if __name__ == "__main__":
|
|
863 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
864 |
train_metrics.append(train_metric)
|
865 |
|
866 |
-
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
867 |
|
868 |
if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
|
869 |
# Save metrics
|
|
|
722 |
|
723 |
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
724 |
|
725 |
+
steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
|
726 |
+
total_train_steps = steps_per_epoch * num_epochs
|
727 |
+
|
728 |
# Create learning rate schedule
|
729 |
|
730 |
# See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 6% of training steps
|
|
|
778 |
# Setup train state
|
779 |
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
780 |
|
781 |
+
if training_args.resume_from_checkpoint:
|
782 |
+
state, resume_step = restore_checkpoint(training_args.resume_from_checkpoint, state)
|
783 |
+
else:
|
784 |
+
resume_step = 0
|
785 |
+
|
786 |
# Define gradient update step fn
|
787 |
def train_step(state, batch, dropout_rng):
|
788 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
|
|
836 |
# Replicate the train state on each device
|
837 |
state = jax_utils.replicate(state)
|
838 |
|
839 |
+
|
|
|
840 |
|
841 |
logger.info("***** Running training *****")
|
842 |
logger.info(f" Num examples = {len(datasets['train'])}")
|
|
|
862 |
|
863 |
# Gather the indexes for creating the batch and do a training step
|
864 |
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
865 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
866 |
+
# skip to the step from which we are resuming
|
867 |
+
if cur_step < resume_step:
|
868 |
+
continue
|
869 |
+
|
870 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
871 |
model_inputs = data_collator(samples)
|
872 |
|
|
|
875 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
876 |
train_metrics.append(train_metric)
|
877 |
|
|
|
878 |
|
879 |
if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
|
880 |
# Save metrics
|
runs/Jul10_07-37-20_t1v-n-0e7426e8-w-0/events.out.tfevents.1625902752.t1v-n-0e7426e8-w-0.18397.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1aa4fd14ba6d0007ac2b4c7ad5f7b03ab486b3899ece3eba1fefe852923f2366
|
3 |
+
size 40
|
runs/Jul10_07-45-49_t1v-n-0e7426e8-w-0/events.out.tfevents.1625903173.t1v-n-0e7426e8-w-0.20563.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9086b97ea9ba59e96e4c66b26c205fe1207d0a94ab355127a1e4f8078d84a269
|
3 |
+
size 45399
|