Handling states-steps
Browse files- src/run.sh +9 -5
- src/run_ed_recipe_nlg.py +87 -102
src/run.sh
CHANGED
@@ -12,16 +12,19 @@ export VALIDATION_FILE=/to/../dev.csv
|
|
12 |
export TEST_FILE=/to/../test.csv
|
13 |
export TEXT_COLUMN=inputs
|
14 |
export TARGET_COLUMN=targets
|
15 |
-
export MAX_SOURCE_LENGTH=
|
16 |
export MAX_TARGET_LENGTH=1024
|
17 |
export SOURCE_PREFIX=ingredients
|
18 |
|
19 |
export PER_DEVICE_TRAIN_BATCH_SIZE=8
|
20 |
export PER_DEVICE_EVAL_BATCH_SIZE=8
|
21 |
export GRADIENT_ACCUMULATION_STEPS=2
|
22 |
-
export NUM_TRAIN_EPOCHS=
|
23 |
-
export LEARNING_RATE=
|
24 |
export WARMUP_STEPS=5000
|
|
|
|
|
|
|
25 |
|
26 |
python run_ed_recipe_nlg.py \
|
27 |
--output_dir="$OUTPUT_DIR" \
|
@@ -42,10 +45,11 @@ python run_ed_recipe_nlg.py \
|
|
42 |
--num_train_epochs=$NUM_TRAIN_EPOCHS \
|
43 |
--learning_rate=$LEARNING_RATE \
|
44 |
--warmup_steps=$WARMUP_STEPS \
|
45 |
-
--
|
|
|
|
|
46 |
--prediction_debug \
|
47 |
--do_train \
|
48 |
--do_eval \
|
49 |
-
--do_predict \
|
50 |
--overwrite_output_dir \
|
51 |
--predict_with_generate
|
|
|
12 |
export TEST_FILE=/to/../test.csv
|
13 |
export TEXT_COLUMN=inputs
|
14 |
export TARGET_COLUMN=targets
|
15 |
+
export MAX_SOURCE_LENGTH=256
|
16 |
export MAX_TARGET_LENGTH=1024
|
17 |
export SOURCE_PREFIX=ingredients
|
18 |
|
19 |
export PER_DEVICE_TRAIN_BATCH_SIZE=8
|
20 |
export PER_DEVICE_EVAL_BATCH_SIZE=8
|
21 |
export GRADIENT_ACCUMULATION_STEPS=2
|
22 |
+
export NUM_TRAIN_EPOCHS=5.0
|
23 |
+
export LEARNING_RATE=1e-4
|
24 |
export WARMUP_STEPS=5000
|
25 |
+
export LOGGING_STEPS=500
|
26 |
+
export EVAL_STEPS=2500
|
27 |
+
export SAVE_STEPS=2500
|
28 |
|
29 |
python run_ed_recipe_nlg.py \
|
30 |
--output_dir="$OUTPUT_DIR" \
|
|
|
45 |
--num_train_epochs=$NUM_TRAIN_EPOCHS \
|
46 |
--learning_rate=$LEARNING_RATE \
|
47 |
--warmup_steps=$WARMUP_STEPS \
|
48 |
+
--logging_step=$LOGGING_STEPS \
|
49 |
+
--eval_steps=$EVAL_STEPS \
|
50 |
+
--save_steps=$SAVE_STEPS \
|
51 |
--prediction_debug \
|
52 |
--do_train \
|
53 |
--do_eval \
|
|
|
54 |
--overwrite_output_dir \
|
55 |
--predict_with_generate
|
src/run_ed_recipe_nlg.py
CHANGED
@@ -258,7 +258,20 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
|
|
258 |
yield batch
|
259 |
|
260 |
|
261 |
-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
summary_writer.scalar("train_time", train_time, step)
|
263 |
|
264 |
train_metrics = get_metrics(train_metrics)
|
@@ -267,6 +280,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
267 |
for i, val in enumerate(vals):
|
268 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
269 |
|
|
|
|
|
270 |
for metric_name, value in eval_metrics.items():
|
271 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
272 |
|
@@ -553,7 +568,7 @@ def main():
|
|
553 |
result = {}
|
554 |
|
555 |
try:
|
556 |
-
result_blue = bleu.compute(predictions=decoded_preds, references=
|
557 |
result_blue = result_blue["score"]
|
558 |
except Exception as e:
|
559 |
logger.info(f'Error occurred during bleu {e}')
|
@@ -734,6 +749,7 @@ def main():
|
|
734 |
logger.info(f" Total optimization steps = {total_train_steps}")
|
735 |
|
736 |
train_time = 0
|
|
|
737 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
738 |
for epoch in epochs:
|
739 |
# ======================== Training ================================
|
@@ -741,115 +757,84 @@ def main():
|
|
741 |
|
742 |
# Create sampling rng
|
743 |
rng, input_rng = jax.random.split(rng)
|
744 |
-
train_metrics = []
|
745 |
|
746 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
747 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
748 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
749 |
# train
|
750 |
-
for
|
751 |
batch = next(train_loader)
|
752 |
state, train_metric = p_train_step(state, batch)
|
753 |
train_metrics.append(train_metric)
|
754 |
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
766 |
-
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
|
773 |
-
|
774 |
-
|
775 |
-
|
776 |
-
|
777 |
-
|
778 |
-
|
779 |
-
|
780 |
-
|
781 |
-
|
782 |
-
|
783 |
-
|
784 |
-
|
785 |
-
|
786 |
-
|
787 |
-
|
788 |
-
|
789 |
-
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
# generation
|
824 |
-
if data_args.predict_with_generate:
|
825 |
-
generated_ids = p_generate_step(state.params, batch)
|
826 |
-
pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
827 |
-
pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
828 |
-
|
829 |
-
# normalize prediction metrics
|
830 |
-
pred_metrics = get_metrics(pred_metrics)
|
831 |
-
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
832 |
-
|
833 |
-
# compute ROUGE metrics
|
834 |
-
mix_desc = ""
|
835 |
-
if data_args.predict_with_generate:
|
836 |
-
mix_metrics = compute_metrics(pred_generations, pred_labels)
|
837 |
-
pred_metrics.update(mix_metrics)
|
838 |
-
mix_desc = " ".join([f"Predict {key}: {value} |" for key, value in mix_metrics.items()])
|
839 |
-
|
840 |
-
# Print metrics
|
841 |
-
desc = f"Predict Loss: {pred_metrics['loss']} | {mix_desc})"
|
842 |
-
logger.info(desc)
|
843 |
-
|
844 |
-
# save checkpoint after each epoch and push checkpoint to the hub
|
845 |
-
if jax.process_index() == 0:
|
846 |
-
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
847 |
-
model.save_pretrained(
|
848 |
-
training_args.output_dir,
|
849 |
-
params=params,
|
850 |
-
push_to_hub=training_args.push_to_hub,
|
851 |
-
commit_message=f"Saving weights and logs of epoch {epoch + 1}",
|
852 |
-
)
|
853 |
|
854 |
|
855 |
if __name__ == "__main__":
|
|
|
258 |
yield batch
|
259 |
|
260 |
|
261 |
+
# def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
262 |
+
# summary_writer.scalar("train_time", train_time, step)
|
263 |
+
#
|
264 |
+
# train_metrics = get_metrics(train_metrics)
|
265 |
+
# for key, vals in train_metrics.items():
|
266 |
+
# tag = f"train_{key}"
|
267 |
+
# for i, val in enumerate(vals):
|
268 |
+
# summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
269 |
+
#
|
270 |
+
# for metric_name, value in eval_metrics.items():
|
271 |
+
# summary_writer.scalar(f"eval_{metric_name}", value, step)
|
272 |
+
#
|
273 |
+
|
274 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
275 |
summary_writer.scalar("train_time", train_time, step)
|
276 |
|
277 |
train_metrics = get_metrics(train_metrics)
|
|
|
280 |
for i, val in enumerate(vals):
|
281 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
282 |
|
283 |
+
|
284 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
285 |
for metric_name, value in eval_metrics.items():
|
286 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
287 |
|
|
|
568 |
result = {}
|
569 |
|
570 |
try:
|
571 |
+
result_blue = bleu.compute(predictions=decoded_preds, references=decoded_labels_bleu)
|
572 |
result_blue = result_blue["score"]
|
573 |
except Exception as e:
|
574 |
logger.info(f'Error occurred during bleu {e}')
|
|
|
749 |
logger.info(f" Total optimization steps = {total_train_steps}")
|
750 |
|
751 |
train_time = 0
|
752 |
+
train_metrics = []
|
753 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
754 |
for epoch in epochs:
|
755 |
# ======================== Training ================================
|
|
|
757 |
|
758 |
# Create sampling rng
|
759 |
rng, input_rng = jax.random.split(rng)
|
|
|
760 |
|
761 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
762 |
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
|
763 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
764 |
# train
|
765 |
+
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
|
766 |
batch = next(train_loader)
|
767 |
state, train_metric = p_train_step(state, batch)
|
768 |
train_metrics.append(train_metric)
|
769 |
|
770 |
+
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
|
771 |
+
|
772 |
+
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
773 |
+
# Save metrics
|
774 |
+
train_metric = unreplicate(train_metric)
|
775 |
+
train_time += time.time() - train_start
|
776 |
+
if has_tensorboard and jax.process_index() == 0:
|
777 |
+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
778 |
+
|
779 |
+
epochs.write(
|
780 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
781 |
+
)
|
782 |
+
|
783 |
+
train_metrics = []
|
784 |
+
|
785 |
+
if cur_step % training_args.eval_steps == 0 and cur_step > 0 and training_args.do_eval:
|
786 |
+
eval_metrics = []
|
787 |
+
eval_preds = []
|
788 |
+
eval_labels = []
|
789 |
+
|
790 |
+
eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
|
791 |
+
eval_steps = len(eval_dataset) // eval_batch_size
|
792 |
+
for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
|
793 |
+
# Model forward
|
794 |
+
batch = next(eval_loader)
|
795 |
+
labels = batch["labels"]
|
796 |
+
|
797 |
+
metrics = p_eval_step(state.params, batch)
|
798 |
+
eval_metrics.append(metrics)
|
799 |
+
|
800 |
+
# generation
|
801 |
+
if data_args.predict_with_generate:
|
802 |
+
generated_ids = p_generate_step(state.params, batch)
|
803 |
+
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
804 |
+
eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
|
805 |
+
|
806 |
+
# normalize eval metrics
|
807 |
+
eval_metrics = get_metrics(eval_metrics)
|
808 |
+
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
809 |
+
|
810 |
+
# compute MIX metrics
|
811 |
+
mix_desc = ""
|
812 |
+
if data_args.predict_with_generate:
|
813 |
+
mix_metrics = compute_metrics(eval_preds, eval_labels)
|
814 |
+
eval_metrics.update(mix_metrics)
|
815 |
+
mix_desc = " ".join([f"Eval {key}: {value} |" for key, value in mix_metrics.items()])
|
816 |
+
|
817 |
+
# Print metrics and update progress bar
|
818 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {mix_desc})"
|
819 |
+
epochs.write(desc)
|
820 |
+
epochs.desc = desc
|
821 |
+
|
822 |
+
# Save metrics
|
823 |
+
if has_tensorboard and jax.process_index() == 0:
|
824 |
+
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
825 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
826 |
+
|
827 |
+
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
828 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
829 |
+
if jax.process_index() == 0:
|
830 |
+
# params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
831 |
+
params = jax.device_get(unreplicate(state.params))
|
832 |
+
model.save_pretrained(
|
833 |
+
training_args.output_dir,
|
834 |
+
params=params,
|
835 |
+
push_to_hub=training_args.push_to_hub,
|
836 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
837 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
838 |
|
839 |
|
840 |
if __name__ == "__main__":
|