Fix prediction metric
Browse files- src/run_ed_recipe_nlg.py +5 -5
src/run_ed_recipe_nlg.py
CHANGED
@@ -832,14 +832,14 @@ def main():
|
|
832 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
833 |
|
834 |
# compute ROUGE metrics
|
835 |
-
|
836 |
if data_args.predict_with_generate:
|
837 |
-
|
838 |
-
pred_metrics.update(
|
839 |
-
|
840 |
|
841 |
# Print metrics
|
842 |
-
desc = f"Predict Loss: {pred_metrics['loss']} | {
|
843 |
logger.info(desc)
|
844 |
|
845 |
# save checkpoint after each epoch and push checkpoint to the hub
|
|
|
832 |
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
833 |
|
834 |
# compute ROUGE metrics
|
835 |
+
mix_desc = ""
|
836 |
if data_args.predict_with_generate:
|
837 |
+
mix_metrics = compute_metrics(pred_generations, pred_labels)
|
838 |
+
pred_metrics.update(mix_metrics)
|
839 |
+
mix_desc = " ".join([f"Predict {key}: {value} |" for key, value in mix_metrics.items()])
|
840 |
|
841 |
# Print metrics
|
842 |
+
desc = f"Predict Loss: {pred_metrics['loss']} | {mix_desc})"
|
843 |
logger.info(desc)
|
844 |
|
845 |
# save checkpoint after each epoch and push checkpoint to the hub
|