Fix script
Browse files
src/run.sh
CHANGED
@@ -27,7 +27,7 @@ export LOGGING_STEPS=500
|
|
27 |
export EVAL_STEPS=2500
|
28 |
export SAVE_STEPS=2500
|
29 |
|
30 |
-
python src/
|
31 |
--output_dir="$OUTPUT_DIR" \
|
32 |
--train_file="$TRAIN_FILE" \
|
33 |
--validation_file="$VALIDATION_FILE" \
|
|
|
27 |
export EVAL_STEPS=2500
|
28 |
export SAVE_STEPS=2500
|
29 |
|
30 |
+
python src/run_recipe_nlg_flax.py \
|
31 |
--output_dir="$OUTPUT_DIR" \
|
32 |
--train_file="$TRAIN_FILE" \
|
33 |
--validation_file="$VALIDATION_FILE" \
|
src/{run_ed_recipe_nlg.py → run_recipe_nlg_flax.py}
RENAMED
@@ -779,7 +779,9 @@ def main():
|
|
779 |
# Save metrics
|
780 |
train_metric = unreplicate(train_metric)
|
781 |
train_time += time.time() - train_start
|
|
|
782 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
783 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
784 |
|
785 |
epochs.write(
|
@@ -789,6 +791,7 @@ def main():
|
|
789 |
train_metrics = []
|
790 |
|
791 |
if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
|
|
|
792 |
eval_metrics = []
|
793 |
eval_preds = []
|
794 |
eval_labels = []
|
@@ -827,20 +830,27 @@ def main():
|
|
827 |
|
828 |
# Save metrics
|
829 |
if has_tensorboard and jax.process_index() == 0:
|
830 |
-
|
|
|
831 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
832 |
|
833 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
834 |
-
|
|
|
835 |
if jax.process_index() == 0:
|
836 |
-
|
837 |
-
params = jax.device_get(unreplicate(state.params))
|
838 |
model.save_pretrained(
|
839 |
training_args.output_dir,
|
840 |
params=params,
|
841 |
push_to_hub=training_args.push_to_hub,
|
842 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
843 |
)
|
|
|
|
|
|
|
|
|
|
|
844 |
|
845 |
|
846 |
if __name__ == "__main__":
|
|
|
779 |
# Save metrics
|
780 |
train_metric = unreplicate(train_metric)
|
781 |
train_time += time.time() - train_start
|
782 |
+
|
783 |
if has_tensorboard and jax.process_index() == 0:
|
784 |
+
logger.info(f"*** Writing training summary after {cur_step} steps ***")
|
785 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
786 |
|
787 |
epochs.write(
|
|
|
791 |
train_metrics = []
|
792 |
|
793 |
if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
|
794 |
+
logger.info(f"*** Evaluation after {cur_step} steps ***")
|
795 |
eval_metrics = []
|
796 |
eval_preds = []
|
797 |
eval_labels = []
|
|
|
830 |
|
831 |
# Save metrics
|
832 |
if has_tensorboard and jax.process_index() == 0:
|
833 |
+
logger.info(f"*** Writing evaluation summary after {cur_step} steps ***")
|
834 |
+
# cur_step = epoch * (len(train_dataset) // train_batch_size)
|
835 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
836 |
|
837 |
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
838 |
+
logger.info(f"*** Saving checkpoints after {cur_step} steps ***")
|
839 |
+
# save checkpoint after each steps and push checkpoint to the hub
|
840 |
if jax.process_index() == 0:
|
841 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
842 |
+
# params = jax.device_get(unreplicate(state.params))
|
843 |
model.save_pretrained(
|
844 |
training_args.output_dir,
|
845 |
params=params,
|
846 |
push_to_hub=training_args.push_to_hub,
|
847 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
848 |
)
|
849 |
+
tokenizer.save_pretrained(
|
850 |
+
training_args.output_dir,
|
851 |
+
push_to_hub=training_args.push_to_hub,
|
852 |
+
commit_message=f"Saving tokenizer step {cur_step}",
|
853 |
+
)
|
854 |
|
855 |
|
856 |
if __name__ == "__main__":
|