m3hrdadfi commited on
Commit
2091b6a
1 Parent(s): a9618fd

Fix script

Browse files
src/run.sh CHANGED
@@ -27,7 +27,7 @@ export LOGGING_STEPS=500
27
  export EVAL_STEPS=2500
28
  export SAVE_STEPS=2500
29
 
30
- python src/run_ed_recipe_nlg.py \
31
  --output_dir="$OUTPUT_DIR" \
32
  --train_file="$TRAIN_FILE" \
33
  --validation_file="$VALIDATION_FILE" \
 
27
  export EVAL_STEPS=2500
28
  export SAVE_STEPS=2500
29
 
30
+ python src/run_recipe_nlg_flax.py \
31
  --output_dir="$OUTPUT_DIR" \
32
  --train_file="$TRAIN_FILE" \
33
  --validation_file="$VALIDATION_FILE" \
src/{run_ed_recipe_nlg.py → run_recipe_nlg_flax.py} RENAMED
@@ -779,7 +779,9 @@ def main():
779
  # Save metrics
780
  train_metric = unreplicate(train_metric)
781
  train_time += time.time() - train_start
 
782
  if has_tensorboard and jax.process_index() == 0:
 
783
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
784
 
785
  epochs.write(
@@ -789,6 +791,7 @@ def main():
789
  train_metrics = []
790
 
791
  if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
 
792
  eval_metrics = []
793
  eval_preds = []
794
  eval_labels = []
@@ -827,20 +830,27 @@ def main():
827
 
828
  # Save metrics
829
  if has_tensorboard and jax.process_index() == 0:
830
- cur_step = epoch * (len(train_dataset) // train_batch_size)
 
831
  write_eval_metric(summary_writer, eval_metrics, cur_step)
832
 
833
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
834
- # save checkpoint after each epoch and push checkpoint to the hub
 
835
  if jax.process_index() == 0:
836
- # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
837
- params = jax.device_get(unreplicate(state.params))
838
  model.save_pretrained(
839
  training_args.output_dir,
840
  params=params,
841
  push_to_hub=training_args.push_to_hub,
842
  commit_message=f"Saving weights and logs of step {cur_step}",
843
  )
 
 
 
 
 
844
 
845
 
846
  if __name__ == "__main__":
 
779
  # Save metrics
780
  train_metric = unreplicate(train_metric)
781
  train_time += time.time() - train_start
782
+
783
  if has_tensorboard and jax.process_index() == 0:
784
+ logger.info(f"*** Writing training summary after {cur_step} steps ***")
785
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
786
 
787
  epochs.write(
 
791
  train_metrics = []
792
 
793
  if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
794
+ logger.info(f"*** Evaluation after {cur_step} steps ***")
795
  eval_metrics = []
796
  eval_preds = []
797
  eval_labels = []
 
830
 
831
  # Save metrics
832
  if has_tensorboard and jax.process_index() == 0:
833
+ logger.info(f"*** Writing evaluation summary after {cur_step} steps ***")
834
+ # cur_step = epoch * (len(train_dataset) // train_batch_size)
835
  write_eval_metric(summary_writer, eval_metrics, cur_step)
836
 
837
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
838
+ logger.info(f"*** Saving checkpoints after {cur_step} steps ***")
839
+ # save checkpoint after each steps and push checkpoint to the hub
840
  if jax.process_index() == 0:
841
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
842
+ # params = jax.device_get(unreplicate(state.params))
843
  model.save_pretrained(
844
  training_args.output_dir,
845
  params=params,
846
  push_to_hub=training_args.push_to_hub,
847
  commit_message=f"Saving weights and logs of step {cur_step}",
848
  )
849
+ tokenizer.save_pretrained(
850
+ training_args.output_dir,
851
+ push_to_hub=training_args.push_to_hub,
852
+ commit_message=f"Saving tokenizer step {cur_step}",
853
+ )
854
 
855
 
856
  if __name__ == "__main__":