Saving weights and logs of epoch 6
Browse files- .run_translation_t5_flax.py.swp +0 -0
- config.json +3 -0
- events.out.tfevents.1625766202.t1v-n-55481057-w-0.41473.3.v2 +0 -0
- events.out.tfevents.1625766661.t1v-n-55481057-w-0.42918.3.v2 +0 -0
- events.out.tfevents.1625767718.t1v-n-55481057-w-0.44369.3.v2 +0 -0
- events.out.tfevents.1625767744.t1v-n-55481057-w-0.45667.3.v2 +0 -0
- events.out.tfevents.1625768139.t1v-n-55481057-w-0.47104.3.v2 +0 -0
- events.out.tfevents.1625768463.t1v-n-55481057-w-0.48556.3.v2 +0 -0
- events.out.tfevents.1625769058.t1v-n-55481057-w-0.50006.3.v2 +0 -0
- events.out.tfevents.1625769345.t1v-n-55481057-w-0.51489.3.v2 +0 -0
- events.out.tfevents.1625769791.t1v-n-55481057-w-0.52973.3.v2 +0 -0
- events.out.tfevents.1625770347.t1v-n-55481057-w-0.54460.3.v2 +0 -0
- events.out.tfevents.1625770589.t1v-n-55481057-w-0.55856.3.v2 +0 -0
- events.out.tfevents.1625770862.t1v-n-55481057-w-0.57252.3.v2 +0 -0
- events.out.tfevents.1625771104.t1v-n-55481057-w-0.58650.3.v2 +0 -0
- flax_model.msgpack +3 -0
- run_translation_t5_flax.py +18 -8
.run_translation_t5_flax.py.swp
ADDED
Binary file (57.3 kB). View file
|
|
config.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:67ab6b43f4bacd25ccb5e78e065aa2c118535865f9621645b9f0caad1249e47c
|
3 |
+
size 1360
|
events.out.tfevents.1625766202.t1v-n-55481057-w-0.41473.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625766661.t1v-n-55481057-w-0.42918.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625767718.t1v-n-55481057-w-0.44369.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625767744.t1v-n-55481057-w-0.45667.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625768139.t1v-n-55481057-w-0.47104.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625768463.t1v-n-55481057-w-0.48556.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625769058.t1v-n-55481057-w-0.50006.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625769345.t1v-n-55481057-w-0.51489.3.v2
ADDED
Binary file (32 kB). View file
|
|
events.out.tfevents.1625769791.t1v-n-55481057-w-0.52973.3.v2
ADDED
Binary file (32 kB). View file
|
|
events.out.tfevents.1625770347.t1v-n-55481057-w-0.54460.3.v2
ADDED
Binary file (31.9 kB). View file
|
|
events.out.tfevents.1625770589.t1v-n-55481057-w-0.55856.3.v2
ADDED
Binary file (40 Bytes). View file
|
|
events.out.tfevents.1625770862.t1v-n-55481057-w-0.57252.3.v2
ADDED
Binary file (31.9 kB). View file
|
|
events.out.tfevents.1625771104.t1v-n-55481057-w-0.58650.3.v2
ADDED
Binary file (176 kB). View file
|
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:69324e2320f6e0c2619bce081b9a703fb4f3dadd403c7b960875a5a8c61d1f39
|
3 |
+
size 241981002
|
run_translation_t5_flax.py
CHANGED
@@ -260,8 +260,10 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
260 |
for i, val in enumerate(vals):
|
261 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
262 |
|
|
|
263 |
for metric_name, value in eval_metrics.items():
|
264 |
-
|
|
|
265 |
|
266 |
|
267 |
def create_learning_rate_fn(
|
@@ -499,7 +501,7 @@ def main():
|
|
499 |
)
|
500 |
|
501 |
# Metric
|
502 |
-
metric = load_metric("
|
503 |
|
504 |
def postprocess_text(preds, labels):
|
505 |
preds = [pred.strip() for pred in preds]
|
@@ -519,14 +521,22 @@ def main():
|
|
519 |
#Probably not needed for bleu - pere
|
520 |
#decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
521 |
|
522 |
-
breakpoint()
|
523 |
-
result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
# Extract a few results from ROUGE
|
525 |
-
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
526 |
|
527 |
-
prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
528 |
-
result["gen_len"] = np.mean(prediction_lens)
|
529 |
-
result = {k: round(v, 4) for k, v in result.items()}
|
530 |
return result
|
531 |
|
532 |
# Enable tensorboard only on the master node
|
|
|
260 |
for i, val in enumerate(vals):
|
261 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
262 |
|
263 |
+
#Pere - dropping all values that are not float
|
264 |
for metric_name, value in eval_metrics.items():
|
265 |
+
if isinstance(value,float):
|
266 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
267 |
|
268 |
|
269 |
def create_learning_rate_fn(
|
|
|
501 |
)
|
502 |
|
503 |
# Metric
|
504 |
+
metric = load_metric("sacrebleu")
|
505 |
|
506 |
def postprocess_text(preds, labels):
|
507 |
preds = [pred.strip() for pred in preds]
|
|
|
521 |
#Probably not needed for bleu - pere
|
522 |
#decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
|
523 |
|
524 |
+
#breakpoint()
|
525 |
+
#result = metric.compute(predictions=decoded_preds, references=decoded_labels)
|
526 |
+
decoded_labels_list = [[d] for d in decoded_labels]
|
527 |
+
result = metric.compute(predictions=decoded_preds, references=decoded_labels_list)
|
528 |
+
|
529 |
+
#Debug stuff - pere
|
530 |
+
print("Example translations")
|
531 |
+
for i in range(0,5):
|
532 |
+
print(f'{decoded_preds[i]} - {decoded_labels_list[i]}')
|
533 |
+
#breakpoint()
|
534 |
# Extract a few results from ROUGE
|
535 |
+
#result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
|
536 |
|
537 |
+
#prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
|
538 |
+
#result["gen_len"] = np.mean(prediction_lens)
|
539 |
+
#result = {k: round(v, 4) for k, v in result.items()}
|
540 |
return result
|
541 |
|
542 |
# Enable tensorboard only on the master node
|