Saving weights and logs of step 16000
Browse files- Determine_batch_size.ipynb +289 -0
- Load_preprocessed_dataset.ipynb +147 -0
- flax_model.msgpack +1 -1
- pytorch_model.bin +1 -1
- run_t5.sh +4 -4
- run_t5_mlm_flax_custom_dataset.py +80 -70
- runs/Jul10_12-03-45_t1v-n-0e7426e8-w-0/events.out.tfevents.1625920526.t1v-n-0e7426e8-w-0.48005.3.v2 +0 -3
- runs/Jul10_12-39-58_t1v-n-0e7426e8-w-0/events.out.tfevents.1625922498.t1v-n-0e7426e8-w-0.52901.3.v2 +2 -2
- runs/{Jul10_07-37-20_t1v-n-0e7426e8-w-0/events.out.tfevents.1625902752.t1v-n-0e7426e8-w-0.18397.3.v2 → Jul11_09-15-07_t1v-n-0e7426e8-w-0/events.out.tfevents.1625995853.t1v-n-0e7426e8-w-0.145718.3.v2} +2 -2
Determine_batch_size.ipynb
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {
|
7 |
+
"colab": {
|
8 |
+
"base_uri": "https://localhost:8080/"
|
9 |
+
},
|
10 |
+
"id": "kqOqTZuKeJoa",
|
11 |
+
"outputId": "9f63819c-9bc1-4c15-e9cd-9c1121edd2a6"
|
12 |
+
},
|
13 |
+
"outputs": [],
|
14 |
+
"source": [
|
15 |
+
"#!pip install \"jax[tpu]>=0.2.16\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"cell_type": "code",
|
20 |
+
"execution_count": null,
|
21 |
+
"metadata": {},
|
22 |
+
"outputs": [],
|
23 |
+
"source": []
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": 2,
|
28 |
+
"metadata": {
|
29 |
+
"colab": {
|
30 |
+
"base_uri": "https://localhost:8080/"
|
31 |
+
},
|
32 |
+
"id": "SQ-lhEVFeY4d",
|
33 |
+
"outputId": "7346c6b8-1848-4755-c114-94d6de50b50d"
|
34 |
+
},
|
35 |
+
"outputs": [],
|
36 |
+
"source": [
|
37 |
+
"#!git clone https://github.com/huggingface/transformers.git"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": 3,
|
43 |
+
"metadata": {
|
44 |
+
"colab": {
|
45 |
+
"base_uri": "https://localhost:8080/"
|
46 |
+
},
|
47 |
+
"id": "9qSTMLvFfBVs",
|
48 |
+
"outputId": "40659f61-86d4-4ae5-9262-501557737705"
|
49 |
+
},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"#!pip install ./transformers"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": 4,
|
58 |
+
"metadata": {
|
59 |
+
"id": "7Og7zRTrfm08"
|
60 |
+
},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"#!pip install jaxlib>=0.2.9"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 5,
|
69 |
+
"metadata": {
|
70 |
+
"id": "nmQv7VMaf1L8"
|
71 |
+
},
|
72 |
+
"outputs": [],
|
73 |
+
"source": [
|
74 |
+
"#!pip install flax>=0.3.4"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"cell_type": "code",
|
79 |
+
"execution_count": 6,
|
80 |
+
"metadata": {
|
81 |
+
"id": "MT6jpop-f4dc"
|
82 |
+
},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"#!pip install optax>=0.0.9"
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"cell_type": "code",
|
90 |
+
"execution_count": 7,
|
91 |
+
"metadata": {},
|
92 |
+
"outputs": [],
|
93 |
+
"source": [
|
94 |
+
"# %%capture\n",
|
95 |
+
"# !pip install jupyterlab_widgets\n",
|
96 |
+
"# !pip install ipywidgets"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 8,
|
102 |
+
"metadata": {
|
103 |
+
"id": "-F5NIqDmfDLb"
|
104 |
+
},
|
105 |
+
"outputs": [
|
106 |
+
{
|
107 |
+
"name": "stderr",
|
108 |
+
"output_type": "stream",
|
109 |
+
"text": [
|
110 |
+
"2021-07-08 10:11:37.310929: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
|
111 |
+
]
|
112 |
+
}
|
113 |
+
],
|
114 |
+
"source": [
|
115 |
+
"from transformers import (\n",
|
116 |
+
" CONFIG_MAPPING,\n",
|
117 |
+
" FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n",
|
118 |
+
" BatchEncoding,\n",
|
119 |
+
" FlaxT5ForConditionalGeneration,\n",
|
120 |
+
" T5ForConditionalGeneration,\n",
|
121 |
+
" HfArgumentParser,\n",
|
122 |
+
" PreTrainedTokenizerBase,\n",
|
123 |
+
" T5Config,\n",
|
124 |
+
" T5TokenizerFast,\n",
|
125 |
+
" TrainingArguments,\n",
|
126 |
+
" is_tensorboard_available,\n",
|
127 |
+
" set_seed,\n",
|
128 |
+
")\n"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 9,
|
134 |
+
"metadata": {
|
135 |
+
"id": "aInICxY6gREQ"
|
136 |
+
},
|
137 |
+
"outputs": [],
|
138 |
+
"source": [
|
139 |
+
"import flax\n",
|
140 |
+
"import jax\n",
|
141 |
+
"import jax.numpy as jnp\n",
|
142 |
+
"import optax\n",
|
143 |
+
"from flax import jax_utils, traverse_util\n",
|
144 |
+
"from flax.training import train_state\n",
|
145 |
+
"from flax.training.common_utils import get_metrics, onehot, shard\n",
|
146 |
+
"from transformers import (\n",
|
147 |
+
" CONFIG_MAPPING,\n",
|
148 |
+
" FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n",
|
149 |
+
" BatchEncoding,\n",
|
150 |
+
" FlaxT5ForConditionalGeneration,\n",
|
151 |
+
" T5ForConditionalGeneration,\n",
|
152 |
+
" HfArgumentParser,\n",
|
153 |
+
" PreTrainedTokenizerBase,\n",
|
154 |
+
" T5Config,\n",
|
155 |
+
" T5TokenizerFast,\n",
|
156 |
+
" TrainingArguments,\n",
|
157 |
+
" is_tensorboard_available,\n",
|
158 |
+
" set_seed,\n",
|
159 |
+
")\n",
|
160 |
+
"from transformers.models.t5.modeling_flax_t5 import shift_tokens_right\n"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": 10,
|
166 |
+
"metadata": {
|
167 |
+
"id": "iEqVlHptfOCT"
|
168 |
+
},
|
169 |
+
"outputs": [],
|
170 |
+
"source": [
|
171 |
+
"tokenizer = T5TokenizerFast.from_pretrained(\"t5-small\")\n",
|
172 |
+
"config = T5Config.from_pretrained(\"t5-small\")"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": 11,
|
178 |
+
"metadata": {
|
179 |
+
"colab": {
|
180 |
+
"base_uri": "https://localhost:8080/"
|
181 |
+
},
|
182 |
+
"id": "LNETw3cWfjbr",
|
183 |
+
"outputId": "95c0e750-c087-46dd-92fa-39f8ff0238f2"
|
184 |
+
},
|
185 |
+
"outputs": [
|
186 |
+
{
|
187 |
+
"name": "stderr",
|
188 |
+
"output_type": "stream",
|
189 |
+
"text": [
|
190 |
+
"INFO:absl:Starting the local TPU driver.\n",
|
191 |
+
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
192 |
+
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
|
193 |
+
]
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"source": [
|
197 |
+
"model = FlaxT5ForConditionalGeneration(config)"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 12,
|
203 |
+
"metadata": {},
|
204 |
+
"outputs": [],
|
205 |
+
"source": [
|
206 |
+
"import numpy as np"
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": 13,
|
212 |
+
"metadata": {
|
213 |
+
"id": "T5F3BEA2f6xE"
|
214 |
+
},
|
215 |
+
"outputs": [],
|
216 |
+
"source": [
|
217 |
+
"input_ids = np.asarray(208 * [512 * [1]], dtype=np.int32)"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "code",
|
222 |
+
"execution_count": 14,
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [],
|
225 |
+
"source": [
|
226 |
+
"def run_forward(input_ids, params):\n",
|
227 |
+
" return model(input_ids, decoder_input_ids=input_ids).logits"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": 15,
|
233 |
+
"metadata": {},
|
234 |
+
"outputs": [],
|
235 |
+
"source": [
|
236 |
+
"jitted_forward = jax.jit(run_forward)"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "code",
|
241 |
+
"execution_count": 16,
|
242 |
+
"metadata": {},
|
243 |
+
"outputs": [],
|
244 |
+
"source": [
|
245 |
+
"logits = jitted_forward(input_ids, model.params)"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "code",
|
250 |
+
"execution_count": null,
|
251 |
+
"metadata": {},
|
252 |
+
"outputs": [],
|
253 |
+
"source": []
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": null,
|
258 |
+
"metadata": {},
|
259 |
+
"outputs": [],
|
260 |
+
"source": []
|
261 |
+
}
|
262 |
+
],
|
263 |
+
"metadata": {
|
264 |
+
"accelerator": "TPU",
|
265 |
+
"colab": {
|
266 |
+
"name": "Untitled1.ipynb",
|
267 |
+
"provenance": []
|
268 |
+
},
|
269 |
+
"kernelspec": {
|
270 |
+
"display_name": "Python 3 (ipykernel)",
|
271 |
+
"language": "python",
|
272 |
+
"name": "python3"
|
273 |
+
},
|
274 |
+
"language_info": {
|
275 |
+
"codemirror_mode": {
|
276 |
+
"name": "ipython",
|
277 |
+
"version": 3
|
278 |
+
},
|
279 |
+
"file_extension": ".py",
|
280 |
+
"mimetype": "text/x-python",
|
281 |
+
"name": "python",
|
282 |
+
"nbconvert_exporter": "python",
|
283 |
+
"pygments_lexer": "ipython3",
|
284 |
+
"version": "3.8.10"
|
285 |
+
}
|
286 |
+
},
|
287 |
+
"nbformat": 4,
|
288 |
+
"nbformat_minor": 4
|
289 |
+
}
|
Load_preprocessed_dataset.ipynb
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 10,
|
6 |
+
"id": "cf148030-7287-4c9e-ae32-8d1e1c47be30",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from datasets import Dataset, DatasetDict"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 11,
|
16 |
+
"id": "5161b4ba-e8cf-43e1-b67e-503c29aa4271",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"datasets = DatasetDict.load_from_disk(\"./grouped_dataset\")"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": 12,
|
26 |
+
"id": "15f9d047-ac35-43d7-ab55-9f9afe96dd07",
|
27 |
+
"metadata": {},
|
28 |
+
"outputs": [
|
29 |
+
{
|
30 |
+
"data": {
|
31 |
+
"text/plain": [
|
32 |
+
"DatasetDict({\n",
|
33 |
+
" train: Dataset({\n",
|
34 |
+
" features: ['input_ids'],\n",
|
35 |
+
" num_rows: 86438919\n",
|
36 |
+
" })\n",
|
37 |
+
" validation: Dataset({\n",
|
38 |
+
" features: ['input_ids'],\n",
|
39 |
+
" num_rows: 4735324\n",
|
40 |
+
" })\n",
|
41 |
+
"})"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
"execution_count": 12,
|
45 |
+
"metadata": {},
|
46 |
+
"output_type": "execute_result"
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"source": [
|
50 |
+
"datasets"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 14,
|
56 |
+
"id": "d1d1218e-142e-441a-b20d-d300b13b172a",
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"train = datasets['train']"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": null,
|
66 |
+
"id": "9eaddfb1-242f-4a25-8789-efe97b2a5712",
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": []
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 15,
|
74 |
+
"id": "8aabb26f-19ca-467a-b383-3a693be70cac",
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [
|
77 |
+
{
|
78 |
+
"name": "stdout",
|
79 |
+
"output_type": "stream",
|
80 |
+
"text": [
|
81 |
+
"86438919\n"
|
82 |
+
]
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"source": [
|
86 |
+
"print(len(train))"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": null,
|
92 |
+
"id": "f3176986-5b34-4ed6-a643-e342db9a2ce8",
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [],
|
95 |
+
"source": []
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": 16,
|
100 |
+
"id": "1205bbef-ba9d-4ddc-af2e-602d56b7dd64",
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [
|
103 |
+
{
|
104 |
+
"name": "stdout",
|
105 |
+
"output_type": "stream",
|
106 |
+
"text": [
|
107 |
+
"{'input_ids': [256, 3, 20, 18452, 6690, 7757, 1286, 43, 10, 4942, 1286, 80, 12, 4782, 5442, 39, 5385, 33, 4, 5, 3, 2924, 117, 5669, 228, 21, 193, 9030, 511, 24, 11, 5, 665, 165, 4218, 7, 26, 264, 1528, 35, 105, 3, 19653, 12, 9661, 17156, 13955, 4, 132, 5, 611, 959, 961, 146, 6522, 7757, 1286, 89, 7500, 9716, 11, 5, 4868, 107, 13604, 12, 12836, 13368, 11, 611, 959, 4, 3, 69, 99, 12, 13132, 6690, 590, 5, 1803, 1867, 69, 7, 924, 10, 1762, 4, 3, 69, 538, 489, 14, 1149, 16, 3, 11384, 199, 116, 399, 4782, 291, 3, 6, 237, 13, 2629, 3, 8987, 291, 4, 69, 5, 3, 27, 72, 20, 325, 3, 2924, 133, 21, 105, 9030, 10, 1149, 242, 16, 144, 13572, 11, 9, 13401, 20, 7951, 8, 165, 4218, 4, 5, 1910]}\n"
|
108 |
+
]
|
109 |
+
}
|
110 |
+
],
|
111 |
+
"source": [
|
112 |
+
"it = iter(train)\n",
|
113 |
+
"\n",
|
114 |
+
"print(next(it))"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": null,
|
120 |
+
"id": "f5d4e8de-419c-4c70-896e-fbd640bb7321",
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [],
|
123 |
+
"source": []
|
124 |
+
}
|
125 |
+
],
|
126 |
+
"metadata": {
|
127 |
+
"kernelspec": {
|
128 |
+
"display_name": "Python 3 (ipykernel)",
|
129 |
+
"language": "python",
|
130 |
+
"name": "python3"
|
131 |
+
},
|
132 |
+
"language_info": {
|
133 |
+
"codemirror_mode": {
|
134 |
+
"name": "ipython",
|
135 |
+
"version": 3
|
136 |
+
},
|
137 |
+
"file_extension": ".py",
|
138 |
+
"mimetype": "text/x-python",
|
139 |
+
"name": "python",
|
140 |
+
"nbconvert_exporter": "python",
|
141 |
+
"pygments_lexer": "ipython3",
|
142 |
+
"version": "3.8.10"
|
143 |
+
}
|
144 |
+
},
|
145 |
+
"nbformat": 4,
|
146 |
+
"nbformat_minor": 5
|
147 |
+
}
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891548548
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ed68bf4bf2ba245a90ae31d71a56c0f85d1ef7665f0748bd10826c688e5de825
|
3 |
size 891548548
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891650495
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1a8f60fdc3ad43a82bab7ec3dcaf1138179d7508798267becb15426d86b9385f
|
3 |
size 891650495
|
run_t5.sh
CHANGED
@@ -6,7 +6,6 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
6 |
|
7 |
# T5 paper lr 0.01 with batch size 128
|
8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
9 |
-
# Warmup steps is set to 4% of the training steps
|
10 |
|
11 |
./run_t5_mlm_flax_custom_dataset.py \
|
12 |
--output_dir="${MODEL_DIR}" \
|
@@ -23,12 +22,13 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
23 |
--dtype="bfloat16" \
|
24 |
--overwrite_output_dir \
|
25 |
--num_train_epochs="1" \
|
26 |
-
--logging_steps="
|
27 |
--save_steps="2000" \
|
28 |
-
--eval_steps="
|
|
|
|
|
29 |
--push_to_hub
|
30 |
|
31 |
-
# --resume_from_checkpoint="${MODEL_DIR}/ckpt-1500" \
|
32 |
|
33 |
|
34 |
#git add pytorch_model.bin
|
|
|
6 |
|
7 |
# T5 paper lr 0.01 with batch size 128
|
8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
|
|
9 |
|
10 |
./run_t5_mlm_flax_custom_dataset.py \
|
11 |
--output_dir="${MODEL_DIR}" \
|
|
|
22 |
--dtype="bfloat16" \
|
23 |
--overwrite_output_dir \
|
24 |
--num_train_epochs="1" \
|
25 |
+
--logging_steps="20" \
|
26 |
--save_steps="2000" \
|
27 |
+
--eval_steps="10000000" \
|
28 |
+
--resume_from_checkpoint="${MODEL_DIR}/ckpt-14000" \
|
29 |
+
--warmup_steps="3413" \
|
30 |
--push_to_hub
|
31 |
|
|
|
32 |
|
33 |
|
34 |
#git add pytorch_model.bin
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
@@ -31,7 +31,7 @@ from pathlib import Path
|
|
31 |
from typing import Dict, List, Optional
|
32 |
|
33 |
import numpy as np
|
34 |
-
from datasets import load_dataset
|
35 |
from tqdm import tqdm
|
36 |
|
37 |
import flax
|
@@ -552,15 +552,15 @@ if __name__ == "__main__":
|
|
552 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
|
553 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
|
554 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
|
555 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
|
556 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
|
557 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
|
558 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
|
559 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
|
560 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
|
561 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
|
562 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
|
563 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
|
564 |
add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
|
565 |
add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
|
566 |
random.Random(SEED).shuffle(data_files)
|
@@ -580,7 +580,10 @@ if __name__ == "__main__":
|
|
580 |
|
581 |
train, val = train_val_files()
|
582 |
|
583 |
-
|
|
|
|
|
|
|
584 |
|
585 |
# data_files = {}
|
586 |
# if data_args.train_file is not None:
|
@@ -623,31 +626,8 @@ if __name__ == "__main__":
|
|
623 |
config = CONFIG_MAPPING[model_args.model_type]()
|
624 |
logger.warning("You are instantiating a new config instance from scratch.")
|
625 |
|
626 |
-
# Preprocessing the datasets.
|
627 |
-
# First we tokenize all the texts.
|
628 |
-
if training_args.do_train:
|
629 |
-
column_names = datasets["train"].column_names
|
630 |
-
else:
|
631 |
-
column_names = datasets["validation"].column_names
|
632 |
-
text_column_name = "text" if "text" in column_names else column_names[0]
|
633 |
-
|
634 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
635 |
|
636 |
-
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
637 |
-
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
|
638 |
-
def tokenize_function(examples):
|
639 |
-
return tokenizer(examples[text_column_name], return_attention_mask=False)
|
640 |
-
|
641 |
-
logger.info(f"Start tokenization, remove_column_names = {column_names}")
|
642 |
-
|
643 |
-
tokenized_datasets = datasets.map(
|
644 |
-
tokenize_function,
|
645 |
-
batched=True,
|
646 |
-
num_proc=data_args.preprocessing_num_workers,
|
647 |
-
remove_columns=column_names,
|
648 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
649 |
-
)
|
650 |
-
|
651 |
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
652 |
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
653 |
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
|
@@ -656,40 +636,64 @@ if __name__ == "__main__":
|
|
656 |
noise_density=data_args.mlm_probability,
|
657 |
mean_noise_span_length=data_args.mean_noise_span_length,
|
658 |
)
|
|
|
659 |
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
#
|
670 |
-
#
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
|
|
|
|
|
|
679 |
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
693 |
|
694 |
# Enable tensorboard only on the master node
|
695 |
has_tensorboard = is_tensorboard_available()
|
@@ -751,9 +755,14 @@ if __name__ == "__main__":
|
|
751 |
|
752 |
# Create learning rate schedule
|
753 |
|
754 |
-
|
755 |
-
|
756 |
-
|
|
|
|
|
|
|
|
|
|
|
757 |
|
758 |
warmup_fn = optax.linear_schedule(
|
759 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
|
@@ -863,7 +872,8 @@ if __name__ == "__main__":
|
|
863 |
state = jax_utils.replicate(state)
|
864 |
|
865 |
logger.info("***** Running training *****")
|
866 |
-
|
|
|
867 |
logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
|
868 |
logger.info(f" Num Epochs = {num_epochs}")
|
869 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
|
|
31 |
from typing import Dict, List, Optional
|
32 |
|
33 |
import numpy as np
|
34 |
+
from datasets import load_dataset, DatasetDict
|
35 |
from tqdm import tqdm
|
36 |
|
37 |
import flax
|
|
|
552 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
|
553 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
|
554 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
|
555 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
|
556 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
|
557 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
|
558 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
|
559 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
|
560 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
|
561 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
|
562 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
|
563 |
+
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
|
564 |
add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
|
565 |
add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
|
566 |
random.Random(SEED).shuffle(data_files)
|
|
|
580 |
|
581 |
train, val = train_val_files()
|
582 |
|
583 |
+
load_grouped = False
|
584 |
+
|
585 |
+
if not load_grouped:
|
586 |
+
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
587 |
|
588 |
# data_files = {}
|
589 |
# if data_args.train_file is not None:
|
|
|
626 |
config = CONFIG_MAPPING[model_args.model_type]()
|
627 |
logger.warning("You are instantiating a new config instance from scratch.")
|
628 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
629 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
630 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
631 |
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
632 |
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
633 |
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
|
|
|
636 |
noise_density=data_args.mlm_probability,
|
637 |
mean_noise_span_length=data_args.mean_noise_span_length,
|
638 |
)
|
639 |
+
logger.info(f"Max seq length: {max_seq_length}, expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
|
640 |
|
641 |
+
# Preprocessing the datasets.
|
642 |
+
# First we tokenize all the texts.
|
643 |
+
if not load_grouped:
|
644 |
+
if training_args.do_train:
|
645 |
+
column_names = datasets["train"].column_names
|
646 |
+
else:
|
647 |
+
column_names = datasets["validation"].column_names
|
648 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
649 |
+
|
650 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
651 |
+
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
|
652 |
+
def tokenize_function(examples):
|
653 |
+
return tokenizer(examples[text_column_name], return_attention_mask=False)
|
654 |
+
|
655 |
+
logger.info(f"Start tokenization, remove_column_names = {column_names}")
|
656 |
+
tokenized_datasets = datasets.map(
|
657 |
+
tokenize_function,
|
658 |
+
batched=True,
|
659 |
+
num_proc=data_args.preprocessing_num_workers,
|
660 |
+
remove_columns=column_names,
|
661 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
662 |
+
)
|
663 |
|
664 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
|
665 |
+
def group_texts(examples):
|
666 |
+
# Concatenate all texts.
|
667 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
668 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
669 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
670 |
+
# customize this part to your needs.
|
671 |
+
if total_length >= expanded_inputs_length:
|
672 |
+
total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
|
673 |
+
# Split by chunks of max_len.
|
674 |
+
result = {
|
675 |
+
k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
|
676 |
+
for k, t in concatenated_examples.items()
|
677 |
+
}
|
678 |
+
return result
|
679 |
+
|
680 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
681 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
682 |
+
# might be slower to preprocess.
|
683 |
+
#
|
684 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
685 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
686 |
+
logger.info(f"Start group_texts")
|
687 |
+
tokenized_datasets = tokenized_datasets.map(
|
688 |
+
group_texts,
|
689 |
+
batched=True,
|
690 |
+
batch_size=200,
|
691 |
+
num_proc=data_args.preprocessing_num_workers,
|
692 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
693 |
+
)
|
694 |
+
else:
|
695 |
+
logger.info("Loading tokenized and grouped dataset")
|
696 |
+
tokenized_datasets = DatasetDict.load_from_disk("/home/yeb/grouped_datasets")
|
697 |
|
698 |
# Enable tensorboard only on the master node
|
699 |
has_tensorboard = is_tensorboard_available()
|
|
|
755 |
|
756 |
# Create learning rate schedule
|
757 |
|
758 |
+
if training_args.warmup_steps:
|
759 |
+
warmup_steps = training_args.warmup_steps
|
760 |
+
elif training_args.warmup_ratio:
|
761 |
+
# See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at % of training steps
|
762 |
+
warmup_steps = int(training_args.warmup_ratio * num_train_steps)
|
763 |
+
logging.info(f"Warmup steps set to {100*training_args.warmup_ratio}% = {warmup_steps} of total train steps {num_train_steps}")
|
764 |
+
else:
|
765 |
+
raise Exception("Need either --warmup_steps or --warmup_ratio")
|
766 |
|
767 |
warmup_fn = optax.linear_schedule(
|
768 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
|
|
|
872 |
state = jax_utils.replicate(state)
|
873 |
|
874 |
logger.info("***** Running training *****")
|
875 |
+
if not load_grouped:
|
876 |
+
logger.info(f" Num examples = {len(datasets['train'])}")
|
877 |
logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
|
878 |
logger.info(f" Num Epochs = {num_epochs}")
|
879 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
runs/Jul10_12-03-45_t1v-n-0e7426e8-w-0/events.out.tfevents.1625920526.t1v-n-0e7426e8-w-0.48005.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ac6e43e7a7661df39374a0364897650b948e816446aa7cfb526ff2f0f51b9e1e
|
3 |
-
size 40
|
|
|
|
|
|
|
|
runs/Jul10_12-39-58_t1v-n-0e7426e8-w-0/events.out.tfevents.1625922498.t1v-n-0e7426e8-w-0.52901.3.v2
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62ea16c934f55451bfb14bd666567f0c8837fead4ad2e1f6a8adbd8d11fd25a6
|
3 |
+
size 2359167
|
runs/{Jul10_07-37-20_t1v-n-0e7426e8-w-0/events.out.tfevents.1625902752.t1v-n-0e7426e8-w-0.18397.3.v2 → Jul11_09-15-07_t1v-n-0e7426e8-w-0/events.out.tfevents.1625995853.t1v-n-0e7426e8-w-0.145718.3.v2}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ecdd317adb51d2b44773888aaa52793f97b5af475a8f35560774d02bd6ae20a2
|
3 |
+
size 300940
|