dat
commited on
Commit
•
f291f93
1
Parent(s):
f6e0bf7
Saving weights and logs at step 1252
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Load data & train tokenizer.ipynb +0 -0
- checkpoint_60000 +3 -0
- events.out.tfevents.1626173264.t1v-n-f5c06ea1-w-0.340852.3.v2 +3 -0
- events.out.tfevents.1626174131.t1v-n-f5c06ea1-w-0.343920.3.v2 +3 -0
- events.out.tfevents.1626174670.t1v-n-f5c06ea1-w-0.346512.3.v2 +3 -0
- events.out.tfevents.1626175237.t1v-n-f5c06ea1-w-0.349243.3.v2 +3 -0
- events.out.tfevents.1626176074.t1v-n-f5c06ea1-w-0.351681.3.v2 +3 -0
- events.out.tfevents.1626180467.t1v-n-f5c06ea1-w-0.354027.3.v2 +3 -0
- events.out.tfevents.1626180750.t1v-n-f5c06ea1-w-0.355855.3.v2 +3 -0
- events.out.tfevents.1626181600.t1v-n-f5c06ea1-w-0.357816.3.v2 +3 -0
- events.out.tfevents.1626181889.t1v-n-f5c06ea1-w-0.360037.3.v2 +3 -0
- events.out.tfevents.1626182175.t1v-n-f5c06ea1-w-0.362298.3.v2 +3 -0
- events.out.tfevents.1626182874.t1v-n-f5c06ea1-w-0.365284.3.v2 +3 -0
- events.out.tfevents.1626184460.t1v-n-f5c06ea1-w-0.369028.3.v2 +3 -0
- events.out.tfevents.1626242600.t1v-n-f5c06ea1-w-0.491835.3.v2 +3 -0
- events.out.tfevents.1626285315.t1v-n-f5c06ea1-w-0.533662.3.v2 +3 -0
- events.out.tfevents.1626286793.t1v-n-f5c06ea1-w-0.547087.3.v2 +3 -0
- events.out.tfevents.1626287584.t1v-n-f5c06ea1-w-0.550207.3.v2 +3 -0
- events.out.tfevents.1626288936.t1v-n-f5c06ea1-w-0.553832.3.v2 +3 -0
- events.out.tfevents.1626290714.t1v-n-f5c06ea1-w-0.557554.3.v2 +3 -0
- events.out.tfevents.1626292080.t1v-n-f5c06ea1-w-0.560928.3.v2 +3 -0
- events.out.tfevents.1626292866.t1v-n-f5c06ea1-w-0.563390.3.v2 +3 -0
- events.out.tfevents.1626293250.t1v-n-f5c06ea1-w-0.565261.3.v2 +3 -0
- events.out.tfevents.1626294676.t1v-n-f5c06ea1-w-0.568447.3.v2 +3 -0
- events.out.tfevents.1626295212.t1v-n-f5c06ea1-w-0.570637.3.v2 +3 -0
- events.out.tfevents.1626296457.t1v-n-f5c06ea1-w-0.573688.3.v2 +3 -0
- events.out.tfevents.1626296630.t1v-n-f5c06ea1-w-0.575437.3.v2 +3 -0
- flax_model.msgpack +2 -2
- run.sh +11 -9
- run_mlm_flax.py +270 -218
- run_mlm_flax_no_accum.py +776 -0
- save_tokenized_data.py +484 -0
- train_tokenizer.py +43 -0
- wandb/debug-internal.log +1 -1
- wandb/debug.log +1 -1
- wandb/latest-run +1 -1
- wandb/run-20210713_010630-14xhiyhf/files/output.log +9 -0
- wandb/run-20210713_010630-14xhiyhf/logs/debug-internal.log +24 -0
- wandb/run-20210713_010630-14xhiyhf/logs/debug.log +2 -0
- wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb +0 -0
- wandb/run-20210713_104745-1rl2j7or/files/config.yaml +304 -0
- wandb/run-20210713_104745-1rl2j7or/files/output.log +57 -0
- wandb/run-20210713_104745-1rl2j7or/files/requirements.txt +92 -0
- wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json +44 -0
- wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json +1 -0
- wandb/run-20210713_104745-1rl2j7or/logs/debug-internal.log +181 -0
- wandb/run-20210713_104745-1rl2j7or/logs/debug.log +27 -0
- wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb +0 -0
- wandb/run-20210713_110212-594z6oo0/files/config.yaml +307 -0
- wandb/run-20210713_110212-594z6oo0/files/output.log +39 -0
Load data & train tokenizer.ipynb
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
checkpoint_60000
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73e6d7222b2cee297be0891db385dcce6e0cbff6ec3697c08118513955f8aaf7
|
3 |
+
size 769729450
|
events.out.tfevents.1626173264.t1v-n-f5c06ea1-w-0.340852.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73fdfc3eb9d8111b1e3460227717a3942adfe9263bca08b7fd2bfab9af98d9a1
|
3 |
+
size 38186
|
events.out.tfevents.1626174131.t1v-n-f5c06ea1-w-0.343920.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dfc6f0b5b354bd4d8d13834613ece71ac9d948186313bc3fde5e2e132a1c9cab
|
3 |
+
size 40
|
events.out.tfevents.1626174670.t1v-n-f5c06ea1-w-0.346512.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f74cf77c0a672ad1201614ba6642a4f3a27b9cf021d0e88eb362c7f38ee86304
|
3 |
+
size 40
|
events.out.tfevents.1626175237.t1v-n-f5c06ea1-w-0.349243.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:be5c2acf821fd2ce776ff5e434706cb933a0fa323f0bb1a82dadd832f1f589d4
|
3 |
+
size 40
|
events.out.tfevents.1626176074.t1v-n-f5c06ea1-w-0.351681.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b085d5029d052defe00b26c54b6357e9d05cbc5ad38cdd2f12537ed0b90008d2
|
3 |
+
size 441341
|
events.out.tfevents.1626180467.t1v-n-f5c06ea1-w-0.354027.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:973eec9b2b17e54f3ee35dc0c4b85a4a3ecf5488cb59f5619d7c635641bfe7b6
|
3 |
+
size 40
|
events.out.tfevents.1626180750.t1v-n-f5c06ea1-w-0.355855.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:013fc500b7fdd46262ee2b2ed5a3624249adef426d0b134944080ccf90d363ed
|
3 |
+
size 40
|
events.out.tfevents.1626181600.t1v-n-f5c06ea1-w-0.357816.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a3d4a519b8f1c293258e292768822980b487ef0e02bbfe9d6a3132b8c2fdd791
|
3 |
+
size 40
|
events.out.tfevents.1626181889.t1v-n-f5c06ea1-w-0.360037.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c1ed9142ba98f2f7197e2a44361331a8c112af5dba98d7fc9f0bcab6228ae8c
|
3 |
+
size 40
|
events.out.tfevents.1626182175.t1v-n-f5c06ea1-w-0.362298.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29cc2c143c306c4619802094513459dbb71c4730d3cdfb879e7224923ddfe7ea
|
3 |
+
size 40
|
events.out.tfevents.1626182874.t1v-n-f5c06ea1-w-0.365284.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:24aa4302db5d02121389fc7f8944025588034aedd21f772c2b71224e3a0b0d13
|
3 |
+
size 220634
|
events.out.tfevents.1626184460.t1v-n-f5c06ea1-w-0.369028.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e5631bf443386a4e37d77053e55ba4517153d5f6d7f77b616258d9c78e6901f
|
3 |
+
size 367772
|
events.out.tfevents.1626242600.t1v-n-f5c06ea1-w-0.491835.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1f94f6c2d80b0e0d6247997634649101caefa3ad8ab4f408b529ad38f86c8770
|
3 |
+
size 40
|
events.out.tfevents.1626285315.t1v-n-f5c06ea1-w-0.533662.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29b681f16c441caf85381c9def58d19f4479a2460146d2cfb68991f8327f01fe
|
3 |
+
size 40
|
events.out.tfevents.1626286793.t1v-n-f5c06ea1-w-0.547087.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:53d63b11450875138751afac48c611f4da76fadc0affb0ec98896b35dbad9728
|
3 |
+
size 40
|
events.out.tfevents.1626287584.t1v-n-f5c06ea1-w-0.550207.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62cc6dc4bf215d99f8685629bf632f82d65fc7f1127d876ded332b31b5432064
|
3 |
+
size 40
|
events.out.tfevents.1626288936.t1v-n-f5c06ea1-w-0.553832.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1fccf6070edac76c190b8bb8de4e37b889dd1b18835777203f9d16ac658aaf71
|
3 |
+
size 40
|
events.out.tfevents.1626290714.t1v-n-f5c06ea1-w-0.557554.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d46028802a38f383ce27081e90ff848e3da863ac08c341f101eed1b20a39556c
|
3 |
+
size 40
|
events.out.tfevents.1626292080.t1v-n-f5c06ea1-w-0.560928.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2e89d0090ae1228c609a140c2a20fbdfb208480a0dd16aced968756947a93f0
|
3 |
+
size 147065
|
events.out.tfevents.1626292866.t1v-n-f5c06ea1-w-0.563390.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2b5607707732c41fb3bac9b56702cf2a006ba526d98638e0352ba54e809c6eff
|
3 |
+
size 40
|
events.out.tfevents.1626293250.t1v-n-f5c06ea1-w-0.565261.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83bed69057844c7af14e165d87c9678d28135297ab5bd374d1e0d80ebd31966f
|
3 |
+
size 221057
|
events.out.tfevents.1626294676.t1v-n-f5c06ea1-w-0.568447.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:050b6dc69ea5a9946fc01c76d67ea00913117399f1a37e0f24db39f39c52e76f
|
3 |
+
size 73565
|
events.out.tfevents.1626295212.t1v-n-f5c06ea1-w-0.570637.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2818b40b384ff7f5a57fe1c4994ebbd02140f7221904f527cfc0a9a115334a79
|
3 |
+
size 184532
|
events.out.tfevents.1626296457.t1v-n-f5c06ea1-w-0.573688.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df3d8a6aa5b0177a3c337963bad77cc5cea9ed722032941dbac474d03b5a3261
|
3 |
+
size 40
|
events.out.tfevents.1626296630.t1v-n-f5c06ea1-w-0.575437.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:932b70a150d991f6939f853c7b54516d5309f2d6c19761fa96a50999bf2199e7
|
3 |
+
size 147993
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:422812fccdda54c02543ac5e994b33b54e510e0474439fbe9360d5190787d38e
|
3 |
+
size 510090043
|
run.sh
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
#!/usr/bin/env bash
|
2 |
|
3 |
-
export TOKENIZERS_PARALLELISM=0
|
4 |
|
5 |
python ./run_mlm_flax.py \
|
6 |
--push_to_hub \
|
@@ -14,18 +14,20 @@ python ./run_mlm_flax.py \
|
|
14 |
--overwrite_output_dir \
|
15 |
--adam_beta1="0.9" \
|
16 |
--adam_beta2="0.98" \
|
17 |
-
--logging_steps="
|
18 |
-
--eval_steps="
|
19 |
-
--num_train_epochs="
|
20 |
-
--preprocessing_num_workers="
|
21 |
-
--save_steps="
|
22 |
-
--learning_rate="
|
23 |
--per_device_train_batch_size="2" \
|
24 |
--per_device_eval_batch_size="2" \
|
25 |
--save_total_limit="5"\
|
26 |
-
--
|
|
|
|
|
|
|
27 |
#--adafactor \
|
28 |
#--dtype="bfloat16" \
|
29 |
-
#--resume_from_checkpoint="./"\
|
30 |
|
31 |
|
|
|
1 |
#!/usr/bin/env bash
|
2 |
|
3 |
+
#export TOKENIZERS_PARALLELISM=0
|
4 |
|
5 |
python ./run_mlm_flax.py \
|
6 |
--push_to_hub \
|
|
|
14 |
--overwrite_output_dir \
|
15 |
--adam_beta1="0.9" \
|
16 |
--adam_beta2="0.98" \
|
17 |
+
--logging_steps="250" \
|
18 |
+
--eval_steps="500" \
|
19 |
+
--num_train_epochs="3" \
|
20 |
+
--preprocessing_num_workers="96" \
|
21 |
+
--save_steps="1250" \
|
22 |
+
--learning_rate="1e-4" \
|
23 |
--per_device_train_batch_size="2" \
|
24 |
--per_device_eval_batch_size="2" \
|
25 |
--save_total_limit="5"\
|
26 |
+
--max_eval_samples="500"\
|
27 |
+
--overwrite_cache False \
|
28 |
+
--gradient_accumulation_steps="4" \
|
29 |
+
#--resume_from_checkpoint="./"\
|
30 |
#--adafactor \
|
31 |
#--dtype="bfloat16" \
|
|
|
32 |
|
33 |
|
run_mlm_flax.py
CHANGED
@@ -20,20 +20,18 @@ text file or a dataset.
|
|
20 |
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
21 |
https://huggingface.co/models?filter=masked-lm
|
22 |
"""
|
23 |
-
import shutil
|
24 |
import logging
|
25 |
import os
|
26 |
import sys
|
27 |
import time
|
28 |
from dataclasses import dataclass, field
|
29 |
-
from ast import Str
|
30 |
|
31 |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
32 |
from pathlib import Path
|
33 |
from typing import Dict, List, Optional, Tuple
|
34 |
|
35 |
import numpy as np
|
36 |
-
from datasets import load_dataset
|
37 |
from tqdm import tqdm
|
38 |
|
39 |
import flax
|
@@ -56,13 +54,12 @@ from transformers import (
|
|
56 |
is_tensorboard_available,
|
57 |
set_seed,
|
58 |
)
|
59 |
-
|
60 |
-
from flax.serialization import to_bytes, from_bytes
|
61 |
-
from importlib.util import find_spec
|
62 |
from flax.training import checkpoints
|
63 |
from flax.jax_utils import unreplicate
|
64 |
from flax.training.checkpoints import save_checkpoint, restore_checkpoint
|
65 |
-
import
|
|
|
66 |
|
67 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
68 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
@@ -104,8 +101,10 @@ class ModelArguments:
|
|
104 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
105 |
},
|
106 |
)
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
|
110 |
|
111 |
@dataclass
|
@@ -120,11 +119,6 @@ class DataTrainingArguments:
|
|
120 |
dataset_config_name: Optional[str] = field(
|
121 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
122 |
)
|
123 |
-
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
124 |
-
validation_file: Optional[str] = field(
|
125 |
-
default=None,
|
126 |
-
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
127 |
-
)
|
128 |
train_ref_file: Optional[str] = field(
|
129 |
default=None,
|
130 |
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
@@ -136,6 +130,9 @@ class DataTrainingArguments:
|
|
136 |
overwrite_cache: bool = field(
|
137 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
138 |
)
|
|
|
|
|
|
|
139 |
validation_split_percentage: Optional[int] = field(
|
140 |
default=5,
|
141 |
metadata={
|
@@ -167,6 +164,17 @@ class DataTrainingArguments:
|
|
167 |
default=False,
|
168 |
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
169 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
|
172 |
@flax.struct.dataclass
|
@@ -266,33 +274,73 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|
266 |
for metric_name, value in eval_metrics.items():
|
267 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
268 |
|
269 |
-
def mb_item(x):
|
270 |
-
return x.item() if hasattr(x, "item") else x
|
271 |
-
|
272 |
-
#checkpoint functions
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
|
279 |
-
"Removes older checkpoints so that `save_total_limit` checkpoints are kept"
|
280 |
-
# TODO: what to remove is decided using step number only, we might want to improve that
|
281 |
-
ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
|
282 |
-
# sort checkpoints by step
|
283 |
-
ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
|
284 |
-
ckpts_to_delete = ckpts_sorted[:-save_total_limit]
|
285 |
-
for ckpt in ckpts_to_delete:
|
286 |
-
logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
|
287 |
-
shutil.rmtree(ckpt)
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
-
|
291 |
-
class TrainState(train_state.TrainState):
|
292 |
-
grad_accum: jnp.ndarray
|
293 |
|
294 |
|
295 |
-
|
296 |
if __name__ == "__main__":
|
297 |
# See all possible arguments in src/transformers/training_args.py
|
298 |
# or by passing the --help flag to this script.
|
@@ -360,52 +408,70 @@ if __name__ == "__main__":
|
|
360 |
cache_dir=model_args.cache_dir,
|
361 |
)
|
362 |
else:
|
363 |
-
|
364 |
-
|
365 |
-
# data_files["train"] = data_args.train_file
|
366 |
-
#if data_args.validation_file is not None:
|
367 |
-
# data_files["validation"] = data_args.validation_file
|
368 |
-
#extension = data_args.train_file.split(".")[-1]
|
369 |
-
#if extension == "txt":
|
370 |
-
# extension = "text"
|
371 |
-
#datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
372 |
-
|
373 |
-
#data_dir = "/home/yeb"
|
374 |
-
# data_dir = "/home/yeb/Developer/data"
|
375 |
data_files = []
|
376 |
-
def
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
400 |
-
datasets["train"] = datasets["train"].select(range(int(0.8*len(datasets["train"]))))
|
401 |
-
datasets["validation"] = datasets["validation"].select(range(int(0.8*len(datasets["validation"]))))
|
402 |
-
#datasets["train"] = datasets["train"].select(range(10000))
|
403 |
-
#datasets["validation"] = datasets["validation"].select(range(10000))
|
404 |
|
|
|
|
|
405 |
|
406 |
|
|
|
|
|
407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
if model_args.config_name:
|
410 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
411 |
elif model_args.model_name_or_path:
|
@@ -430,90 +496,97 @@ if __name__ == "__main__":
|
|
430 |
|
431 |
# Preprocessing the datasets.
|
432 |
# First we tokenize all the texts.
|
433 |
-
if training_args.do_train:
|
434 |
-
column_names = datasets["train"].column_names
|
435 |
-
else:
|
436 |
-
column_names = datasets["validation"].column_names
|
437 |
-
text_column_name = "text" if "text" in column_names else column_names[0]
|
438 |
-
|
439 |
-
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
return tokenizer(
|
450 |
-
examples,
|
451 |
-
return_special_tokens_mask=True,
|
452 |
-
padding=padding,
|
453 |
-
truncation=True,
|
454 |
-
max_length=max_seq_length,
|
455 |
)
|
456 |
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
472 |
-
|
473 |
-
tokenized_datasets = datasets.map(
|
474 |
-
tokenize_function,
|
475 |
-
batched=True,
|
476 |
-
num_proc=data_args.preprocessing_num_workers,
|
477 |
-
remove_columns=column_names,
|
478 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
479 |
-
)
|
480 |
-
|
481 |
-
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
482 |
-
# max_seq_length.
|
483 |
-
def group_texts(examples):
|
484 |
-
# Concatenate all texts.
|
485 |
-
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
486 |
-
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
487 |
-
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
488 |
-
# customize this part to your needs.
|
489 |
-
if total_length >= max_seq_length:
|
490 |
-
total_length = (total_length // max_seq_length) * max_seq_length
|
491 |
-
# Split by chunks of max_len.
|
492 |
-
result = {
|
493 |
-
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
494 |
-
for k, t in concatenated_examples.items()
|
495 |
-
}
|
496 |
-
return result
|
497 |
-
|
498 |
-
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
499 |
-
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
500 |
-
# might be slower to preprocess.
|
501 |
-
#
|
502 |
-
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
503 |
-
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
504 |
-
lm_datasets = tokenized_datasets.map(
|
505 |
-
group_texts,
|
506 |
-
batched=True,
|
507 |
-
batch_size=100,
|
508 |
-
num_proc=data_args.preprocessing_num_workers,
|
509 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
510 |
-
)
|
511 |
-
train_dataset = lm_datasets["train"]
|
512 |
-
eval_dataset = lm_datasets["validation"]
|
513 |
-
|
514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
515 |
|
|
|
|
|
516 |
|
|
|
517 |
# Enable tensorboard only on the master node
|
518 |
has_tensorboard = is_tensorboard_available()
|
519 |
if has_tensorboard and jax.process_index() == 0:
|
@@ -531,7 +604,6 @@ if __name__ == "__main__":
|
|
531 |
"Unable to display metrics through TensorBoard because the package is not installed: "
|
532 |
"Please run pip install tensorboard to enable."
|
533 |
)
|
534 |
-
# enable wandb tracking
|
535 |
has_wandb = find_spec("wandb") is not None
|
536 |
if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
|
537 |
try:
|
@@ -547,7 +619,6 @@ if __name__ == "__main__":
|
|
547 |
except ImportError as e:
|
548 |
print(e)
|
549 |
has_wandb = False
|
550 |
-
|
551 |
# Data collator
|
552 |
# This one will take care of randomly masking the tokens.
|
553 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
@@ -567,10 +638,10 @@ if __name__ == "__main__":
|
|
567 |
|
568 |
# Store some constant
|
569 |
num_epochs = int(training_args.num_train_epochs)
|
570 |
-
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
571 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
572 |
|
573 |
-
num_train_steps = len(
|
574 |
|
575 |
# Create learning rate schedule
|
576 |
warmup_fn = optax.linear_schedule(
|
@@ -605,6 +676,7 @@ if __name__ == "__main__":
|
|
605 |
learning_rate=linear_decay_lr_schedule_fn,
|
606 |
)
|
607 |
else:
|
|
|
608 |
optimizer = optax.adamw(
|
609 |
learning_rate=linear_decay_lr_schedule_fn,
|
610 |
b1=training_args.adam_beta1,
|
@@ -613,22 +685,26 @@ if __name__ == "__main__":
|
|
613 |
weight_decay=training_args.weight_decay,
|
614 |
mask=decay_mask_fn,
|
615 |
)
|
|
|
|
|
|
|
|
|
616 |
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
|
621 |
# Setup train state
|
622 |
-
|
623 |
-
|
624 |
-
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer,grad_accum=jax.tree_map(jnp.zeros_like, model.params))
|
625 |
-
|
626 |
if training_args.resume_from_checkpoint:
|
627 |
-
state =
|
628 |
-
resume_step = mb_item(state.step
|
|
|
|
|
629 |
else:
|
630 |
resume_step = 0
|
631 |
-
|
632 |
|
633 |
# Define gradient update step fn
|
634 |
def train_step(state, batch, dropout_rng):
|
@@ -646,30 +722,17 @@ if __name__ == "__main__":
|
|
646 |
# take average
|
647 |
loss = loss.sum() / label_mask.sum()
|
648 |
|
649 |
-
return loss
|
650 |
|
651 |
grad_fn = jax.value_and_grad(loss_fn)
|
652 |
-
loss,
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
grads = jax.tree_map(lambda x: x / training_args.gradient_accumulation_steps, grad_accum)
|
657 |
-
grads = jax.lax.pmean(grad_accum, "batch")
|
658 |
-
new_state = state.apply_gradients(grads=grads,grad_accum=jax.tree_map(jnp.zeros_like, grads))
|
659 |
-
return new_state
|
660 |
-
|
661 |
-
new_state = jax.lax.cond(
|
662 |
-
state.step % training_args.gradient_accumulation_steps == 0,
|
663 |
-
lambda _: update_fn(),
|
664 |
-
lambda _: state.replace(grad_accum=grad_accum, step=state.step + 1),
|
665 |
-
None,
|
666 |
-
)
|
667 |
-
|
668 |
metrics = jax.lax.pmean(
|
669 |
-
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
670 |
)
|
671 |
|
672 |
-
#return new_state.replace(new_dropout_rng=new_dropout_rng), metrics
|
673 |
return new_state, metrics, new_dropout_rng
|
674 |
|
675 |
# Create parallel version of the train step
|
@@ -700,7 +763,10 @@ if __name__ == "__main__":
|
|
700 |
state = jax_utils.replicate(state)
|
701 |
|
702 |
train_time = 0
|
703 |
-
|
|
|
|
|
|
|
704 |
for epoch in epochs:
|
705 |
# ======================== Training ================================
|
706 |
train_start = time.time()
|
@@ -708,54 +774,53 @@ if __name__ == "__main__":
|
|
708 |
|
709 |
# Create sampling rng
|
710 |
rng, input_rng = jax.random.split(rng)
|
711 |
-
steps_per_epoch = len(train_dataset) // train_batch_size
|
712 |
|
713 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
714 |
-
num_train_samples = len(
|
715 |
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
716 |
-
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size
|
717 |
|
718 |
# Gather the indexes for creating the batch and do a training step
|
719 |
-
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)):
|
720 |
-
samples = [
|
721 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
722 |
-
|
723 |
|
724 |
# Model forward
|
725 |
model_inputs = shard(model_inputs.data)
|
726 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
727 |
train_metrics.append(train_metric)
|
728 |
|
729 |
-
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
730 |
if cur_step < resume_step:
|
731 |
continue
|
732 |
|
733 |
-
if
|
734 |
# Save metrics
|
735 |
train_metric = jax_utils.unreplicate(train_metric)
|
736 |
train_time += time.time() - train_start
|
737 |
if has_tensorboard and jax.process_index() == 0:
|
738 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
|
|
739 |
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
740 |
# TODO: add accumulation of metrics
|
741 |
_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
|
742 |
wandb.log({"training_step":cur_step, **_metrics}, commit=True)
|
743 |
-
|
744 |
epochs.write(
|
745 |
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
746 |
)
|
747 |
|
748 |
train_metrics = []
|
749 |
|
750 |
-
if cur_step %
|
751 |
# ======================== Evaluating ==============================
|
752 |
-
num_eval_samples = len(
|
753 |
eval_samples_idx = jnp.arange(num_eval_samples)
|
754 |
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
755 |
|
756 |
eval_metrics = []
|
757 |
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
758 |
-
samples = [
|
759 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
760 |
|
761 |
# Model forward
|
@@ -775,30 +840,17 @@ if __name__ == "__main__":
|
|
775 |
# Save metrics
|
776 |
if has_tensorboard and jax.process_index() == 0:
|
777 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
778 |
-
|
779 |
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
780 |
_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
|
781 |
wandb.log({"eval_step":cur_step, **_metrics})
|
782 |
|
783 |
-
if
|
784 |
# save checkpoint after each epoch and push checkpoint to the hub
|
785 |
if jax.process_index() == 0:
|
786 |
-
|
787 |
-
|
788 |
-
training_args.output_dir,
|
789 |
-
params=params,
|
790 |
-
push_to_hub=training_args.push_to_hub,
|
791 |
-
commit_message=f"Saving weights and logs of step {cur_step}",
|
792 |
-
)
|
793 |
-
save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=True)
|
794 |
if training_args.save_total_limit is not None:
|
795 |
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
|
796 |
-
|
797 |
if jax.process_index() == 0:
|
798 |
-
|
799 |
-
model.save_pretrained(
|
800 |
-
training_args.output_dir,
|
801 |
-
params=params,
|
802 |
-
push_to_hub=training_args.push_to_hub,
|
803 |
-
commit_message=f"Saving weights and logs of step {cur_step}",
|
804 |
-
)
|
|
|
20 |
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
21 |
https://huggingface.co/models?filter=masked-lm
|
22 |
"""
|
|
|
23 |
import logging
|
24 |
import os
|
25 |
import sys
|
26 |
import time
|
27 |
from dataclasses import dataclass, field
|
|
|
28 |
|
29 |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
30 |
from pathlib import Path
|
31 |
from typing import Dict, List, Optional, Tuple
|
32 |
|
33 |
import numpy as np
|
34 |
+
from datasets import load_dataset, DatasetDict
|
35 |
from tqdm import tqdm
|
36 |
|
37 |
import flax
|
|
|
54 |
is_tensorboard_available,
|
55 |
set_seed,
|
56 |
)
|
57 |
+
import json
|
|
|
|
|
58 |
from flax.training import checkpoints
|
59 |
from flax.jax_utils import unreplicate
|
60 |
from flax.training.checkpoints import save_checkpoint, restore_checkpoint
|
61 |
+
from importlib.util import find_spec
|
62 |
+
|
63 |
|
64 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
65 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
|
101 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
102 |
},
|
103 |
)
|
104 |
+
save_optimizer: Optional[bool] = field(
|
105 |
+
default=True,
|
106 |
+
metadata={"help": "Whether to store full train state including optimizer."},
|
107 |
+
)
|
108 |
|
109 |
|
110 |
@dataclass
|
|
|
119 |
dataset_config_name: Optional[str] = field(
|
120 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
121 |
)
|
|
|
|
|
|
|
|
|
|
|
122 |
train_ref_file: Optional[str] = field(
|
123 |
default=None,
|
124 |
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
|
|
130 |
overwrite_cache: bool = field(
|
131 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
132 |
)
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
validation_split_percentage: Optional[int] = field(
|
137 |
default=5,
|
138 |
metadata={
|
|
|
164 |
default=False,
|
165 |
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
166 |
)
|
167 |
+
max_eval_samples: Optional[int] = field(
|
168 |
+
default=None,
|
169 |
+
metadata={
|
170 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
171 |
+
"value if set."
|
172 |
+
},
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
|
179 |
|
180 |
@flax.struct.dataclass
|
|
|
274 |
for metric_name, value in eval_metrics.items():
|
275 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
+
def _zeros_tree_like(inp_tree):
|
279 |
+
return jax.tree_map(jnp.zeros_like, inp_tree)
|
280 |
+
|
281 |
+
def fake_update(state):
|
282 |
+
fake_updates = _zeros_tree_like(state.params)
|
283 |
+
_, new_inner_opt_state = state.tx.inner_opt.update(fake_updates, state.opt_state.inner_opt_state, state.params)
|
284 |
+
opt_state = state.opt_state
|
285 |
+
new_opt_state = optax.MultiStepsState(mini_step=opt_state.mini_step,
|
286 |
+
gradient_step=opt_state.gradient_step,
|
287 |
+
inner_opt_state=new_inner_opt_state,
|
288 |
+
acc_grads=opt_state.acc_grads)
|
289 |
+
return state.replace(opt_state=new_opt_state)
|
290 |
+
|
291 |
+
def reinstantiate_states(opt_state):
|
292 |
+
new_state = []
|
293 |
+
for state in opt_state:
|
294 |
+
cls = getattr(optax, type(state).__name__)
|
295 |
+
new_state.append(cls(**{k:getattr(state, k) for k in state._fields}))
|
296 |
+
return new_state
|
297 |
+
|
298 |
+
def restore_model_checkpoint(save_dir, state):
|
299 |
+
logger.info(f"RESTORING CHECKPOINT FROM {save_dir}...")
|
300 |
+
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
|
301 |
+
params = from_bytes(state.params, f.read())
|
302 |
+
|
303 |
+
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
|
304 |
+
opt_state = from_bytes(state.opt_state, f.read())
|
305 |
+
|
306 |
+
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
|
307 |
+
training_state = json.load(f)
|
308 |
+
step = training_state["step"]
|
309 |
+
|
310 |
+
logger.info("checkpoint restored")
|
311 |
+
# reinstantiate inner opt state to avoid type conflict
|
312 |
+
if hasattr(opt_state, "inner_opt_state"):
|
313 |
+
print("restoring state of multisteps optimizer")
|
314 |
+
inner_opt_state = reinstantiate_states(opt_state.inner_opt_state)
|
315 |
+
ms_state_dict = {k:getattr(state.opt_state, k) for k in state.opt_state._fields}
|
316 |
+
ms_state_dict["inner_opt_state"] = inner_opt_state
|
317 |
+
opt_state = optax.MultiStepsState(**ms_state_dict)
|
318 |
+
|
319 |
+
return state.replace(step=step, params=params, opt_state=opt_state)
|
320 |
+
|
321 |
+
def save_model_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
|
322 |
+
"""
|
323 |
+
If `push_to_hub` is True, will save to `save_dir`. Otherwise will save to `save_dir/ckpt-{step}`.
|
324 |
+
"""
|
325 |
+
state = jax_utils.unreplicate(state)
|
326 |
+
logger.info(f"SAVING CHECKPOINT IN {save_dir}...")
|
327 |
+
if not push_to_hub:
|
328 |
+
save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
|
329 |
+
model.save_pretrained(
|
330 |
+
save_dir,
|
331 |
+
params=state.params,
|
332 |
+
push_to_hub=push_to_hub,
|
333 |
+
commit_message=f"Saving weights and logs at step {mb_item(state.step)-1}",
|
334 |
+
)
|
335 |
+
if with_opt:
|
336 |
+
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
|
337 |
+
f.write(to_bytes(state.opt_state))
|
338 |
+
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
|
339 |
+
json.dump({"step": state.step.item()}, f)
|
340 |
+
logger.info("checkpoint saved")
|
341 |
|
|
|
|
|
|
|
342 |
|
343 |
|
|
|
344 |
if __name__ == "__main__":
|
345 |
# See all possible arguments in src/transformers/training_args.py
|
346 |
# or by passing the --help flag to this script.
|
|
|
408 |
cache_dir=model_args.cache_dir,
|
409 |
)
|
410 |
else:
|
411 |
+
import glob
|
412 |
+
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
data_files = []
|
414 |
+
def add_jsonlines_dir(path, filespec):
|
415 |
+
global data_files
|
416 |
+
data_files += glob.glob(f"{path}/{filespec}")
|
417 |
+
data_files = list(set(data_files))
|
418 |
+
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
419 |
+
add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
|
420 |
+
add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
|
421 |
+
add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
|
422 |
+
random.Random(42).shuffle(data_files)
|
423 |
+
total = len(data_files)
|
424 |
+
print(total)
|
425 |
+
perc = 0.05
|
426 |
+
val_size = int(perc * total)
|
427 |
+
train_size = total - val_size
|
428 |
+
train = data_files[:train_size]
|
429 |
+
val = data_files[train_size:]
|
430 |
+
print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
|
431 |
+
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
432 |
+
load_grouped = True
|
433 |
+
if not load_grouped:
|
434 |
+
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
435 |
+
|
436 |
+
#from datasets import Dataset
|
|
|
|
|
|
|
|
|
|
|
437 |
|
438 |
+
#dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-train.arrow")
|
439 |
+
#dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-validation.arrow")
|
440 |
|
441 |
|
442 |
+
def mb_item(x):
|
443 |
+
return x.item() if hasattr(x, "item") else x
|
444 |
|
445 |
+
def save_model_checkpoint(model, save_dir, state, with_opt:bool=True, push_to_hub:bool=False):
|
446 |
+
"""
|
447 |
+
If `push_to_hub` is True, will save to `save_dir`. Otherwise will save to `save_dir/ckpt-{step}`.
|
448 |
+
"""
|
449 |
+
state = jax_utils.unreplicate(state)
|
450 |
+
logger.info(f"SAVING CHECKPOINT IN {save_dir}...")
|
451 |
+
if not push_to_hub:
|
452 |
+
save_dir = f"{save_dir}/ckpt-{mb_item(state.step)-1}"
|
453 |
+
model.save_pretrained(
|
454 |
+
save_dir,
|
455 |
+
params=state.params,
|
456 |
+
push_to_hub=push_to_hub,
|
457 |
+
commit_message=f"Saving weights and logs at step {mb_item(state.step)-1}",
|
458 |
+
)
|
459 |
+
if with_opt:
|
460 |
+
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
|
461 |
+
f.write(to_bytes(state.opt_state))
|
462 |
+
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
|
463 |
+
json.dump({"step": state.step.item()}, f)
|
464 |
+
logger.info("checkpoint saved")
|
465 |
|
466 |
+
|
467 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
468 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
469 |
+
|
470 |
+
# Load pretrained model and tokenizer
|
471 |
+
|
472 |
+
# Distributed training:
|
473 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
474 |
+
# download model & vocab.
|
475 |
if model_args.config_name:
|
476 |
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
477 |
elif model_args.model_name_or_path:
|
|
|
496 |
|
497 |
# Preprocessing the datasets.
|
498 |
# First we tokenize all the texts.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
+
if load_grouped:
|
501 |
+
logger.info("Loading tokenized and grouped dataset")
|
502 |
+
tokenized_datasets = DatasetDict.load_from_disk("/data/tokenized_data")
|
503 |
+
logger.info("Setting max validation examples to ")
|
504 |
+
print(f"Number of validation examples {data_args.max_eval_samples}")
|
505 |
+
tokenized_datasets["train"]= tokenized_datasets["train"].select(range(20000))
|
506 |
+
if data_args.max_eval_samples is not None:
|
507 |
+
tokenized_datasets["validation"] = tokenized_datasets["validation"].select(range(data_args.max_eval_samples))
|
508 |
+
else:
|
509 |
+
if training_args.do_train:
|
510 |
+
column_names = datasets["train"].column_names
|
511 |
+
else:
|
512 |
+
column_names = datasets["validation"].column_names
|
513 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
514 |
+
|
515 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
516 |
+
|
517 |
+
if data_args.line_by_line:
|
518 |
+
# When using line_by_line, we just tokenize each nonempty line.
|
519 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
520 |
+
|
521 |
+
def tokenize_function(examples):
|
522 |
+
# Remove empty lines
|
523 |
+
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
524 |
+
return tokenizer(
|
525 |
+
examples,
|
526 |
+
return_special_tokens_mask=True,
|
527 |
+
padding=padding,
|
528 |
+
truncation=True,
|
529 |
+
max_length=max_seq_length,
|
530 |
+
)
|
531 |
|
532 |
+
tokenized_datasets = datasets.map(
|
533 |
+
tokenize_function,
|
534 |
+
input_columns=[text_column_name],
|
535 |
+
batched=True,
|
536 |
+
num_proc=data_args.preprocessing_num_workers,
|
537 |
+
remove_columns=column_names,
|
538 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
)
|
540 |
|
541 |
+
else:
|
542 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
543 |
+
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
544 |
+
# efficient when it receives the `special_tokens_mask`.
|
545 |
+
def tokenize_function(examples):
|
546 |
+
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
547 |
+
|
548 |
+
tokenized_datasets = datasets.map(
|
549 |
+
tokenize_function,
|
550 |
+
batched=True,
|
551 |
+
num_proc=data_args.preprocessing_num_workers,
|
552 |
+
remove_columns=column_names,
|
553 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
554 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
555 |
|
556 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
557 |
+
# max_seq_length.
|
558 |
+
def group_texts(examples):
|
559 |
+
# Concatenate all texts.
|
560 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
561 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
562 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
563 |
+
# customize this part to your needs.
|
564 |
+
if total_length >= max_seq_length:
|
565 |
+
total_length = (total_length // max_seq_length) * max_seq_length
|
566 |
+
# Split by chunks of max_len.
|
567 |
+
result = {
|
568 |
+
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
569 |
+
for k, t in concatenated_examples.items()
|
570 |
+
}
|
571 |
+
return result
|
572 |
+
|
573 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
574 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
575 |
+
# might be slower to preprocess.
|
576 |
+
#
|
577 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
578 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
579 |
+
tokenized_datasets = tokenized_datasets.map(
|
580 |
+
group_texts,
|
581 |
+
batched=True,
|
582 |
+
num_proc=data_args.preprocessing_num_workers,
|
583 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
584 |
+
)
|
585 |
|
586 |
+
#tokenized_datasets.save_to_disk("/data/tokenized_data")
|
587 |
+
#print ("tokenized_datasets saved to disk")
|
588 |
|
589 |
+
|
590 |
# Enable tensorboard only on the master node
|
591 |
has_tensorboard = is_tensorboard_available()
|
592 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
604 |
"Unable to display metrics through TensorBoard because the package is not installed: "
|
605 |
"Please run pip install tensorboard to enable."
|
606 |
)
|
|
|
607 |
has_wandb = find_spec("wandb") is not None
|
608 |
if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
|
609 |
try:
|
|
|
619 |
except ImportError as e:
|
620 |
print(e)
|
621 |
has_wandb = False
|
|
|
622 |
# Data collator
|
623 |
# This one will take care of randomly masking the tokens.
|
624 |
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
|
|
638 |
|
639 |
# Store some constant
|
640 |
num_epochs = int(training_args.num_train_epochs)
|
641 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
|
642 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
643 |
|
644 |
+
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
645 |
|
646 |
# Create learning rate schedule
|
647 |
warmup_fn = optax.linear_schedule(
|
|
|
676 |
learning_rate=linear_decay_lr_schedule_fn,
|
677 |
)
|
678 |
else:
|
679 |
+
from optax import clip_by_global_norm
|
680 |
optimizer = optax.adamw(
|
681 |
learning_rate=linear_decay_lr_schedule_fn,
|
682 |
b1=training_args.adam_beta1,
|
|
|
685 |
weight_decay=training_args.weight_decay,
|
686 |
mask=decay_mask_fn,
|
687 |
)
|
688 |
+
optimizer = optax.chain(
|
689 |
+
optax.clip_by_global_norm(1.),
|
690 |
+
optimizer
|
691 |
+
)
|
692 |
|
693 |
+
if training_args.gradient_accumulation_steps > 1:
|
694 |
+
optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
|
695 |
+
grad_accum_steps = training_args.gradient_accumulation_steps
|
696 |
|
697 |
# Setup train state
|
698 |
+
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
699 |
+
|
|
|
|
|
700 |
if training_args.resume_from_checkpoint:
|
701 |
+
state = restore_model_checkpoint(training_args.resume_from_checkpoint, state)
|
702 |
+
resume_step = mb_item(state.step)
|
703 |
+
if training_args.adafactor:
|
704 |
+
state = fake_update(state)
|
705 |
else:
|
706 |
resume_step = 0
|
707 |
+
|
708 |
|
709 |
# Define gradient update step fn
|
710 |
def train_step(state, batch, dropout_rng):
|
|
|
722 |
# take average
|
723 |
loss = loss.sum() / label_mask.sum()
|
724 |
|
725 |
+
return loss
|
726 |
|
727 |
grad_fn = jax.value_and_grad(loss_fn)
|
728 |
+
loss, grad = grad_fn(state.params)
|
729 |
+
grad = jax.lax.pmean(grad, "batch")
|
730 |
+
new_state = state.apply_gradients(grads=grad)
|
731 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
732 |
metrics = jax.lax.pmean(
|
733 |
+
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}, axis_name="batch"
|
734 |
)
|
735 |
|
|
|
736 |
return new_state, metrics, new_dropout_rng
|
737 |
|
738 |
# Create parallel version of the train step
|
|
|
763 |
state = jax_utils.replicate(state)
|
764 |
|
765 |
train_time = 0
|
766 |
+
steps_per_epoch = len(tokenized_datasets["train"]) // train_batch_size
|
767 |
+
resume_epoch = resume_step // (steps_per_epoch * grad_accum_steps)
|
768 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... ({resume_epoch+1}/{num_epochs})", position=0)
|
769 |
+
logger.info(f"Skipping to epoch {resume_epoch} step {resume_step // grad_accum_steps}")
|
770 |
for epoch in epochs:
|
771 |
# ======================== Training ================================
|
772 |
train_start = time.time()
|
|
|
774 |
|
775 |
# Create sampling rng
|
776 |
rng, input_rng = jax.random.split(rng)
|
|
|
777 |
|
778 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
779 |
+
num_train_samples = len(tokenized_datasets["train"])
|
780 |
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
781 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
|
782 |
|
783 |
# Gather the indexes for creating the batch and do a training step
|
784 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step // grad_accum_steps)):
|
785 |
+
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
786 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
|
|
787 |
|
788 |
# Model forward
|
789 |
model_inputs = shard(model_inputs.data)
|
790 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
791 |
train_metrics.append(train_metric)
|
792 |
|
793 |
+
cur_step = epoch * (num_train_samples // train_batch_size * grad_accum_steps) + step
|
794 |
if cur_step < resume_step:
|
795 |
continue
|
796 |
|
797 |
+
if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
|
798 |
# Save metrics
|
799 |
train_metric = jax_utils.unreplicate(train_metric)
|
800 |
train_time += time.time() - train_start
|
801 |
if has_tensorboard and jax.process_index() == 0:
|
802 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
803 |
+
|
804 |
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
805 |
# TODO: add accumulation of metrics
|
806 |
_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
|
807 |
wandb.log({"training_step":cur_step, **_metrics}, commit=True)
|
808 |
+
|
809 |
epochs.write(
|
810 |
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
811 |
)
|
812 |
|
813 |
train_metrics = []
|
814 |
|
815 |
+
if cur_step % training_args.eval_steps * grad_accum_steps == 0 and cur_step > 0:
|
816 |
# ======================== Evaluating ==============================
|
817 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
818 |
eval_samples_idx = jnp.arange(num_eval_samples)
|
819 |
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
820 |
|
821 |
eval_metrics = []
|
822 |
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
823 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
824 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
825 |
|
826 |
# Model forward
|
|
|
840 |
# Save metrics
|
841 |
if has_tensorboard and jax.process_index() == 0:
|
842 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
|
|
843 |
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
844 |
_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
|
845 |
wandb.log({"eval_step":cur_step, **_metrics})
|
846 |
|
847 |
+
if cur_step % training_args.save_steps == 0 * grad_accum_steps and cur_step > 0:
|
848 |
# save checkpoint after each epoch and push checkpoint to the hub
|
849 |
if jax.process_index() == 0:
|
850 |
+
save_model_checkpoint(model, training_args.output_dir, state, with_opt=model_args.save_optimizer,
|
851 |
+
push_to_hub=training_args.push_to_hub)
|
|
|
|
|
|
|
|
|
|
|
|
|
852 |
if training_args.save_total_limit is not None:
|
853 |
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
|
854 |
+
|
855 |
if jax.process_index() == 0:
|
856 |
+
save_model_checkpoint(model, training_args.output_dir, state, with_opt=model_args.save_optimizer, push_to_hub=training_args.push_to_hub)
|
|
|
|
|
|
|
|
|
|
|
|
run_mlm_flax_no_accum.py
ADDED
@@ -0,0 +1,776 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
18 |
+
text file or a dataset.
|
19 |
+
|
20 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
21 |
+
https://huggingface.co/models?filter=masked-lm
|
22 |
+
"""
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import sys
|
26 |
+
import time
|
27 |
+
from dataclasses import dataclass, field
|
28 |
+
|
29 |
+
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
30 |
+
from pathlib import Path
|
31 |
+
from typing import Dict, List, Optional, Tuple
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
from datasets import load_dataset, DatasetDict
|
35 |
+
from tqdm import tqdm
|
36 |
+
|
37 |
+
import flax
|
38 |
+
import jax
|
39 |
+
import jax.numpy as jnp
|
40 |
+
import optax
|
41 |
+
from flax import jax_utils, traverse_util
|
42 |
+
from flax.training import train_state
|
43 |
+
from flax.training.common_utils import get_metrics, onehot, shard
|
44 |
+
from transformers import (
|
45 |
+
CONFIG_MAPPING,
|
46 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
47 |
+
AutoConfig,
|
48 |
+
AutoTokenizer,
|
49 |
+
FlaxAutoModelForMaskedLM,
|
50 |
+
HfArgumentParser,
|
51 |
+
PreTrainedTokenizerBase,
|
52 |
+
TensorType,
|
53 |
+
TrainingArguments,
|
54 |
+
is_tensorboard_available,
|
55 |
+
set_seed,
|
56 |
+
)
|
57 |
+
import json
|
58 |
+
from flax.training import checkpoints
|
59 |
+
from flax.jax_utils import unreplicate
|
60 |
+
from flax.training.checkpoints import save_checkpoint, restore_checkpoint
|
61 |
+
from importlib.util import find_spec
|
62 |
+
|
63 |
+
|
64 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
65 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class ModelArguments:
|
70 |
+
"""
|
71 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
72 |
+
"""
|
73 |
+
|
74 |
+
model_name_or_path: Optional[str] = field(
|
75 |
+
default=None,
|
76 |
+
metadata={
|
77 |
+
"help": "The model checkpoint for weights initialization."
|
78 |
+
"Don't set if you want to train a model from scratch."
|
79 |
+
},
|
80 |
+
)
|
81 |
+
model_type: Optional[str] = field(
|
82 |
+
default=None,
|
83 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
84 |
+
)
|
85 |
+
config_name: Optional[str] = field(
|
86 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
87 |
+
)
|
88 |
+
tokenizer_name: Optional[str] = field(
|
89 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
90 |
+
)
|
91 |
+
cache_dir: Optional[str] = field(
|
92 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
93 |
+
)
|
94 |
+
use_fast_tokenizer: bool = field(
|
95 |
+
default=True,
|
96 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
97 |
+
)
|
98 |
+
dtype: Optional[str] = field(
|
99 |
+
default="float32",
|
100 |
+
metadata={
|
101 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
102 |
+
},
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class DataTrainingArguments:
|
108 |
+
"""
|
109 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
110 |
+
"""
|
111 |
+
|
112 |
+
dataset_name: Optional[str] = field(
|
113 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
114 |
+
)
|
115 |
+
dataset_config_name: Optional[str] = field(
|
116 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
117 |
+
)
|
118 |
+
train_ref_file: Optional[str] = field(
|
119 |
+
default=None,
|
120 |
+
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
121 |
+
)
|
122 |
+
validation_ref_file: Optional[str] = field(
|
123 |
+
default=None,
|
124 |
+
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
125 |
+
)
|
126 |
+
overwrite_cache: bool = field(
|
127 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
validation_split_percentage: Optional[int] = field(
|
133 |
+
default=5,
|
134 |
+
metadata={
|
135 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
136 |
+
},
|
137 |
+
)
|
138 |
+
max_seq_length: Optional[int] = field(
|
139 |
+
default=None,
|
140 |
+
metadata={
|
141 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
142 |
+
"than this will be truncated. Default to the max input length of the model."
|
143 |
+
},
|
144 |
+
)
|
145 |
+
preprocessing_num_workers: Optional[int] = field(
|
146 |
+
default=None,
|
147 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
148 |
+
)
|
149 |
+
mlm_probability: float = field(
|
150 |
+
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
151 |
+
)
|
152 |
+
pad_to_max_length: bool = field(
|
153 |
+
default=False,
|
154 |
+
metadata={
|
155 |
+
"help": "Whether to pad all samples to `max_seq_length`. "
|
156 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
157 |
+
},
|
158 |
+
)
|
159 |
+
line_by_line: bool = field(
|
160 |
+
default=False,
|
161 |
+
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
162 |
+
)
|
163 |
+
max_eval_samples: Optional[int] = field(
|
164 |
+
default=None,
|
165 |
+
metadata={
|
166 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
167 |
+
"value if set."
|
168 |
+
},
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
@flax.struct.dataclass
|
177 |
+
class FlaxDataCollatorForLanguageModeling:
|
178 |
+
"""
|
179 |
+
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
180 |
+
are not all of the same length.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
184 |
+
The tokenizer used for encoding the data.
|
185 |
+
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
186 |
+
The probability with which to (randomly) mask tokens in the input.
|
187 |
+
|
188 |
+
.. note::
|
189 |
+
|
190 |
+
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
191 |
+
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
192 |
+
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
193 |
+
argument :obj:`return_special_tokens_mask=True`.
|
194 |
+
"""
|
195 |
+
|
196 |
+
tokenizer: PreTrainedTokenizerBase
|
197 |
+
mlm_probability: float = 0.15
|
198 |
+
|
199 |
+
def __post_init__(self):
|
200 |
+
if self.tokenizer.mask_token is None:
|
201 |
+
raise ValueError(
|
202 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
203 |
+
"You should pass `mlm=False` to train on causal language modeling instead."
|
204 |
+
)
|
205 |
+
|
206 |
+
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
207 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
208 |
+
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
209 |
+
|
210 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
211 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
212 |
+
|
213 |
+
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
214 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
215 |
+
)
|
216 |
+
return batch
|
217 |
+
|
218 |
+
def mask_tokens(
|
219 |
+
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
220 |
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
221 |
+
"""
|
222 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
223 |
+
"""
|
224 |
+
labels = inputs.copy()
|
225 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
226 |
+
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
227 |
+
special_tokens_mask = special_tokens_mask.astype("bool")
|
228 |
+
|
229 |
+
probability_matrix[special_tokens_mask] = 0.0
|
230 |
+
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
231 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
232 |
+
|
233 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
234 |
+
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
235 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
236 |
+
|
237 |
+
# 10% of the time, we replace masked input tokens with random word
|
238 |
+
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
239 |
+
indices_random &= masked_indices & ~indices_replaced
|
240 |
+
|
241 |
+
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
242 |
+
inputs[indices_random] = random_words[indices_random]
|
243 |
+
|
244 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
245 |
+
return inputs, labels
|
246 |
+
|
247 |
+
|
248 |
+
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
249 |
+
num_samples = len(samples_idx)
|
250 |
+
samples_to_remove = num_samples % batch_size
|
251 |
+
|
252 |
+
if samples_to_remove != 0:
|
253 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
254 |
+
sections_split = num_samples // batch_size
|
255 |
+
batch_idx = np.split(samples_idx, sections_split)
|
256 |
+
return batch_idx
|
257 |
+
|
258 |
+
|
259 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
260 |
+
summary_writer.scalar("train_time", train_time, step)
|
261 |
+
|
262 |
+
train_metrics = get_metrics(train_metrics)
|
263 |
+
for key, vals in train_metrics.items():
|
264 |
+
tag = f"train_{key}"
|
265 |
+
for i, val in enumerate(vals):
|
266 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
267 |
+
|
268 |
+
|
269 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
270 |
+
for metric_name, value in eval_metrics.items():
|
271 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
272 |
+
|
273 |
+
def rotate_checkpoints(ckpt_dir:str, save_total_limit:int):
|
274 |
+
"Removes older checkpoints so that `save_total_limit` checkpoints are kept"
|
275 |
+
# TODO: what to remove is decided using step number only, we might want to improve that
|
276 |
+
ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
|
277 |
+
# sort checkpoints by step
|
278 |
+
ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
|
279 |
+
ckpts_to_delete = ckpts_sorted[:-save_total_limit]
|
280 |
+
for ckpt in ckpts_to_delete:
|
281 |
+
logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
|
282 |
+
shutil.rmtree(ckpt)
|
283 |
+
|
284 |
+
|
285 |
+
if __name__ == "__main__":
|
286 |
+
# See all possible arguments in src/transformers/training_args.py
|
287 |
+
# or by passing the --help flag to this script.
|
288 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
289 |
+
|
290 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
291 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
292 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
293 |
+
# let's parse it to get our arguments.
|
294 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
295 |
+
else:
|
296 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
297 |
+
|
298 |
+
if (
|
299 |
+
os.path.exists(training_args.output_dir)
|
300 |
+
and os.listdir(training_args.output_dir)
|
301 |
+
and training_args.do_train
|
302 |
+
and not training_args.overwrite_output_dir
|
303 |
+
):
|
304 |
+
raise ValueError(
|
305 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
306 |
+
"Use --overwrite_output_dir to overcome."
|
307 |
+
)
|
308 |
+
|
309 |
+
# Setup logging
|
310 |
+
logging.basicConfig(
|
311 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
312 |
+
level="NOTSET",
|
313 |
+
datefmt="[%X]",
|
314 |
+
)
|
315 |
+
|
316 |
+
# Log on each process the small summary:
|
317 |
+
logger = logging.getLogger(__name__)
|
318 |
+
|
319 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
320 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
321 |
+
|
322 |
+
# Set seed before initializing model.
|
323 |
+
set_seed(training_args.seed)
|
324 |
+
|
325 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
326 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
327 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
328 |
+
#
|
329 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
330 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
331 |
+
#
|
332 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
333 |
+
# download the dataset.
|
334 |
+
if data_args.dataset_name is not None:
|
335 |
+
# Downloading and loading a dataset from the hub.
|
336 |
+
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
337 |
+
|
338 |
+
if "validation" not in datasets.keys():
|
339 |
+
datasets["validation"] = load_dataset(
|
340 |
+
data_args.dataset_name,
|
341 |
+
data_args.dataset_config_name,
|
342 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
343 |
+
cache_dir=model_args.cache_dir,
|
344 |
+
)
|
345 |
+
datasets["train"] = load_dataset(
|
346 |
+
data_args.dataset_name,
|
347 |
+
data_args.dataset_config_name,
|
348 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
349 |
+
cache_dir=model_args.cache_dir,
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
import glob
|
353 |
+
import random
|
354 |
+
data_files = []
|
355 |
+
def add_jsonlines_dir(path, filespec):
|
356 |
+
global data_files
|
357 |
+
data_files += glob.glob(f"{path}/{filespec}")
|
358 |
+
data_files = list(set(data_files))
|
359 |
+
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
360 |
+
add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
|
361 |
+
add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
|
362 |
+
add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
|
363 |
+
random.Random(42).shuffle(data_files)
|
364 |
+
total = len(data_files)
|
365 |
+
print(total)
|
366 |
+
perc = 0.05
|
367 |
+
val_size = int(perc * total)
|
368 |
+
train_size = total - val_size
|
369 |
+
train = data_files[:train_size]
|
370 |
+
val = data_files[train_size:]
|
371 |
+
print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
|
372 |
+
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
373 |
+
load_grouped = True
|
374 |
+
if not load_grouped:
|
375 |
+
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
376 |
+
|
377 |
+
#from datasets import Dataset
|
378 |
+
|
379 |
+
#dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-train.arrow")
|
380 |
+
#dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-validation.arrow")
|
381 |
+
|
382 |
+
|
383 |
+
def mb_item(x):
|
384 |
+
return x.item() if hasattr(x, "item") else x
|
385 |
+
|
386 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
387 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
388 |
+
|
389 |
+
# Load pretrained model and tokenizer
|
390 |
+
|
391 |
+
# Distributed training:
|
392 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
393 |
+
# download model & vocab.
|
394 |
+
if model_args.config_name:
|
395 |
+
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
396 |
+
elif model_args.model_name_or_path:
|
397 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
398 |
+
else:
|
399 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
400 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
401 |
+
|
402 |
+
if model_args.tokenizer_name:
|
403 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
404 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
405 |
+
)
|
406 |
+
elif model_args.model_name_or_path:
|
407 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
408 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
409 |
+
)
|
410 |
+
else:
|
411 |
+
raise ValueError(
|
412 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
413 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
414 |
+
)
|
415 |
+
|
416 |
+
# Preprocessing the datasets.
|
417 |
+
# First we tokenize all the texts.
|
418 |
+
|
419 |
+
if load_grouped:
|
420 |
+
logger.info("Loading tokenized and grouped dataset")
|
421 |
+
tokenized_datasets = DatasetDict.load_from_disk("/data/tokenized_data")
|
422 |
+
logger.info("Setting max validation examples to ")
|
423 |
+
print(f"Number of validation examples {data_args.max_eval_samples}")
|
424 |
+
tokenized_datasets["train"]= tokenized_datasets["train"].select(range(20000))
|
425 |
+
if data_args.max_eval_samples is not None:
|
426 |
+
tokenized_datasets["validation"] = tokenized_datasets["validation"].select(range(data_args.max_eval_samples))
|
427 |
+
else:
|
428 |
+
if training_args.do_train:
|
429 |
+
column_names = datasets["train"].column_names
|
430 |
+
else:
|
431 |
+
column_names = datasets["validation"].column_names
|
432 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
433 |
+
|
434 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
435 |
+
|
436 |
+
if data_args.line_by_line:
|
437 |
+
# When using line_by_line, we just tokenize each nonempty line.
|
438 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
439 |
+
|
440 |
+
def tokenize_function(examples):
|
441 |
+
# Remove empty lines
|
442 |
+
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
443 |
+
return tokenizer(
|
444 |
+
examples,
|
445 |
+
return_special_tokens_mask=True,
|
446 |
+
padding=padding,
|
447 |
+
truncation=True,
|
448 |
+
max_length=max_seq_length,
|
449 |
+
)
|
450 |
+
|
451 |
+
tokenized_datasets = datasets.map(
|
452 |
+
tokenize_function,
|
453 |
+
input_columns=[text_column_name],
|
454 |
+
batched=True,
|
455 |
+
num_proc=data_args.preprocessing_num_workers,
|
456 |
+
remove_columns=column_names,
|
457 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
458 |
+
)
|
459 |
+
|
460 |
+
else:
|
461 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
462 |
+
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
463 |
+
# efficient when it receives the `special_tokens_mask`.
|
464 |
+
def tokenize_function(examples):
|
465 |
+
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
466 |
+
|
467 |
+
tokenized_datasets = datasets.map(
|
468 |
+
tokenize_function,
|
469 |
+
batched=True,
|
470 |
+
num_proc=data_args.preprocessing_num_workers,
|
471 |
+
remove_columns=column_names,
|
472 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
473 |
+
)
|
474 |
+
|
475 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
476 |
+
# max_seq_length.
|
477 |
+
def group_texts(examples):
|
478 |
+
# Concatenate all texts.
|
479 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
480 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
481 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
482 |
+
# customize this part to your needs.
|
483 |
+
if total_length >= max_seq_length:
|
484 |
+
total_length = (total_length // max_seq_length) * max_seq_length
|
485 |
+
# Split by chunks of max_len.
|
486 |
+
result = {
|
487 |
+
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
488 |
+
for k, t in concatenated_examples.items()
|
489 |
+
}
|
490 |
+
return result
|
491 |
+
|
492 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
493 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
494 |
+
# might be slower to preprocess.
|
495 |
+
#
|
496 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
497 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
498 |
+
tokenized_datasets = tokenized_datasets.map(
|
499 |
+
group_texts,
|
500 |
+
batched=True,
|
501 |
+
num_proc=data_args.preprocessing_num_workers,
|
502 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
503 |
+
)
|
504 |
+
|
505 |
+
#tokenized_datasets.save_to_disk("/data/tokenized_data")
|
506 |
+
#print ("tokenized_datasets saved to disk")
|
507 |
+
|
508 |
+
|
509 |
+
# Enable tensorboard only on the master node
|
510 |
+
has_tensorboard = is_tensorboard_available()
|
511 |
+
if has_tensorboard and jax.process_index() == 0:
|
512 |
+
try:
|
513 |
+
from flax.metrics.tensorboard import SummaryWriter
|
514 |
+
|
515 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
516 |
+
except ImportError as ie:
|
517 |
+
has_tensorboard = False
|
518 |
+
logger.warning(
|
519 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
520 |
+
)
|
521 |
+
else:
|
522 |
+
logger.warning(
|
523 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
524 |
+
"Please run pip install tensorboard to enable."
|
525 |
+
)
|
526 |
+
has_wandb = find_spec("wandb") is not None
|
527 |
+
if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
|
528 |
+
try:
|
529 |
+
import wandb
|
530 |
+
wandb.init(
|
531 |
+
entity="wandb",
|
532 |
+
project="hf-flax-pino-roberta",
|
533 |
+
sync_tensorboard=True
|
534 |
+
)
|
535 |
+
wandb.config.update(training_args)
|
536 |
+
wandb.config.update(model_args)
|
537 |
+
wandb.config.update(data_args)
|
538 |
+
except ImportError as e:
|
539 |
+
print(e)
|
540 |
+
has_wandb = False
|
541 |
+
# Data collator
|
542 |
+
# This one will take care of randomly masking the tokens.
|
543 |
+
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
544 |
+
|
545 |
+
# Initialize our training
|
546 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
547 |
+
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
548 |
+
|
549 |
+
if model_args.model_name_or_path:
|
550 |
+
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
551 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
552 |
+
)
|
553 |
+
else:
|
554 |
+
model = FlaxAutoModelForMaskedLM.from_config(
|
555 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
556 |
+
)
|
557 |
+
|
558 |
+
# Store some constant
|
559 |
+
num_epochs = int(training_args.num_train_epochs)
|
560 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
561 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
562 |
+
|
563 |
+
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
564 |
+
|
565 |
+
# Create learning rate schedule
|
566 |
+
warmup_fn = optax.linear_schedule(
|
567 |
+
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
568 |
+
)
|
569 |
+
decay_fn = optax.linear_schedule(
|
570 |
+
init_value=training_args.learning_rate,
|
571 |
+
end_value=0,
|
572 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
573 |
+
)
|
574 |
+
linear_decay_lr_schedule_fn = optax.join_schedules(
|
575 |
+
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
576 |
+
)
|
577 |
+
|
578 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
579 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
580 |
+
# mask boolean with the same structure as the parameters.
|
581 |
+
# The mask is True for parameters that should be decayed.
|
582 |
+
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
583 |
+
# For other models, one should correct the layer norm parameter naming
|
584 |
+
# accordingly.
|
585 |
+
def decay_mask_fn(params):
|
586 |
+
flat_params = traverse_util.flatten_dict(params)
|
587 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
588 |
+
return traverse_util.unflatten_dict(flat_mask)
|
589 |
+
|
590 |
+
# create adam optimizer
|
591 |
+
if training_args.adafactor:
|
592 |
+
# We use the default parameters here to initialize adafactor,
|
593 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
594 |
+
optimizer = optax.adafactor(
|
595 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
596 |
+
)
|
597 |
+
else:
|
598 |
+
optimizer = optax.adamw(
|
599 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
600 |
+
b1=training_args.adam_beta1,
|
601 |
+
b2=training_args.adam_beta2,
|
602 |
+
eps=training_args.adam_epsilon,
|
603 |
+
weight_decay=training_args.weight_decay,
|
604 |
+
mask=decay_mask_fn,
|
605 |
+
)
|
606 |
+
optimizer = optax.chain(
|
607 |
+
optax.clip_grad_by_global_norm(1.),
|
608 |
+
optimizer
|
609 |
+
)
|
610 |
+
|
611 |
+
# Setup train state
|
612 |
+
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
613 |
+
|
614 |
+
if training_args.resume_from_checkpoint:
|
615 |
+
state = restore_checkpoint(training_args.resume_from_checkpoint, state)
|
616 |
+
resume_step = mb_item(state.step.item())
|
617 |
+
else:
|
618 |
+
resume_step = 0
|
619 |
+
|
620 |
+
|
621 |
+
# Define gradient update step fn
|
622 |
+
def train_step(state, batch, dropout_rng):
|
623 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
624 |
+
|
625 |
+
def loss_fn(params):
|
626 |
+
labels = batch.pop("labels")
|
627 |
+
|
628 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
629 |
+
|
630 |
+
# compute loss, ignore padded input tokens
|
631 |
+
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
632 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
633 |
+
|
634 |
+
# take average
|
635 |
+
loss = loss.sum() / label_mask.sum()
|
636 |
+
|
637 |
+
return loss
|
638 |
+
|
639 |
+
grad_fn = jax.value_and_grad(loss_fn)
|
640 |
+
loss, grad = grad_fn(state.params)
|
641 |
+
grad = jax.lax.pmean(grad, "batch")
|
642 |
+
new_state = state.apply_gradients(grads=grad)
|
643 |
+
|
644 |
+
metrics = jax.lax.pmean(
|
645 |
+
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
|
646 |
+
)
|
647 |
+
|
648 |
+
return new_state, metrics, new_dropout_rng
|
649 |
+
|
650 |
+
# Create parallel version of the train step
|
651 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
652 |
+
|
653 |
+
# Define eval fn
|
654 |
+
def eval_step(params, batch):
|
655 |
+
labels = batch.pop("labels")
|
656 |
+
|
657 |
+
logits = model(**batch, params=params, train=False)[0]
|
658 |
+
|
659 |
+
# compute loss, ignore padded input tokens
|
660 |
+
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
661 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
662 |
+
|
663 |
+
# compute accuracy
|
664 |
+
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
|
665 |
+
|
666 |
+
# summarize metrics
|
667 |
+
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
|
668 |
+
metrics = jax.lax.psum(metrics, axis_name="batch")
|
669 |
+
|
670 |
+
return metrics
|
671 |
+
|
672 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
673 |
+
|
674 |
+
# Replicate the train state on each device
|
675 |
+
state = jax_utils.replicate(state)
|
676 |
+
|
677 |
+
train_time = 0
|
678 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
679 |
+
for epoch in epochs:
|
680 |
+
# ======================== Training ================================
|
681 |
+
train_start = time.time()
|
682 |
+
train_metrics = []
|
683 |
+
|
684 |
+
# Create sampling rng
|
685 |
+
rng, input_rng = jax.random.split(rng)
|
686 |
+
|
687 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
688 |
+
num_train_samples = len(tokenized_datasets["train"])
|
689 |
+
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
690 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
691 |
+
|
692 |
+
# Gather the indexes for creating the batch and do a training step
|
693 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)):
|
694 |
+
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
695 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
696 |
+
|
697 |
+
# Model forward
|
698 |
+
model_inputs = shard(model_inputs.data)
|
699 |
+
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
700 |
+
train_metrics.append(train_metric)
|
701 |
+
|
702 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
703 |
+
if cur_step < resume_step:
|
704 |
+
continue
|
705 |
+
|
706 |
+
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
707 |
+
# Save metrics
|
708 |
+
train_metric = jax_utils.unreplicate(train_metric)
|
709 |
+
train_time += time.time() - train_start
|
710 |
+
if has_tensorboard and jax.process_index() == 0:
|
711 |
+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
712 |
+
|
713 |
+
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
714 |
+
# TODO: add accumulation of metrics
|
715 |
+
_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
|
716 |
+
wandb.log({"training_step":cur_step, **_metrics}, commit=True)
|
717 |
+
|
718 |
+
epochs.write(
|
719 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
720 |
+
)
|
721 |
+
|
722 |
+
train_metrics = []
|
723 |
+
|
724 |
+
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
|
725 |
+
# ======================== Evaluating ==============================
|
726 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
727 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
728 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
729 |
+
|
730 |
+
eval_metrics = []
|
731 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
732 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
733 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
734 |
+
|
735 |
+
# Model forward
|
736 |
+
model_inputs = shard(model_inputs.data)
|
737 |
+
metrics = p_eval_step(state.params, model_inputs)
|
738 |
+
eval_metrics.append(metrics)
|
739 |
+
|
740 |
+
# normalize eval metrics
|
741 |
+
eval_metrics = get_metrics(eval_metrics)
|
742 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
743 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
744 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
745 |
+
|
746 |
+
# Update progress bar
|
747 |
+
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
748 |
+
|
749 |
+
# Save metrics
|
750 |
+
if has_tensorboard and jax.process_index() == 0:
|
751 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
752 |
+
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
753 |
+
_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
|
754 |
+
wandb.log({"eval_step":cur_step, **_metrics})
|
755 |
+
|
756 |
+
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
757 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
758 |
+
if jax.process_index() == 0:
|
759 |
+
save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=True)
|
760 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
761 |
+
model.save_pretrained(
|
762 |
+
training_args.output_dir,
|
763 |
+
params=params,
|
764 |
+
push_to_hub=training_args.push_to_hub,
|
765 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
766 |
+
)
|
767 |
+
if training_args.save_total_limit is not None:
|
768 |
+
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
|
769 |
+
if jax.process_index() == 0:
|
770 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
771 |
+
model.save_pretrained(
|
772 |
+
training_args.output_dir,
|
773 |
+
params=params,
|
774 |
+
push_to_hub=training_args.push_to_hub,
|
775 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
776 |
+
)
|
save_tokenized_data.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
18 |
+
text file or a dataset.
|
19 |
+
|
20 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
21 |
+
https://huggingface.co/models?filter=masked-lm
|
22 |
+
"""
|
23 |
+
import logging
|
24 |
+
import os
|
25 |
+
import sys
|
26 |
+
import time
|
27 |
+
from dataclasses import dataclass, field
|
28 |
+
|
29 |
+
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
30 |
+
from pathlib import Path
|
31 |
+
from typing import Dict, List, Optional, Tuple
|
32 |
+
|
33 |
+
import numpy as np
|
34 |
+
from datasets import load_dataset
|
35 |
+
from tqdm import tqdm
|
36 |
+
|
37 |
+
import flax
|
38 |
+
import jax
|
39 |
+
import jax.numpy as jnp
|
40 |
+
import optax
|
41 |
+
from flax import jax_utils, traverse_util
|
42 |
+
from flax.training import train_state
|
43 |
+
from flax.training.common_utils import get_metrics, onehot, shard
|
44 |
+
from transformers import (
|
45 |
+
CONFIG_MAPPING,
|
46 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
47 |
+
AutoConfig,
|
48 |
+
AutoTokenizer,
|
49 |
+
FlaxAutoModelForMaskedLM,
|
50 |
+
HfArgumentParser,
|
51 |
+
PreTrainedTokenizerBase,
|
52 |
+
TensorType,
|
53 |
+
TrainingArguments,
|
54 |
+
is_tensorboard_available,
|
55 |
+
set_seed,
|
56 |
+
)
|
57 |
+
import json
|
58 |
+
from flax.training import checkpoints
|
59 |
+
from flax.jax_utils import unreplicate
|
60 |
+
from flax.training.checkpoints import save_checkpoint, restore_checkpoint
|
61 |
+
from importlib.util import find_spec
|
62 |
+
|
63 |
+
|
64 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
65 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class ModelArguments:
|
70 |
+
"""
|
71 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
72 |
+
"""
|
73 |
+
|
74 |
+
model_name_or_path: Optional[str] = field(
|
75 |
+
default=None,
|
76 |
+
metadata={
|
77 |
+
"help": "The model checkpoint for weights initialization."
|
78 |
+
"Don't set if you want to train a model from scratch."
|
79 |
+
},
|
80 |
+
)
|
81 |
+
model_type: Optional[str] = field(
|
82 |
+
default=None,
|
83 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
84 |
+
)
|
85 |
+
config_name: Optional[str] = field(
|
86 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
87 |
+
)
|
88 |
+
tokenizer_name: Optional[str] = field(
|
89 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
90 |
+
)
|
91 |
+
cache_dir: Optional[str] = field(
|
92 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
93 |
+
)
|
94 |
+
use_fast_tokenizer: bool = field(
|
95 |
+
default=True,
|
96 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
97 |
+
)
|
98 |
+
dtype: Optional[str] = field(
|
99 |
+
default="float32",
|
100 |
+
metadata={
|
101 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
102 |
+
},
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class DataTrainingArguments:
|
108 |
+
"""
|
109 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
110 |
+
"""
|
111 |
+
|
112 |
+
dataset_name: Optional[str] = field(
|
113 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
114 |
+
)
|
115 |
+
dataset_config_name: Optional[str] = field(
|
116 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
117 |
+
)
|
118 |
+
train_ref_file: Optional[str] = field(
|
119 |
+
default=None,
|
120 |
+
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
121 |
+
)
|
122 |
+
validation_ref_file: Optional[str] = field(
|
123 |
+
default=None,
|
124 |
+
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
125 |
+
)
|
126 |
+
overwrite_cache: bool = field(
|
127 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
validation_split_percentage: Optional[int] = field(
|
133 |
+
default=5,
|
134 |
+
metadata={
|
135 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
136 |
+
},
|
137 |
+
)
|
138 |
+
max_seq_length: Optional[int] = field(
|
139 |
+
default=None,
|
140 |
+
metadata={
|
141 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
142 |
+
"than this will be truncated. Default to the max input length of the model."
|
143 |
+
},
|
144 |
+
)
|
145 |
+
preprocessing_num_workers: Optional[int] = field(
|
146 |
+
default=None,
|
147 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
148 |
+
)
|
149 |
+
mlm_probability: float = field(
|
150 |
+
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
151 |
+
)
|
152 |
+
pad_to_max_length: bool = field(
|
153 |
+
default=False,
|
154 |
+
metadata={
|
155 |
+
"help": "Whether to pad all samples to `max_seq_length`. "
|
156 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
157 |
+
},
|
158 |
+
)
|
159 |
+
line_by_line: bool = field(
|
160 |
+
default=False,
|
161 |
+
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
162 |
+
)
|
163 |
+
max_eval_samples: Optional[int] = field(
|
164 |
+
default=None,
|
165 |
+
metadata={
|
166 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
167 |
+
"value if set."
|
168 |
+
},
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
@flax.struct.dataclass
|
177 |
+
class FlaxDataCollatorForLanguageModeling:
|
178 |
+
"""
|
179 |
+
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
180 |
+
are not all of the same length.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
184 |
+
The tokenizer used for encoding the data.
|
185 |
+
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
186 |
+
The probability with which to (randomly) mask tokens in the input.
|
187 |
+
|
188 |
+
.. note::
|
189 |
+
|
190 |
+
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
191 |
+
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
192 |
+
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
193 |
+
argument :obj:`return_special_tokens_mask=True`.
|
194 |
+
"""
|
195 |
+
|
196 |
+
tokenizer: PreTrainedTokenizerBase
|
197 |
+
mlm_probability: float = 0.15
|
198 |
+
|
199 |
+
def __post_init__(self):
|
200 |
+
if self.tokenizer.mask_token is None:
|
201 |
+
raise ValueError(
|
202 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
203 |
+
"You should pass `mlm=False` to train on causal language modeling instead."
|
204 |
+
)
|
205 |
+
|
206 |
+
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
207 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
208 |
+
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
209 |
+
|
210 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
211 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
212 |
+
|
213 |
+
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
214 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
215 |
+
)
|
216 |
+
return batch
|
217 |
+
|
218 |
+
def mask_tokens(
|
219 |
+
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
220 |
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
221 |
+
"""
|
222 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
223 |
+
"""
|
224 |
+
labels = inputs.copy()
|
225 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
226 |
+
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
227 |
+
special_tokens_mask = special_tokens_mask.astype("bool")
|
228 |
+
|
229 |
+
probability_matrix[special_tokens_mask] = 0.0
|
230 |
+
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
231 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
232 |
+
|
233 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
234 |
+
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
235 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
236 |
+
|
237 |
+
# 10% of the time, we replace masked input tokens with random word
|
238 |
+
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
239 |
+
indices_random &= masked_indices & ~indices_replaced
|
240 |
+
|
241 |
+
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
242 |
+
inputs[indices_random] = random_words[indices_random]
|
243 |
+
|
244 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
245 |
+
return inputs, labels
|
246 |
+
|
247 |
+
|
248 |
+
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
249 |
+
num_samples = len(samples_idx)
|
250 |
+
samples_to_remove = num_samples % batch_size
|
251 |
+
|
252 |
+
if samples_to_remove != 0:
|
253 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
254 |
+
sections_split = num_samples // batch_size
|
255 |
+
batch_idx = np.split(samples_idx, sections_split)
|
256 |
+
return batch_idx
|
257 |
+
|
258 |
+
|
259 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
260 |
+
summary_writer.scalar("train_time", train_time, step)
|
261 |
+
|
262 |
+
train_metrics = get_metrics(train_metrics)
|
263 |
+
for key, vals in train_metrics.items():
|
264 |
+
tag = f"train_{key}"
|
265 |
+
for i, val in enumerate(vals):
|
266 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
267 |
+
|
268 |
+
|
269 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
270 |
+
for metric_name, value in eval_metrics.items():
|
271 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
272 |
+
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
# See all possible arguments in src/transformers/training_args.py
|
276 |
+
# or by passing the --help flag to this script.
|
277 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
278 |
+
|
279 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
280 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
281 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
282 |
+
# let's parse it to get our arguments.
|
283 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
284 |
+
else:
|
285 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
286 |
+
|
287 |
+
if (
|
288 |
+
os.path.exists(training_args.output_dir)
|
289 |
+
and os.listdir(training_args.output_dir)
|
290 |
+
and training_args.do_train
|
291 |
+
and not training_args.overwrite_output_dir
|
292 |
+
):
|
293 |
+
raise ValueError(
|
294 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
295 |
+
"Use --overwrite_output_dir to overcome."
|
296 |
+
)
|
297 |
+
|
298 |
+
# Setup logging
|
299 |
+
logging.basicConfig(
|
300 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
301 |
+
level="NOTSET",
|
302 |
+
datefmt="[%X]",
|
303 |
+
)
|
304 |
+
|
305 |
+
# Log on each process the small summary:
|
306 |
+
logger = logging.getLogger(__name__)
|
307 |
+
|
308 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
309 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
310 |
+
|
311 |
+
# Set seed before initializing model.
|
312 |
+
set_seed(training_args.seed)
|
313 |
+
|
314 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
315 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
316 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
317 |
+
#
|
318 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
319 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
320 |
+
#
|
321 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
322 |
+
# download the dataset.
|
323 |
+
if data_args.dataset_name is not None:
|
324 |
+
# Downloading and loading a dataset from the hub.
|
325 |
+
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
326 |
+
|
327 |
+
if "validation" not in datasets.keys():
|
328 |
+
datasets["validation"] = load_dataset(
|
329 |
+
data_args.dataset_name,
|
330 |
+
data_args.dataset_config_name,
|
331 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
332 |
+
cache_dir=model_args.cache_dir,
|
333 |
+
)
|
334 |
+
datasets["train"] = load_dataset(
|
335 |
+
data_args.dataset_name,
|
336 |
+
data_args.dataset_config_name,
|
337 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
338 |
+
cache_dir=model_args.cache_dir,
|
339 |
+
)
|
340 |
+
else:
|
341 |
+
import glob
|
342 |
+
import random
|
343 |
+
data_files = []
|
344 |
+
def add_jsonlines_dir(path, filespec):
|
345 |
+
global data_files
|
346 |
+
data_files += glob.glob(f"{path}/{filespec}")
|
347 |
+
data_files = list(set(data_files))
|
348 |
+
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
349 |
+
#add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
|
350 |
+
#add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
|
351 |
+
add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
|
352 |
+
random.Random(42).shuffle(data_files)
|
353 |
+
total = len(data_files)
|
354 |
+
print(total)
|
355 |
+
perc = 0.05
|
356 |
+
val_size = int(perc * total)
|
357 |
+
train_size = total - val_size
|
358 |
+
train = data_files[5:8]
|
359 |
+
val = data_files[1:3]
|
360 |
+
print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
|
361 |
+
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
362 |
+
datasets = load_dataset('json', data_files={'train': train, 'validation': val},cache_dir="/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723")
|
363 |
+
|
364 |
+
#from datasets import Dataset
|
365 |
+
|
366 |
+
#dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-train.arrow")
|
367 |
+
#dataset = Dataset.from_file("/home/dat/.cache/huggingface/datasets/json/default-9add402b38836560/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723/json-validation.arrow")
|
368 |
+
|
369 |
+
|
370 |
+
def mb_item(x):
|
371 |
+
return x.item() if hasattr(x, "item") else x
|
372 |
+
|
373 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
374 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
375 |
+
|
376 |
+
# Load pretrained model and tokenizer
|
377 |
+
|
378 |
+
# Distributed training:
|
379 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
380 |
+
# download model & vocab.
|
381 |
+
if model_args.config_name:
|
382 |
+
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
383 |
+
elif model_args.model_name_or_path:
|
384 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
385 |
+
else:
|
386 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
387 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
388 |
+
|
389 |
+
if model_args.tokenizer_name:
|
390 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
391 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
392 |
+
)
|
393 |
+
elif model_args.model_name_or_path:
|
394 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
395 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
396 |
+
)
|
397 |
+
else:
|
398 |
+
raise ValueError(
|
399 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
400 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
401 |
+
)
|
402 |
+
|
403 |
+
# Preprocessing the datasets.
|
404 |
+
# First we tokenize all the texts.
|
405 |
+
if training_args.do_train:
|
406 |
+
column_names = datasets["train"].column_names
|
407 |
+
else:
|
408 |
+
column_names = datasets["validation"].column_names
|
409 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
410 |
+
|
411 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
412 |
+
|
413 |
+
if data_args.line_by_line:
|
414 |
+
# When using line_by_line, we just tokenize each nonempty line.
|
415 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
416 |
+
|
417 |
+
def tokenize_function(examples):
|
418 |
+
# Remove empty lines
|
419 |
+
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
420 |
+
return tokenizer(
|
421 |
+
examples,
|
422 |
+
return_special_tokens_mask=True,
|
423 |
+
padding=padding,
|
424 |
+
truncation=True,
|
425 |
+
max_length=max_seq_length,
|
426 |
+
)
|
427 |
+
|
428 |
+
tokenized_datasets = datasets.map(
|
429 |
+
tokenize_function,
|
430 |
+
input_columns=[text_column_name],
|
431 |
+
batched=True,
|
432 |
+
num_proc=data_args.preprocessing_num_workers,
|
433 |
+
remove_columns=column_names,
|
434 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
435 |
+
)
|
436 |
+
tokenized_datasets.save_to_disk("/data/tokenized_data")
|
437 |
+
print ("save data")
|
438 |
+
else:
|
439 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
440 |
+
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
441 |
+
# efficient when it receives the `special_tokens_mask`.
|
442 |
+
def tokenize_function(examples):
|
443 |
+
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
444 |
+
|
445 |
+
tokenized_datasets = datasets.map(
|
446 |
+
tokenize_function,
|
447 |
+
batched=True,
|
448 |
+
num_proc=data_args.preprocessing_num_workers,
|
449 |
+
remove_columns=column_names,
|
450 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
451 |
+
)
|
452 |
+
|
453 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
454 |
+
# max_seq_length.
|
455 |
+
def group_texts(examples):
|
456 |
+
# Concatenate all texts.
|
457 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
458 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
459 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
460 |
+
# customize this part to your needs.
|
461 |
+
if total_length >= max_seq_length:
|
462 |
+
total_length = (total_length // max_seq_length) * max_seq_length
|
463 |
+
# Split by chunks of max_len.
|
464 |
+
result = {
|
465 |
+
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
466 |
+
for k, t in concatenated_examples.items()
|
467 |
+
}
|
468 |
+
return result
|
469 |
+
|
470 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
471 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
472 |
+
# might be slower to preprocess.
|
473 |
+
#
|
474 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
475 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
476 |
+
tokenized_datasets = tokenized_datasets.map(
|
477 |
+
group_texts,
|
478 |
+
batched=True,
|
479 |
+
num_proc=data_args.preprocessing_num_workers,
|
480 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
481 |
+
)
|
482 |
+
|
483 |
+
tokenized_datasets.save_to_disk("/data/tokenized_data")
|
484 |
+
print ("save data")
|
train_tokenizer.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import random
|
3 |
+
from tokenizers import ByteLevelBPETokenizer
|
4 |
+
from datasets import load_dataset
|
5 |
+
|
6 |
+
data_files = []
|
7 |
+
def add_jsonlines_dir(path, filespec):
|
8 |
+
global data_files
|
9 |
+
data_files += glob.glob(f"{path}/{filespec}")
|
10 |
+
data_files = list(set(data_files))
|
11 |
+
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
12 |
+
add_jsonlines_dir(f"/data/c4_cleaned2", "*.gz")
|
13 |
+
add_jsonlines_dir(f"/data/nrc_uniq_cleaned_20210223", "*.gz")
|
14 |
+
add_jsonlines_dir(f"/data/nu_uniq_cleaned_20210225", "*.gz")
|
15 |
+
random.Random(42).shuffle(data_files)
|
16 |
+
total = len(data_files)
|
17 |
+
print(total)
|
18 |
+
perc = 0.05
|
19 |
+
val_size = int(perc * total)
|
20 |
+
train_size = total - val_size
|
21 |
+
train = data_files[:train_size]
|
22 |
+
val = data_files[train_size:]
|
23 |
+
print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
|
24 |
+
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
25 |
+
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
tokenizer = ByteLevelBPETokenizer()
|
30 |
+
|
31 |
+
def batch_iterator(batch_size=1000):
|
32 |
+
for i in range(0, len(datasets), batch_size):
|
33 |
+
yield datasets["train"][i: i + batch_size]["text"]
|
34 |
+
|
35 |
+
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50358, min_frequency=2, special_tokens=[
|
36 |
+
"<s>",
|
37 |
+
"<pad>",
|
38 |
+
"</s>",
|
39 |
+
"<unk>",
|
40 |
+
"<mask>",
|
41 |
+
])
|
42 |
+
|
43 |
+
tokenizer.save("tokenizer.json")
|
wandb/debug-internal.log
CHANGED
@@ -1 +1 @@
|
|
1 |
-
run-
|
|
|
1 |
+
run-20210714_210351-1msvb4w4/logs/debug-internal.log
|
wandb/debug.log
CHANGED
@@ -1 +1 @@
|
|
1 |
-
run-
|
|
|
1 |
+
run-20210714_210351-1msvb4w4/logs/debug.log
|
wandb/latest-run
CHANGED
@@ -1 +1 @@
|
|
1 |
-
run-
|
|
|
1 |
+
run-20210714_210351-1msvb4w4
|
wandb/run-20210713_010630-14xhiyhf/files/output.log
CHANGED
@@ -16222,3 +16222,12 @@ Training...: 64%|████████████▊ | 59500/92767 [9
|
|
16222 |
|
16223 |
Training...: 65%|████████████▉ | 60000/92767 [9:35:07<5:11:39, 1.75it/s]
|
16224 |
git-lfs/2.9.2 (GitHub; linux amd64; go 1.13.5)92767 [9:35:07<5:11:39, 1.75it/s]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16222 |
|
16223 |
Training...: 65%|████████████▉ | 60000/92767 [9:35:07<5:11:39, 1.75it/s]
|
16224 |
git-lfs/2.9.2 (GitHub; linux amd64; go 1.13.5)92767 [9:35:07<5:11:39, 1.75it/s]
|
16225 |
+
[10:43:30] - DEBUG - huggingface_hub.repository - [Repository] is a valid git repo
|
16226 |
+
[10:44:08] - INFO - huggingface_hub.repository - Uploading LFS objects: 100% (3/3), 1.0 GB | 43 MB/s, done.
|
16227 |
+
[10:44:09] - INFO - absl - Saving checkpoint at step: 60000
|
16228 |
+
tcmalloc: large alloc 1363968000 bytes == 0x2ed6e2000 @ 0x7f170bb8c680 0x7f170bbacbdd 0x7f143fe0e20d 0x7f143fe1c340 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe1be87 0x7f143fe17bd3 0x7f143fe181fe 0x504d56 0x56acb6 0x568d9a 0x5f5b33 0x56bc9b 0x5f5956 0x56aadf 0x5f5956 0x56fb87 0x568d9a 0x5f5b33 0x56bc9b 0x568d9a 0x68cdc7
|
16229 |
+
[10:44:13] - INFO - absl - Saved checkpoint at checkpoint_60000
|
16230 |
+
|
16231 |
+
|
16232 |
+
|
16233 |
+
|
wandb/run-20210713_010630-14xhiyhf/logs/debug-internal.log
CHANGED
@@ -22396,3 +22396,27 @@
|
|
22396 |
2021-07-13 10:43:28,960 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/wandb-summary.json
|
22397 |
2021-07-13 10:43:29,961 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22398 |
2021-07-13 10:43:31,962 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22396 |
2021-07-13 10:43:28,960 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/wandb-summary.json
|
22397 |
2021-07-13 10:43:29,961 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22398 |
2021-07-13 10:43:31,962 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22399 |
+
2021-07-13 10:43:36,601 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
|
22400 |
+
2021-07-13 10:43:36,601 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
|
22401 |
+
2021-07-13 10:43:51,734 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
|
22402 |
+
2021-07-13 10:43:51,734 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
|
22403 |
+
2021-07-13 10:43:55,447 DEBUG SenderThread:332390 [sender.py:send():179] send: stats
|
22404 |
+
2021-07-13 10:44:06,865 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
|
22405 |
+
2021-07-13 10:44:06,866 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
|
22406 |
+
2021-07-13 10:44:09,977 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22407 |
+
2021-07-13 10:44:14,979 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22408 |
+
2021-07-13 10:44:16,979 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22409 |
+
2021-07-13 10:44:18,980 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22410 |
+
2021-07-13 10:44:20,981 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22411 |
+
2021-07-13 10:44:22,005 DEBUG HandlerThread:332390 [handler.py:handle_request():124] handle_request: stop_status
|
22412 |
+
2021-07-13 10:44:22,005 DEBUG SenderThread:332390 [sender.py:send_request():193] send_request: stop_status
|
22413 |
+
2021-07-13 10:44:22,982 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22414 |
+
2021-07-13 10:44:23,482 WARNING MainThread:332390 [internal.py:wandb_internal():147] Internal process interrupt: 1
|
22415 |
+
2021-07-13 10:44:24,702 WARNING MainThread:332390 [internal.py:wandb_internal():147] Internal process interrupt: 2
|
22416 |
+
2021-07-13 10:44:24,703 ERROR MainThread:332390 [internal.py:wandb_internal():150] Internal process interrupted.
|
22417 |
+
2021-07-13 10:44:24,982 INFO Thread-8 :332390 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/files/output.log
|
22418 |
+
2021-07-13 10:44:25,021 INFO SenderThread:332390 [sender.py:finish():945] shutting down sender
|
22419 |
+
2021-07-13 10:44:25,022 INFO SenderThread:332390 [dir_watcher.py:finish():282] shutting down directory watcher
|
22420 |
+
2021-07-13 10:44:25,022 INFO WriterThread:332390 [datastore.py:close():288] close: /home/dat/pino-roberta-base/wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb
|
22421 |
+
2021-07-13 10:44:25,022 INFO HandlerThread:332390 [handler.py:finish():638] shutting down handler
|
22422 |
+
2021-07-13 10:44:25,103 INFO MainThread:332390 [internal.py:handle_exit():78] Internal process exited
|
wandb/run-20210713_010630-14xhiyhf/logs/debug.log
CHANGED
@@ -23,3 +23,5 @@ config: {}
|
|
23 |
2021-07-13 01:06:32,711 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': False, 'do_eval': False, 'do_predict': False, 'evaluation_strategy': 'IntervalStrategy.NO', 'prediction_loss_only': False, 'per_device_train_batch_size': 2, 'per_device_eval_batch_size': 2, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0095, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 5.0, 'max_steps': -1, 'lr_scheduler_type': 'SchedulerType.LINEAR', 'warmup_ratio': 0.0, 'warmup_steps': 5000, 'log_level': -1, 'log_level_replica': -1, 'log_on_each_node': True, 'logging_dir': './runs/Jul13_01-05-41_t1v-n-f5c06ea1-w-0', 'logging_strategy': 'IntervalStrategy.STEPS', 'logging_first_step': False, 'logging_steps': 500, 'save_strategy': 'IntervalStrategy.STEPS', 'save_steps': 20000, 'save_total_limit': 5, 'save_on_each_node': False, 'no_cuda': False, 'seed': 42, 'fp16': False, 'fp16_opt_level': 'O1', 'fp16_backend': 'auto', 'fp16_full_eval': False, 'local_rank': -1, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 92768, 'dataloader_num_workers': 0, 'past_index': -1, 'run_name': './', 'disable_tqdm': False, 'remove_unused_columns': True, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'sharded_ddp': [], 'deepspeed': None, 'label_smoothing_factor': 0.0, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'dataloader_pin_memory': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': True, 'resume_from_checkpoint': None, 'push_to_hub_model_id': '', 'push_to_hub_organization': None, 'push_to_hub_token': None, 'mp_parameters': ''}
|
24 |
2021-07-13 01:06:32,712 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'model_name_or_path': None, 'model_type': 'big_bird', 'config_name': './', 'tokenizer_name': './', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
|
25 |
2021-07-13 01:06:32,714 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'dataset_name': None, 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 4096, 'preprocessing_num_workers': 64, 'mlm_probability': 0.15, 'pad_to_max_length': False, 'line_by_line': False}
|
|
|
|
|
|
23 |
2021-07-13 01:06:32,711 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': False, 'do_eval': False, 'do_predict': False, 'evaluation_strategy': 'IntervalStrategy.NO', 'prediction_loss_only': False, 'per_device_train_batch_size': 2, 'per_device_eval_batch_size': 2, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0095, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 5.0, 'max_steps': -1, 'lr_scheduler_type': 'SchedulerType.LINEAR', 'warmup_ratio': 0.0, 'warmup_steps': 5000, 'log_level': -1, 'log_level_replica': -1, 'log_on_each_node': True, 'logging_dir': './runs/Jul13_01-05-41_t1v-n-f5c06ea1-w-0', 'logging_strategy': 'IntervalStrategy.STEPS', 'logging_first_step': False, 'logging_steps': 500, 'save_strategy': 'IntervalStrategy.STEPS', 'save_steps': 20000, 'save_total_limit': 5, 'save_on_each_node': False, 'no_cuda': False, 'seed': 42, 'fp16': False, 'fp16_opt_level': 'O1', 'fp16_backend': 'auto', 'fp16_full_eval': False, 'local_rank': -1, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 92768, 'dataloader_num_workers': 0, 'past_index': -1, 'run_name': './', 'disable_tqdm': False, 'remove_unused_columns': True, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'sharded_ddp': [], 'deepspeed': None, 'label_smoothing_factor': 0.0, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'dataloader_pin_memory': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': True, 'resume_from_checkpoint': None, 'push_to_hub_model_id': '', 'push_to_hub_organization': None, 'push_to_hub_token': None, 'mp_parameters': ''}
|
24 |
2021-07-13 01:06:32,712 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'model_name_or_path': None, 'model_type': 'big_bird', 'config_name': './', 'tokenizer_name': './', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
|
25 |
2021-07-13 01:06:32,714 INFO MainThread:330819 [wandb_run.py:_config_callback():872] config_cb None None {'dataset_name': None, 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 4096, 'preprocessing_num_workers': 64, 'mlm_probability': 0.15, 'pad_to_max_length': False, 'line_by_line': False}
|
26 |
+
2021-07-13 10:44:23,634 INFO MainThread:330819 [wandb_run.py:_atexit_cleanup():1593] got exitcode: 255
|
27 |
+
2021-07-13 10:44:23,634 INFO MainThread:330819 [wandb_run.py:_restore():1565] restore
|
wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb
CHANGED
Binary files a/wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb and b/wandb/run-20210713_010630-14xhiyhf/run-14xhiyhf.wandb differ
|
|
wandb/run-20210713_104745-1rl2j7or/files/config.yaml
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb_version: 1
|
2 |
+
|
3 |
+
_wandb:
|
4 |
+
desc: null
|
5 |
+
value:
|
6 |
+
cli_version: 0.10.33
|
7 |
+
framework: huggingface
|
8 |
+
huggingface_version: 4.9.0.dev0
|
9 |
+
is_jupyter_run: false
|
10 |
+
is_kaggle_kernel: false
|
11 |
+
python_version: 3.8.10
|
12 |
+
t:
|
13 |
+
1:
|
14 |
+
- 3
|
15 |
+
- 11
|
16 |
+
4: 3.8.10
|
17 |
+
5: 0.10.33
|
18 |
+
6: 4.9.0.dev0
|
19 |
+
8:
|
20 |
+
- 5
|
21 |
+
adafactor:
|
22 |
+
desc: null
|
23 |
+
value: false
|
24 |
+
adam_beta1:
|
25 |
+
desc: null
|
26 |
+
value: 0.9
|
27 |
+
adam_beta2:
|
28 |
+
desc: null
|
29 |
+
value: 0.98
|
30 |
+
adam_epsilon:
|
31 |
+
desc: null
|
32 |
+
value: 1.0e-08
|
33 |
+
cache_dir:
|
34 |
+
desc: null
|
35 |
+
value: null
|
36 |
+
config_name:
|
37 |
+
desc: null
|
38 |
+
value: ./
|
39 |
+
dataloader_drop_last:
|
40 |
+
desc: null
|
41 |
+
value: false
|
42 |
+
dataloader_num_workers:
|
43 |
+
desc: null
|
44 |
+
value: 0
|
45 |
+
dataloader_pin_memory:
|
46 |
+
desc: null
|
47 |
+
value: true
|
48 |
+
dataset_config_name:
|
49 |
+
desc: null
|
50 |
+
value: null
|
51 |
+
dataset_name:
|
52 |
+
desc: null
|
53 |
+
value: null
|
54 |
+
ddp_find_unused_parameters:
|
55 |
+
desc: null
|
56 |
+
value: null
|
57 |
+
debug:
|
58 |
+
desc: null
|
59 |
+
value: []
|
60 |
+
deepspeed:
|
61 |
+
desc: null
|
62 |
+
value: null
|
63 |
+
disable_tqdm:
|
64 |
+
desc: null
|
65 |
+
value: false
|
66 |
+
do_eval:
|
67 |
+
desc: null
|
68 |
+
value: false
|
69 |
+
do_predict:
|
70 |
+
desc: null
|
71 |
+
value: false
|
72 |
+
do_train:
|
73 |
+
desc: null
|
74 |
+
value: false
|
75 |
+
dtype:
|
76 |
+
desc: null
|
77 |
+
value: float32
|
78 |
+
eval_accumulation_steps:
|
79 |
+
desc: null
|
80 |
+
value: null
|
81 |
+
eval_steps:
|
82 |
+
desc: null
|
83 |
+
value: 100001
|
84 |
+
evaluation_strategy:
|
85 |
+
desc: null
|
86 |
+
value: IntervalStrategy.NO
|
87 |
+
fp16:
|
88 |
+
desc: null
|
89 |
+
value: false
|
90 |
+
fp16_backend:
|
91 |
+
desc: null
|
92 |
+
value: auto
|
93 |
+
fp16_full_eval:
|
94 |
+
desc: null
|
95 |
+
value: false
|
96 |
+
fp16_opt_level:
|
97 |
+
desc: null
|
98 |
+
value: O1
|
99 |
+
gradient_accumulation_steps:
|
100 |
+
desc: null
|
101 |
+
value: 2
|
102 |
+
greater_is_better:
|
103 |
+
desc: null
|
104 |
+
value: null
|
105 |
+
group_by_length:
|
106 |
+
desc: null
|
107 |
+
value: false
|
108 |
+
ignore_data_skip:
|
109 |
+
desc: null
|
110 |
+
value: false
|
111 |
+
label_names:
|
112 |
+
desc: null
|
113 |
+
value: null
|
114 |
+
label_smoothing_factor:
|
115 |
+
desc: null
|
116 |
+
value: 0.0
|
117 |
+
learning_rate:
|
118 |
+
desc: null
|
119 |
+
value: 5.0e-05
|
120 |
+
length_column_name:
|
121 |
+
desc: null
|
122 |
+
value: length
|
123 |
+
line_by_line:
|
124 |
+
desc: null
|
125 |
+
value: false
|
126 |
+
load_best_model_at_end:
|
127 |
+
desc: null
|
128 |
+
value: false
|
129 |
+
local_rank:
|
130 |
+
desc: null
|
131 |
+
value: -1
|
132 |
+
log_level:
|
133 |
+
desc: null
|
134 |
+
value: -1
|
135 |
+
log_level_replica:
|
136 |
+
desc: null
|
137 |
+
value: -1
|
138 |
+
log_on_each_node:
|
139 |
+
desc: null
|
140 |
+
value: true
|
141 |
+
logging_dir:
|
142 |
+
desc: null
|
143 |
+
value: ./runs/Jul13_10-47-16_t1v-n-f5c06ea1-w-0
|
144 |
+
logging_first_step:
|
145 |
+
desc: null
|
146 |
+
value: false
|
147 |
+
logging_steps:
|
148 |
+
desc: null
|
149 |
+
value: 50
|
150 |
+
logging_strategy:
|
151 |
+
desc: null
|
152 |
+
value: IntervalStrategy.STEPS
|
153 |
+
lr_scheduler_type:
|
154 |
+
desc: null
|
155 |
+
value: SchedulerType.LINEAR
|
156 |
+
max_grad_norm:
|
157 |
+
desc: null
|
158 |
+
value: 1.0
|
159 |
+
max_seq_length:
|
160 |
+
desc: null
|
161 |
+
value: 4096
|
162 |
+
max_steps:
|
163 |
+
desc: null
|
164 |
+
value: -1
|
165 |
+
metric_for_best_model:
|
166 |
+
desc: null
|
167 |
+
value: null
|
168 |
+
mlm_probability:
|
169 |
+
desc: null
|
170 |
+
value: 0.15
|
171 |
+
model_name_or_path:
|
172 |
+
desc: null
|
173 |
+
value: null
|
174 |
+
model_type:
|
175 |
+
desc: null
|
176 |
+
value: big_bird
|
177 |
+
mp_parameters:
|
178 |
+
desc: null
|
179 |
+
value: ''
|
180 |
+
no_cuda:
|
181 |
+
desc: null
|
182 |
+
value: false
|
183 |
+
num_train_epochs:
|
184 |
+
desc: null
|
185 |
+
value: 5.0
|
186 |
+
output_dir:
|
187 |
+
desc: null
|
188 |
+
value: ./
|
189 |
+
overwrite_cache:
|
190 |
+
desc: null
|
191 |
+
value: false
|
192 |
+
overwrite_output_dir:
|
193 |
+
desc: null
|
194 |
+
value: true
|
195 |
+
pad_to_max_length:
|
196 |
+
desc: null
|
197 |
+
value: false
|
198 |
+
past_index:
|
199 |
+
desc: null
|
200 |
+
value: -1
|
201 |
+
per_device_eval_batch_size:
|
202 |
+
desc: null
|
203 |
+
value: 2
|
204 |
+
per_device_train_batch_size:
|
205 |
+
desc: null
|
206 |
+
value: 2
|
207 |
+
per_gpu_eval_batch_size:
|
208 |
+
desc: null
|
209 |
+
value: null
|
210 |
+
per_gpu_train_batch_size:
|
211 |
+
desc: null
|
212 |
+
value: null
|
213 |
+
prediction_loss_only:
|
214 |
+
desc: null
|
215 |
+
value: false
|
216 |
+
preprocessing_num_workers:
|
217 |
+
desc: null
|
218 |
+
value: 64
|
219 |
+
push_to_hub:
|
220 |
+
desc: null
|
221 |
+
value: true
|
222 |
+
push_to_hub_model_id:
|
223 |
+
desc: null
|
224 |
+
value: ''
|
225 |
+
push_to_hub_organization:
|
226 |
+
desc: null
|
227 |
+
value: null
|
228 |
+
push_to_hub_token:
|
229 |
+
desc: null
|
230 |
+
value: null
|
231 |
+
remove_unused_columns:
|
232 |
+
desc: null
|
233 |
+
value: true
|
234 |
+
report_to:
|
235 |
+
desc: null
|
236 |
+
value:
|
237 |
+
- tensorboard
|
238 |
+
- wandb
|
239 |
+
resume_from_checkpoint:
|
240 |
+
desc: null
|
241 |
+
value: null
|
242 |
+
run_name:
|
243 |
+
desc: null
|
244 |
+
value: ./
|
245 |
+
save_on_each_node:
|
246 |
+
desc: null
|
247 |
+
value: false
|
248 |
+
save_steps:
|
249 |
+
desc: null
|
250 |
+
value: 20000
|
251 |
+
save_strategy:
|
252 |
+
desc: null
|
253 |
+
value: IntervalStrategy.STEPS
|
254 |
+
save_total_limit:
|
255 |
+
desc: null
|
256 |
+
value: 5
|
257 |
+
seed:
|
258 |
+
desc: null
|
259 |
+
value: 42
|
260 |
+
sharded_ddp:
|
261 |
+
desc: null
|
262 |
+
value: []
|
263 |
+
skip_memory_metrics:
|
264 |
+
desc: null
|
265 |
+
value: true
|
266 |
+
tokenizer_name:
|
267 |
+
desc: null
|
268 |
+
value: ./
|
269 |
+
tpu_metrics_debug:
|
270 |
+
desc: null
|
271 |
+
value: false
|
272 |
+
tpu_num_cores:
|
273 |
+
desc: null
|
274 |
+
value: null
|
275 |
+
train_file:
|
276 |
+
desc: null
|
277 |
+
value: null
|
278 |
+
train_ref_file:
|
279 |
+
desc: null
|
280 |
+
value: null
|
281 |
+
use_fast_tokenizer:
|
282 |
+
desc: null
|
283 |
+
value: true
|
284 |
+
use_legacy_prediction_loop:
|
285 |
+
desc: null
|
286 |
+
value: false
|
287 |
+
validation_file:
|
288 |
+
desc: null
|
289 |
+
value: null
|
290 |
+
validation_ref_file:
|
291 |
+
desc: null
|
292 |
+
value: null
|
293 |
+
validation_split_percentage:
|
294 |
+
desc: null
|
295 |
+
value: 5
|
296 |
+
warmup_ratio:
|
297 |
+
desc: null
|
298 |
+
value: 0.0
|
299 |
+
warmup_steps:
|
300 |
+
desc: null
|
301 |
+
value: 10
|
302 |
+
weight_decay:
|
303 |
+
desc: null
|
304 |
+
value: 0.0095
|
wandb/run-20210713_104745-1rl2j7or/files/output.log
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3114: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
|
2 |
+
lax._check_user_dtype_supported(dtype, "zeros")
|
3 |
+
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
|
4 |
+
warnings.warn(
|
5 |
+
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
|
6 |
+
warnings.warn(
|
7 |
+
Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s]
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
Training...: 60%|██████████████████ | 50/83 [01:32<00:23, 1.40it/s]
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
Epoch ... (1/5): 20%|█████▍ | 1/5 [02:00<08:02, 120.70s/it]
|
20 |
+
|
21 |
+
Training...: 16%|████▋ | 13/83 [00:07<00:53, 1.32it/s]
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
Training...: 78%|███████████████████████▍ | 65/83 [00:44<00:24, 1.38s/it]
|
30 |
+
|
31 |
+
Epoch ... (1/5): 40%|███████████▏ | 2/5 [03:06<04:25, 88.56s/it]
|
32 |
+
|
33 |
+
Training...: 22%|██████▉ | 18/83 [00:01<00:07, 9.26it/s]
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
Epoch ... (1/5): 60%|████████████████▊ | 3/5 [04:12<02:36, 78.08s/it]s]
|
43 |
+
Step... (150 | Loss: 7.8581647872924805, Learning Rate: 2.256410152767785e-05)
|
44 |
+
|
45 |
+
Training...: 33%|███████████ | 27/83 [00:03<00:06, 9.31it/s]
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
Training...: 93%|███████████████████████████████▌ | 77/83 [00:32<00:04, 1.41it/s]
|
54 |
+
|
55 |
+
Epoch ... (1/5): 80%|██████████████████████▍ | 4/5 [05:18<01:13, 73.25s/it]/it]
|
56 |
+
|
57 |
+
|
wandb/run-20210713_104745-1rl2j7or/files/requirements.txt
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==0.13.0
|
2 |
+
aiohttp==3.7.4.post0
|
3 |
+
astunparse==1.6.3
|
4 |
+
async-timeout==3.0.1
|
5 |
+
attrs==21.2.0
|
6 |
+
cachetools==4.2.2
|
7 |
+
certifi==2021.5.30
|
8 |
+
chardet==4.0.0
|
9 |
+
chex==0.0.8
|
10 |
+
click==8.0.1
|
11 |
+
configparser==5.0.2
|
12 |
+
cycler==0.10.0
|
13 |
+
datasets==1.9.1.dev0
|
14 |
+
dill==0.3.4
|
15 |
+
dm-tree==0.1.6
|
16 |
+
docker-pycreds==0.4.0
|
17 |
+
filelock==3.0.12
|
18 |
+
flatbuffers==1.12
|
19 |
+
flax==0.3.4
|
20 |
+
fsspec==2021.6.1
|
21 |
+
gast==0.4.0
|
22 |
+
gitdb==4.0.7
|
23 |
+
gitpython==3.1.18
|
24 |
+
google-auth-oauthlib==0.4.4
|
25 |
+
google-auth==1.32.1
|
26 |
+
google-pasta==0.2.0
|
27 |
+
grpcio==1.34.1
|
28 |
+
h5py==3.1.0
|
29 |
+
huggingface-hub==0.0.12
|
30 |
+
idna==2.10
|
31 |
+
jax==0.2.16
|
32 |
+
jaxlib==0.1.68
|
33 |
+
joblib==1.0.1
|
34 |
+
keras-nightly==2.5.0.dev2021032900
|
35 |
+
keras-preprocessing==1.1.2
|
36 |
+
kiwisolver==1.3.1
|
37 |
+
libtpu-nightly==0.1.dev20210615
|
38 |
+
markdown==3.3.4
|
39 |
+
matplotlib==3.4.2
|
40 |
+
msgpack==1.0.2
|
41 |
+
multidict==5.1.0
|
42 |
+
multiprocess==0.70.12.2
|
43 |
+
numpy==1.19.5
|
44 |
+
oauthlib==3.1.1
|
45 |
+
opt-einsum==3.3.0
|
46 |
+
optax==0.0.9
|
47 |
+
packaging==21.0
|
48 |
+
pandas==1.3.0
|
49 |
+
pathtools==0.1.2
|
50 |
+
pillow==8.3.1
|
51 |
+
pip==20.0.2
|
52 |
+
pkg-resources==0.0.0
|
53 |
+
promise==2.3
|
54 |
+
protobuf==3.17.3
|
55 |
+
psutil==5.8.0
|
56 |
+
pyarrow==4.0.1
|
57 |
+
pyasn1-modules==0.2.8
|
58 |
+
pyasn1==0.4.8
|
59 |
+
pyparsing==2.4.7
|
60 |
+
python-dateutil==2.8.1
|
61 |
+
pytz==2021.1
|
62 |
+
pyyaml==5.4.1
|
63 |
+
regex==2021.7.6
|
64 |
+
requests-oauthlib==1.3.0
|
65 |
+
requests==2.25.1
|
66 |
+
rsa==4.7.2
|
67 |
+
sacremoses==0.0.45
|
68 |
+
scipy==1.7.0
|
69 |
+
sentry-sdk==1.3.0
|
70 |
+
setuptools==44.0.0
|
71 |
+
shortuuid==1.0.1
|
72 |
+
six==1.15.0
|
73 |
+
smmap==4.0.0
|
74 |
+
subprocess32==3.5.4
|
75 |
+
tensorboard-data-server==0.6.1
|
76 |
+
tensorboard-plugin-wit==1.8.0
|
77 |
+
tensorboard==2.5.0
|
78 |
+
tensorflow-estimator==2.5.0
|
79 |
+
tensorflow==2.5.0
|
80 |
+
termcolor==1.1.0
|
81 |
+
tokenizers==0.10.3
|
82 |
+
toolz==0.11.1
|
83 |
+
tqdm==4.61.2
|
84 |
+
transformers==4.9.0.dev0
|
85 |
+
typing-extensions==3.7.4.3
|
86 |
+
urllib3==1.26.6
|
87 |
+
wandb==0.10.33
|
88 |
+
werkzeug==2.0.1
|
89 |
+
wheel==0.36.2
|
90 |
+
wrapt==1.12.1
|
91 |
+
xxhash==2.0.2
|
92 |
+
yarl==1.6.3
|
wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
|
3 |
+
"python": "3.8.10",
|
4 |
+
"heartbeatAt": "2021-07-13T10:47:47.215746",
|
5 |
+
"startedAt": "2021-07-13T10:47:45.129053",
|
6 |
+
"docker": null,
|
7 |
+
"cpu_count": 96,
|
8 |
+
"cuda": null,
|
9 |
+
"args": [
|
10 |
+
"--push_to_hub",
|
11 |
+
"--output_dir=./",
|
12 |
+
"--model_type=big_bird",
|
13 |
+
"--config_name=./",
|
14 |
+
"--tokenizer_name=./",
|
15 |
+
"--max_seq_length=4096",
|
16 |
+
"--weight_decay=0.0095",
|
17 |
+
"--warmup_steps=10",
|
18 |
+
"--overwrite_output_dir",
|
19 |
+
"--adam_beta1=0.9",
|
20 |
+
"--adam_beta2=0.98",
|
21 |
+
"--logging_steps=50",
|
22 |
+
"--eval_steps=100001",
|
23 |
+
"--num_train_epochs=5",
|
24 |
+
"--preprocessing_num_workers=64",
|
25 |
+
"--save_steps=20000",
|
26 |
+
"--learning_rate=5e-5",
|
27 |
+
"--per_device_train_batch_size=2",
|
28 |
+
"--per_device_eval_batch_size=2",
|
29 |
+
"--save_total_limit=5",
|
30 |
+
"--gradient_accumulation_steps=2"
|
31 |
+
],
|
32 |
+
"state": "running",
|
33 |
+
"program": "./run_mlm_flax.py",
|
34 |
+
"codePath": "run_mlm_flax.py",
|
35 |
+
"git": {
|
36 |
+
"remote": "https://huggingface.co/flax-community/pino-roberta-base",
|
37 |
+
"commit": "bc11ccfe77236f87575711b26034b9751449de4b"
|
38 |
+
},
|
39 |
+
"email": null,
|
40 |
+
"root": "/home/dat/pino-roberta-base",
|
41 |
+
"host": "t1v-n-f5c06ea1-w-0",
|
42 |
+
"username": "dat",
|
43 |
+
"executable": "/home/dat/pino/bin/python"
|
44 |
+
}
|
wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"training_step": 200, "learning_rate": 1.0769229447760154e-05, "train_loss": 7.618040084838867, "_runtime": 333, "_timestamp": 1626173598, "_step": 6}
|
wandb/run-20210713_104745-1rl2j7or/logs/debug-internal.log
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2021-07-13 10:47:45,828 INFO MainThread:342403 [internal.py:wandb_internal():88] W&B internal server running at pid: 342403, started at: 2021-07-13 10:47:45.828158
|
2 |
+
2021-07-13 10:47:45,830 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: check_version
|
3 |
+
2021-07-13 10:47:45,830 INFO WriterThread:342403 [datastore.py:open_for_write():80] open: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb
|
4 |
+
2021-07-13 10:47:45,831 DEBUG SenderThread:342403 [sender.py:send():179] send: header
|
5 |
+
2021-07-13 10:47:45,831 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: check_version
|
6 |
+
2021-07-13 10:47:45,871 DEBUG SenderThread:342403 [sender.py:send():179] send: run
|
7 |
+
2021-07-13 10:47:46,041 INFO SenderThread:342403 [dir_watcher.py:__init__():168] watching files in: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files
|
8 |
+
2021-07-13 10:47:46,041 INFO SenderThread:342403 [sender.py:_start_run_threads():716] run started: 1rl2j7or with start time 1626173265
|
9 |
+
2021-07-13 10:47:46,041 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
10 |
+
2021-07-13 10:47:46,041 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: run_start
|
11 |
+
2021-07-13 10:47:46,042 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
12 |
+
2021-07-13 10:47:47,043 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
13 |
+
2021-07-13 10:47:47,215 DEBUG HandlerThread:342403 [meta.py:__init__():39] meta init
|
14 |
+
2021-07-13 10:47:47,215 DEBUG HandlerThread:342403 [meta.py:__init__():53] meta init done
|
15 |
+
2021-07-13 10:47:47,215 DEBUG HandlerThread:342403 [meta.py:probe():210] probe
|
16 |
+
2021-07-13 10:47:47,217 DEBUG HandlerThread:342403 [meta.py:_setup_git():200] setup git
|
17 |
+
2021-07-13 10:47:47,250 DEBUG HandlerThread:342403 [meta.py:_setup_git():207] setup git done
|
18 |
+
2021-07-13 10:47:47,250 DEBUG HandlerThread:342403 [meta.py:_save_pip():57] save pip
|
19 |
+
2021-07-13 10:47:47,251 DEBUG HandlerThread:342403 [meta.py:_save_pip():71] save pip done
|
20 |
+
2021-07-13 10:47:47,251 DEBUG HandlerThread:342403 [meta.py:probe():252] probe done
|
21 |
+
2021-07-13 10:47:47,255 DEBUG SenderThread:342403 [sender.py:send():179] send: files
|
22 |
+
2021-07-13 10:47:47,255 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-metadata.json with policy now
|
23 |
+
2021-07-13 10:47:47,262 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
24 |
+
2021-07-13 10:47:47,262 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
25 |
+
2021-07-13 10:47:47,394 DEBUG SenderThread:342403 [sender.py:send():179] send: config
|
26 |
+
2021-07-13 10:47:47,394 DEBUG SenderThread:342403 [sender.py:send():179] send: config
|
27 |
+
2021-07-13 10:47:47,394 DEBUG SenderThread:342403 [sender.py:send():179] send: config
|
28 |
+
2021-07-13 10:47:47,719 INFO Thread-11 :342403 [upload_job.py:push():137] Uploaded file /tmp/tmpta17r5ywwandb/1f1555en-wandb-metadata.json
|
29 |
+
2021-07-13 10:47:48,042 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json
|
30 |
+
2021-07-13 10:47:48,042 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/requirements.txt
|
31 |
+
2021-07-13 10:47:48,042 INFO Thread-8 :342403 [dir_watcher.py:_on_file_created():216] file/dir created: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
32 |
+
2021-07-13 10:48:02,047 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
33 |
+
2021-07-13 10:48:02,398 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
34 |
+
2021-07-13 10:48:02,398 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
35 |
+
2021-07-13 10:48:04,048 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
36 |
+
2021-07-13 10:48:15,296 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
37 |
+
2021-07-13 10:48:17,054 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/config.yaml
|
38 |
+
2021-07-13 10:48:17,555 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
39 |
+
2021-07-13 10:48:17,556 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
40 |
+
2021-07-13 10:48:32,709 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
41 |
+
2021-07-13 10:48:32,710 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
42 |
+
2021-07-13 10:48:45,371 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
43 |
+
2021-07-13 10:48:47,840 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
44 |
+
2021-07-13 10:48:47,840 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
45 |
+
2021-07-13 10:49:02,980 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
46 |
+
2021-07-13 10:49:02,980 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
47 |
+
2021-07-13 10:49:15,445 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
48 |
+
2021-07-13 10:49:18,113 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
49 |
+
2021-07-13 10:49:18,113 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
50 |
+
2021-07-13 10:49:24,080 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
51 |
+
2021-07-13 10:49:26,080 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
52 |
+
2021-07-13 10:49:28,081 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
53 |
+
2021-07-13 10:49:30,082 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
54 |
+
2021-07-13 10:49:32,083 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
55 |
+
2021-07-13 10:49:33,242 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
56 |
+
2021-07-13 10:49:33,243 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
57 |
+
2021-07-13 10:49:34,084 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
58 |
+
2021-07-13 10:49:36,084 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
59 |
+
2021-07-13 10:49:45,514 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
60 |
+
2021-07-13 10:49:48,375 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
61 |
+
2021-07-13 10:49:48,375 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
62 |
+
2021-07-13 10:49:58,179 DEBUG SenderThread:342403 [sender.py:send():179] send: history
|
63 |
+
2021-07-13 10:49:58,180 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
64 |
+
2021-07-13 10:49:58,180 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
65 |
+
2021-07-13 10:49:59,093 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
66 |
+
2021-07-13 10:50:00,093 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
67 |
+
2021-07-13 10:50:02,094 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
68 |
+
2021-07-13 10:50:03,510 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
69 |
+
2021-07-13 10:50:03,510 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
70 |
+
2021-07-13 10:50:04,095 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
71 |
+
2021-07-13 10:50:15,583 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
72 |
+
2021-07-13 10:50:18,643 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
73 |
+
2021-07-13 10:50:18,643 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
74 |
+
2021-07-13 10:50:24,102 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
75 |
+
2021-07-13 10:50:28,758 DEBUG SenderThread:342403 [sender.py:send():179] send: history
|
76 |
+
2021-07-13 10:50:28,759 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
77 |
+
2021-07-13 10:50:28,763 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
78 |
+
2021-07-13 10:50:29,104 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
79 |
+
2021-07-13 10:50:30,105 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
80 |
+
2021-07-13 10:50:32,106 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
81 |
+
2021-07-13 10:50:33,775 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
82 |
+
2021-07-13 10:50:33,776 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
83 |
+
2021-07-13 10:50:34,107 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
84 |
+
2021-07-13 10:50:36,107 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
85 |
+
2021-07-13 10:50:38,108 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
86 |
+
2021-07-13 10:50:40,109 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
87 |
+
2021-07-13 10:50:42,110 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
88 |
+
2021-07-13 10:50:45,653 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
89 |
+
2021-07-13 10:50:48,905 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
90 |
+
2021-07-13 10:50:48,906 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
91 |
+
2021-07-13 10:51:04,035 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
92 |
+
2021-07-13 10:51:04,035 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
93 |
+
2021-07-13 10:51:04,964 DEBUG SenderThread:342403 [sender.py:send():179] send: history
|
94 |
+
2021-07-13 10:51:04,964 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
95 |
+
2021-07-13 10:51:04,964 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
96 |
+
2021-07-13 10:51:05,119 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
97 |
+
2021-07-13 10:51:06,119 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
98 |
+
2021-07-13 10:51:08,120 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
99 |
+
2021-07-13 10:51:15,726 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
100 |
+
2021-07-13 10:51:19,168 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
101 |
+
2021-07-13 10:51:19,168 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
102 |
+
2021-07-13 10:51:24,126 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
103 |
+
2021-07-13 10:51:26,127 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
104 |
+
2021-07-13 10:51:34,303 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
105 |
+
2021-07-13 10:51:34,303 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
106 |
+
2021-07-13 10:51:35,557 DEBUG SenderThread:342403 [sender.py:send():179] send: history
|
107 |
+
2021-07-13 10:51:35,558 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
108 |
+
2021-07-13 10:51:35,558 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
109 |
+
2021-07-13 10:51:36,131 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
110 |
+
2021-07-13 10:51:36,132 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
111 |
+
2021-07-13 10:51:38,132 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
112 |
+
2021-07-13 10:51:40,133 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
113 |
+
2021-07-13 10:51:42,134 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
114 |
+
2021-07-13 10:51:44,135 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
115 |
+
2021-07-13 10:51:45,797 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
116 |
+
2021-07-13 10:51:46,136 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
117 |
+
2021-07-13 10:51:48,137 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
118 |
+
2021-07-13 10:51:49,438 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
119 |
+
2021-07-13 10:51:49,438 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
120 |
+
2021-07-13 10:51:50,137 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
121 |
+
2021-07-13 10:52:04,579 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
122 |
+
2021-07-13 10:52:04,580 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
123 |
+
2021-07-13 10:52:11,761 DEBUG SenderThread:342403 [sender.py:send():179] send: history
|
124 |
+
2021-07-13 10:52:11,762 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
125 |
+
2021-07-13 10:52:11,763 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
126 |
+
2021-07-13 10:52:12,146 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
127 |
+
2021-07-13 10:52:14,147 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
128 |
+
2021-07-13 10:52:15,867 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
129 |
+
2021-07-13 10:52:19,709 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
130 |
+
2021-07-13 10:52:19,710 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
131 |
+
2021-07-13 10:52:24,150 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
132 |
+
2021-07-13 10:52:26,151 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
133 |
+
2021-07-13 10:52:34,838 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
134 |
+
2021-07-13 10:52:34,839 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
135 |
+
2021-07-13 10:52:42,378 DEBUG SenderThread:342403 [sender.py:send():179] send: history
|
136 |
+
2021-07-13 10:52:42,378 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
137 |
+
2021-07-13 10:52:42,379 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
138 |
+
2021-07-13 10:52:43,158 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
139 |
+
2021-07-13 10:52:45,159 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
140 |
+
2021-07-13 10:52:45,939 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
141 |
+
2021-07-13 10:52:47,160 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
142 |
+
2021-07-13 10:52:49,161 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
143 |
+
2021-07-13 10:52:49,969 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
144 |
+
2021-07-13 10:52:49,970 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
145 |
+
2021-07-13 10:52:51,161 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
146 |
+
2021-07-13 10:52:53,162 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
147 |
+
2021-07-13 10:52:55,163 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
148 |
+
2021-07-13 10:52:57,164 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
149 |
+
2021-07-13 10:53:05,101 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
150 |
+
2021-07-13 10:53:05,101 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
151 |
+
2021-07-13 10:53:16,014 DEBUG SenderThread:342403 [sender.py:send():179] send: stats
|
152 |
+
2021-07-13 10:53:18,580 DEBUG SenderThread:342403 [sender.py:send():179] send: history
|
153 |
+
2021-07-13 10:53:18,580 DEBUG SenderThread:342403 [sender.py:send():179] send: summary
|
154 |
+
2021-07-13 10:53:18,580 INFO SenderThread:342403 [sender.py:_save_file():841] saving file wandb-summary.json with policy end
|
155 |
+
2021-07-13 10:53:19,173 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
156 |
+
2021-07-13 10:53:20,233 DEBUG HandlerThread:342403 [handler.py:handle_request():124] handle_request: stop_status
|
157 |
+
2021-07-13 10:53:20,234 DEBUG SenderThread:342403 [sender.py:send_request():193] send_request: stop_status
|
158 |
+
2021-07-13 10:53:21,173 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
159 |
+
2021-07-13 10:53:25,175 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
160 |
+
2021-07-13 10:53:27,176 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
161 |
+
2021-07-13 10:53:29,177 INFO Thread-8 :342403 [dir_watcher.py:_on_file_modified():229] file/dir modified: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
162 |
+
2021-07-13 10:53:34,237 WARNING MainThread:342403 [internal.py:wandb_internal():147] Internal process interrupt: 1
|
163 |
+
2021-07-13 10:53:34,484 WARNING MainThread:342403 [internal.py:wandb_internal():147] Internal process interrupt: 2
|
164 |
+
2021-07-13 10:53:34,484 ERROR MainThread:342403 [internal.py:wandb_internal():150] Internal process interrupted.
|
165 |
+
2021-07-13 10:53:35,385 INFO WriterThread:342403 [datastore.py:close():288] close: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb
|
166 |
+
2021-07-13 10:53:35,409 INFO SenderThread:342403 [sender.py:finish():945] shutting down sender
|
167 |
+
2021-07-13 10:53:35,409 INFO SenderThread:342403 [dir_watcher.py:finish():282] shutting down directory watcher
|
168 |
+
2021-07-13 10:53:35,414 INFO HandlerThread:342403 [handler.py:finish():638] shutting down handler
|
169 |
+
2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():312] scan: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files
|
170 |
+
2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/requirements.txt requirements.txt
|
171 |
+
2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log output.log
|
172 |
+
2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-metadata.json wandb-metadata.json
|
173 |
+
2021-07-13 10:53:36,180 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/config.yaml config.yaml
|
174 |
+
2021-07-13 10:53:36,181 INFO SenderThread:342403 [dir_watcher.py:finish():318] scan save: /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json wandb-summary.json
|
175 |
+
2021-07-13 10:53:36,181 INFO SenderThread:342403 [file_pusher.py:finish():177] shutting down file pusher
|
176 |
+
2021-07-13 10:53:36,181 INFO SenderThread:342403 [file_pusher.py:join():182] waiting for file pusher
|
177 |
+
2021-07-13 10:53:36,622 INFO Thread-14 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/config.yaml
|
178 |
+
2021-07-13 10:53:36,624 INFO Thread-15 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/wandb-summary.json
|
179 |
+
2021-07-13 10:53:36,634 INFO Thread-13 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/output.log
|
180 |
+
2021-07-13 10:53:36,654 INFO Thread-12 :342403 [upload_job.py:push():137] Uploaded file /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/files/requirements.txt
|
181 |
+
2021-07-13 10:53:37,518 INFO MainThread:342403 [internal.py:handle_exit():78] Internal process exited
|
wandb/run-20210713_104745-1rl2j7or/logs/debug.log
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_setup.py:_flush():69] setting env: {}
|
2 |
+
2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_setup.py:_flush():69] setting login settings: {}
|
3 |
+
2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_init.py:_log_setup():337] Logging user logs to /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/logs/debug.log
|
4 |
+
2021-07-13 10:47:45,130 INFO MainThread:340852 [wandb_init.py:_log_setup():338] Logging internal logs to /home/dat/pino-roberta-base/wandb/run-20210713_104745-1rl2j7or/logs/debug-internal.log
|
5 |
+
2021-07-13 10:47:45,131 INFO MainThread:340852 [wandb_init.py:init():370] calling init triggers
|
6 |
+
2021-07-13 10:47:45,131 INFO MainThread:340852 [wandb_init.py:init():375] wandb.init called with sweep_config: {}
|
7 |
+
config: {}
|
8 |
+
2021-07-13 10:47:45,131 INFO MainThread:340852 [wandb_init.py:init():419] starting backend
|
9 |
+
2021-07-13 10:47:45,131 INFO MainThread:340852 [backend.py:_multiprocessing_setup():70] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
|
10 |
+
2021-07-13 10:47:45,179 INFO MainThread:340852 [backend.py:ensure_launched():135] starting backend process...
|
11 |
+
2021-07-13 10:47:45,225 INFO MainThread:340852 [backend.py:ensure_launched():139] started backend process with pid: 342403
|
12 |
+
2021-07-13 10:47:45,228 INFO MainThread:340852 [wandb_init.py:init():424] backend started and connected
|
13 |
+
2021-07-13 10:47:45,231 INFO MainThread:340852 [wandb_init.py:init():472] updated telemetry
|
14 |
+
2021-07-13 10:47:45,231 INFO MainThread:340852 [wandb_init.py:init():491] communicating current version
|
15 |
+
2021-07-13 10:47:45,870 INFO MainThread:340852 [wandb_init.py:init():496] got version response
|
16 |
+
2021-07-13 10:47:45,870 INFO MainThread:340852 [wandb_init.py:init():504] communicating run to backend with 30 second timeout
|
17 |
+
2021-07-13 10:47:46,040 INFO MainThread:340852 [wandb_init.py:init():529] starting run threads in backend
|
18 |
+
2021-07-13 10:47:47,259 INFO MainThread:340852 [wandb_run.py:_console_start():1623] atexit reg
|
19 |
+
2021-07-13 10:47:47,260 INFO MainThread:340852 [wandb_run.py:_redirect():1497] redirect: SettingsConsole.REDIRECT
|
20 |
+
2021-07-13 10:47:47,261 INFO MainThread:340852 [wandb_run.py:_redirect():1502] Redirecting console.
|
21 |
+
2021-07-13 10:47:47,262 INFO MainThread:340852 [wandb_run.py:_redirect():1558] Redirects installed.
|
22 |
+
2021-07-13 10:47:47,262 INFO MainThread:340852 [wandb_init.py:init():554] run started, returning control to user process
|
23 |
+
2021-07-13 10:47:47,268 INFO MainThread:340852 [wandb_run.py:_config_callback():872] config_cb None None {'output_dir': './', 'overwrite_output_dir': True, 'do_train': False, 'do_eval': False, 'do_predict': False, 'evaluation_strategy': 'IntervalStrategy.NO', 'prediction_loss_only': False, 'per_device_train_batch_size': 2, 'per_device_eval_batch_size': 2, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 2, 'eval_accumulation_steps': None, 'learning_rate': 5e-05, 'weight_decay': 0.0095, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'num_train_epochs': 5.0, 'max_steps': -1, 'lr_scheduler_type': 'SchedulerType.LINEAR', 'warmup_ratio': 0.0, 'warmup_steps': 10, 'log_level': -1, 'log_level_replica': -1, 'log_on_each_node': True, 'logging_dir': './runs/Jul13_10-47-16_t1v-n-f5c06ea1-w-0', 'logging_strategy': 'IntervalStrategy.STEPS', 'logging_first_step': False, 'logging_steps': 50, 'save_strategy': 'IntervalStrategy.STEPS', 'save_steps': 20000, 'save_total_limit': 5, 'save_on_each_node': False, 'no_cuda': False, 'seed': 42, 'fp16': False, 'fp16_opt_level': 'O1', 'fp16_backend': 'auto', 'fp16_full_eval': False, 'local_rank': -1, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 100001, 'dataloader_num_workers': 0, 'past_index': -1, 'run_name': './', 'disable_tqdm': False, 'remove_unused_columns': True, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'sharded_ddp': [], 'deepspeed': None, 'label_smoothing_factor': 0.0, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'dataloader_pin_memory': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': True, 'resume_from_checkpoint': None, 'push_to_hub_model_id': '', 'push_to_hub_organization': None, 'push_to_hub_token': None, 'mp_parameters': ''}
|
24 |
+
2021-07-13 10:47:47,270 INFO MainThread:340852 [wandb_run.py:_config_callback():872] config_cb None None {'model_name_or_path': None, 'model_type': 'big_bird', 'config_name': './', 'tokenizer_name': './', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'float32'}
|
25 |
+
2021-07-13 10:47:47,271 INFO MainThread:340852 [wandb_run.py:_config_callback():872] config_cb None None {'dataset_name': None, 'dataset_config_name': None, 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 4096, 'preprocessing_num_workers': 64, 'mlm_probability': 0.15, 'pad_to_max_length': False, 'line_by_line': False}
|
26 |
+
2021-07-13 10:53:34,760 INFO MainThread:340852 [wandb_run.py:_atexit_cleanup():1593] got exitcode: 255
|
27 |
+
2021-07-13 10:53:34,761 INFO MainThread:340852 [wandb_run.py:_restore():1565] restore
|
wandb/run-20210713_104745-1rl2j7or/run-1rl2j7or.wandb
ADDED
Binary file (14.8 kB). View file
|
|
wandb/run-20210713_110212-594z6oo0/files/config.yaml
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
wandb_version: 1
|
2 |
+
|
3 |
+
_wandb:
|
4 |
+
desc: null
|
5 |
+
value:
|
6 |
+
cli_version: 0.10.33
|
7 |
+
framework: huggingface
|
8 |
+
huggingface_version: 4.9.0.dev0
|
9 |
+
is_jupyter_run: false
|
10 |
+
is_kaggle_kernel: false
|
11 |
+
python_version: 3.8.10
|
12 |
+
t:
|
13 |
+
1:
|
14 |
+
- 3
|
15 |
+
- 11
|
16 |
+
2:
|
17 |
+
- 3
|
18 |
+
- 11
|
19 |
+
4: 3.8.10
|
20 |
+
5: 0.10.33
|
21 |
+
6: 4.9.0.dev0
|
22 |
+
8:
|
23 |
+
- 5
|
24 |
+
adafactor:
|
25 |
+
desc: null
|
26 |
+
value: false
|
27 |
+
adam_beta1:
|
28 |
+
desc: null
|
29 |
+
value: 0.9
|
30 |
+
adam_beta2:
|
31 |
+
desc: null
|
32 |
+
value: 0.98
|
33 |
+
adam_epsilon:
|
34 |
+
desc: null
|
35 |
+
value: 1.0e-08
|
36 |
+
cache_dir:
|
37 |
+
desc: null
|
38 |
+
value: null
|
39 |
+
config_name:
|
40 |
+
desc: null
|
41 |
+
value: ./
|
42 |
+
dataloader_drop_last:
|
43 |
+
desc: null
|
44 |
+
value: false
|
45 |
+
dataloader_num_workers:
|
46 |
+
desc: null
|
47 |
+
value: 0
|
48 |
+
dataloader_pin_memory:
|
49 |
+
desc: null
|
50 |
+
value: true
|
51 |
+
dataset_config_name:
|
52 |
+
desc: null
|
53 |
+
value: null
|
54 |
+
dataset_name:
|
55 |
+
desc: null
|
56 |
+
value: null
|
57 |
+
ddp_find_unused_parameters:
|
58 |
+
desc: null
|
59 |
+
value: null
|
60 |
+
debug:
|
61 |
+
desc: null
|
62 |
+
value: []
|
63 |
+
deepspeed:
|
64 |
+
desc: null
|
65 |
+
value: null
|
66 |
+
disable_tqdm:
|
67 |
+
desc: null
|
68 |
+
value: false
|
69 |
+
do_eval:
|
70 |
+
desc: null
|
71 |
+
value: false
|
72 |
+
do_predict:
|
73 |
+
desc: null
|
74 |
+
value: false
|
75 |
+
do_train:
|
76 |
+
desc: null
|
77 |
+
value: false
|
78 |
+
dtype:
|
79 |
+
desc: null
|
80 |
+
value: float32
|
81 |
+
eval_accumulation_steps:
|
82 |
+
desc: null
|
83 |
+
value: null
|
84 |
+
eval_steps:
|
85 |
+
desc: null
|
86 |
+
value: 100001
|
87 |
+
evaluation_strategy:
|
88 |
+
desc: null
|
89 |
+
value: IntervalStrategy.NO
|
90 |
+
fp16:
|
91 |
+
desc: null
|
92 |
+
value: false
|
93 |
+
fp16_backend:
|
94 |
+
desc: null
|
95 |
+
value: auto
|
96 |
+
fp16_full_eval:
|
97 |
+
desc: null
|
98 |
+
value: false
|
99 |
+
fp16_opt_level:
|
100 |
+
desc: null
|
101 |
+
value: O1
|
102 |
+
gradient_accumulation_steps:
|
103 |
+
desc: null
|
104 |
+
value: 2
|
105 |
+
greater_is_better:
|
106 |
+
desc: null
|
107 |
+
value: null
|
108 |
+
group_by_length:
|
109 |
+
desc: null
|
110 |
+
value: false
|
111 |
+
ignore_data_skip:
|
112 |
+
desc: null
|
113 |
+
value: false
|
114 |
+
label_names:
|
115 |
+
desc: null
|
116 |
+
value: null
|
117 |
+
label_smoothing_factor:
|
118 |
+
desc: null
|
119 |
+
value: 0.0
|
120 |
+
learning_rate:
|
121 |
+
desc: null
|
122 |
+
value: 5.0e-05
|
123 |
+
length_column_name:
|
124 |
+
desc: null
|
125 |
+
value: length
|
126 |
+
line_by_line:
|
127 |
+
desc: null
|
128 |
+
value: false
|
129 |
+
load_best_model_at_end:
|
130 |
+
desc: null
|
131 |
+
value: false
|
132 |
+
local_rank:
|
133 |
+
desc: null
|
134 |
+
value: -1
|
135 |
+
log_level:
|
136 |
+
desc: null
|
137 |
+
value: -1
|
138 |
+
log_level_replica:
|
139 |
+
desc: null
|
140 |
+
value: -1
|
141 |
+
log_on_each_node:
|
142 |
+
desc: null
|
143 |
+
value: true
|
144 |
+
logging_dir:
|
145 |
+
desc: null
|
146 |
+
value: ./runs/Jul13_11-01-24_t1v-n-f5c06ea1-w-0
|
147 |
+
logging_first_step:
|
148 |
+
desc: null
|
149 |
+
value: false
|
150 |
+
logging_steps:
|
151 |
+
desc: null
|
152 |
+
value: 500
|
153 |
+
logging_strategy:
|
154 |
+
desc: null
|
155 |
+
value: IntervalStrategy.STEPS
|
156 |
+
lr_scheduler_type:
|
157 |
+
desc: null
|
158 |
+
value: SchedulerType.LINEAR
|
159 |
+
max_grad_norm:
|
160 |
+
desc: null
|
161 |
+
value: 1.0
|
162 |
+
max_seq_length:
|
163 |
+
desc: null
|
164 |
+
value: 4096
|
165 |
+
max_steps:
|
166 |
+
desc: null
|
167 |
+
value: -1
|
168 |
+
metric_for_best_model:
|
169 |
+
desc: null
|
170 |
+
value: null
|
171 |
+
mlm_probability:
|
172 |
+
desc: null
|
173 |
+
value: 0.15
|
174 |
+
model_name_or_path:
|
175 |
+
desc: null
|
176 |
+
value: null
|
177 |
+
model_type:
|
178 |
+
desc: null
|
179 |
+
value: big_bird
|
180 |
+
mp_parameters:
|
181 |
+
desc: null
|
182 |
+
value: ''
|
183 |
+
no_cuda:
|
184 |
+
desc: null
|
185 |
+
value: false
|
186 |
+
num_train_epochs:
|
187 |
+
desc: null
|
188 |
+
value: 5.0
|
189 |
+
output_dir:
|
190 |
+
desc: null
|
191 |
+
value: ./
|
192 |
+
overwrite_cache:
|
193 |
+
desc: null
|
194 |
+
value: false
|
195 |
+
overwrite_output_dir:
|
196 |
+
desc: null
|
197 |
+
value: true
|
198 |
+
pad_to_max_length:
|
199 |
+
desc: null
|
200 |
+
value: false
|
201 |
+
past_index:
|
202 |
+
desc: null
|
203 |
+
value: -1
|
204 |
+
per_device_eval_batch_size:
|
205 |
+
desc: null
|
206 |
+
value: 2
|
207 |
+
per_device_train_batch_size:
|
208 |
+
desc: null
|
209 |
+
value: 2
|
210 |
+
per_gpu_eval_batch_size:
|
211 |
+
desc: null
|
212 |
+
value: null
|
213 |
+
per_gpu_train_batch_size:
|
214 |
+
desc: null
|
215 |
+
value: null
|
216 |
+
prediction_loss_only:
|
217 |
+
desc: null
|
218 |
+
value: false
|
219 |
+
preprocessing_num_workers:
|
220 |
+
desc: null
|
221 |
+
value: 64
|
222 |
+
push_to_hub:
|
223 |
+
desc: null
|
224 |
+
value: true
|
225 |
+
push_to_hub_model_id:
|
226 |
+
desc: null
|
227 |
+
value: ''
|
228 |
+
push_to_hub_organization:
|
229 |
+
desc: null
|
230 |
+
value: null
|
231 |
+
push_to_hub_token:
|
232 |
+
desc: null
|
233 |
+
value: null
|
234 |
+
remove_unused_columns:
|
235 |
+
desc: null
|
236 |
+
value: true
|
237 |
+
report_to:
|
238 |
+
desc: null
|
239 |
+
value:
|
240 |
+
- tensorboard
|
241 |
+
- wandb
|
242 |
+
resume_from_checkpoint:
|
243 |
+
desc: null
|
244 |
+
value: null
|
245 |
+
run_name:
|
246 |
+
desc: null
|
247 |
+
value: ./
|
248 |
+
save_on_each_node:
|
249 |
+
desc: null
|
250 |
+
value: false
|
251 |
+
save_steps:
|
252 |
+
desc: null
|
253 |
+
value: 20000
|
254 |
+
save_strategy:
|
255 |
+
desc: null
|
256 |
+
value: IntervalStrategy.STEPS
|
257 |
+
save_total_limit:
|
258 |
+
desc: null
|
259 |
+
value: 5
|
260 |
+
seed:
|
261 |
+
desc: null
|
262 |
+
value: 42
|
263 |
+
sharded_ddp:
|
264 |
+
desc: null
|
265 |
+
value: []
|
266 |
+
skip_memory_metrics:
|
267 |
+
desc: null
|
268 |
+
value: true
|
269 |
+
tokenizer_name:
|
270 |
+
desc: null
|
271 |
+
value: ./
|
272 |
+
tpu_metrics_debug:
|
273 |
+
desc: null
|
274 |
+
value: false
|
275 |
+
tpu_num_cores:
|
276 |
+
desc: null
|
277 |
+
value: null
|
278 |
+
train_file:
|
279 |
+
desc: null
|
280 |
+
value: null
|
281 |
+
train_ref_file:
|
282 |
+
desc: null
|
283 |
+
value: null
|
284 |
+
use_fast_tokenizer:
|
285 |
+
desc: null
|
286 |
+
value: true
|
287 |
+
use_legacy_prediction_loop:
|
288 |
+
desc: null
|
289 |
+
value: false
|
290 |
+
validation_file:
|
291 |
+
desc: null
|
292 |
+
value: null
|
293 |
+
validation_ref_file:
|
294 |
+
desc: null
|
295 |
+
value: null
|
296 |
+
validation_split_percentage:
|
297 |
+
desc: null
|
298 |
+
value: 5
|
299 |
+
warmup_ratio:
|
300 |
+
desc: null
|
301 |
+
value: 0.0
|
302 |
+
warmup_steps:
|
303 |
+
desc: null
|
304 |
+
value: 10
|
305 |
+
weight_decay:
|
306 |
+
desc: null
|
307 |
+
value: 0.0095
|
wandb/run-20210713_110212-594z6oo0/files/output.log
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/home/dat/pino/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:3114: UserWarning: Explicitly requested dtype <class 'jax._src.numpy.lax_numpy.int64'> requested in zeros is not available, and will be truncated to dtype int32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
|
2 |
+
lax._check_user_dtype_supported(dtype, "zeros")
|
3 |
+
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
|
4 |
+
warnings.warn(
|
5 |
+
/home/dat/pino/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
|
6 |
+
warnings.warn(
|
7 |
+
Epoch ... (1/5): 0%| | 0/5 [00:00<?, ?it/s]
|
8 |
+
Training...: 0%| | 0/92767 [01:25<?, ?it/s]
|
9 |
+
Epoch ... (1/5): 0%| | 0/5 [02:57<?, ?it/s]
|
10 |
+
Traceback (most recent call last):
|
11 |
+
File "./run_mlm_flax.py", line 712, in <module>
|
12 |
+
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
13 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
|
14 |
+
return fun(*args, **kwargs)
|
15 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/_src/api.py", line 1647, in f_pmapped
|
16 |
+
out = pxla.xla_pmap(
|
17 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1620, in bind
|
18 |
+
return call_bind(self, fun, *args, **params)
|
19 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1551, in call_bind
|
20 |
+
outs = primitive.process(top_trace, fun, tracers, params)
|
21 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 1623, in process
|
22 |
+
return trace.process_map(self, fun, tracers, params)
|
23 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/core.py", line 606, in process_call
|
24 |
+
return primitive.impl(f, *tracers, **params)
|
25 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 637, in xla_pmap_impl
|
26 |
+
return compiled_fun(*args)
|
27 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1152, in execute_replicated
|
28 |
+
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
|
29 |
+
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Resource exhausted: Attempting to reserve 12.60G at the bottom of memory. That was not possible. There are 12.15G free, 0B reserved, and 12.13G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
|
30 |
+
The stack trace below excludes JAX-internal frames.
|
31 |
+
The preceding is the original exception that occurred, unmodified.
|
32 |
+
--------------------
|
33 |
+
The above exception was the direct cause of the following exception:
|
34 |
+
Traceback (most recent call last):
|
35 |
+
File "./run_mlm_flax.py", line 712, in <module>
|
36 |
+
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
37 |
+
File "/home/dat/pino/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 1152, in execute_replicated
|
38 |
+
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
|
39 |
+
RuntimeError: Resource exhausted: Attempting to reserve 12.60G at the bottom of memory. That was not possible. There are 12.15G free, 0B reserved, and 12.13G reservable.: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
|