Update scripts. Add dataset load check notebook
Browse files- Load_token_group_dataset.ipynb +567 -0
- run_t5.sh +6 -5
- run_t5_mlm_flax_custom_dataset.py +18 -6
Load_token_group_dataset.ipynb
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 71,
|
6 |
+
"id": "d7f2bdb5-95c2-4a57-80e8-8f1a30a138b0",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Number of files 20 after adding ./c4_cleaned glob *73*.gz\n",
|
14 |
+
"Number of files 39 after adding ./c4_cleaned glob *47*.gz\n",
|
15 |
+
"Number of files 60 after adding ./c4_cleaned glob *12*.gz\n",
|
16 |
+
"Number of files 79 after adding ./c4_cleaned glob *29*.gz\n",
|
17 |
+
"Number of files 97 after adding ./c4_cleaned glob *74*.gz\n",
|
18 |
+
"Number of files 116 after adding ./c4_cleaned glob *26*.gz\n",
|
19 |
+
"Number of files 135 after adding ./c4_cleaned glob *54*.gz\n",
|
20 |
+
"Number of files 154 after adding ./c4_cleaned glob *68*.gz\n",
|
21 |
+
"Number of files 172 after adding ./c4_cleaned glob *57*.gz\n",
|
22 |
+
"Number of files 189 after adding ./c4_cleaned glob *46*.gz\n",
|
23 |
+
"Number of files 206 after adding ./c4_cleaned glob *35*.gz\n",
|
24 |
+
"Number of files 226 after adding ./c4_cleaned glob *13*.gz\n",
|
25 |
+
"Number of files 242 after adding ./c4_cleaned glob *41*.gz\n",
|
26 |
+
"Number of files 259 after adding ./c4_cleaned glob *52*.gz\n",
|
27 |
+
"Number of files 276 after adding ./c4_cleaned glob *63*.gz\n",
|
28 |
+
"Number of files 292 after adding ./c4_cleaned glob *85*.gz\n",
|
29 |
+
"Number of files 309 after adding ./c4_cleaned glob *81*.gz\n",
|
30 |
+
"Number of files 326 after adding ./c4_cleaned glob *96*.gz\n",
|
31 |
+
"Number of files 526 after adding ./nrc_uniq_cleaned_20210223 glob *.gz\n",
|
32 |
+
"Number of files 726 after adding ./nu_uniq_cleaned_20210225 glob *.gz\n",
|
33 |
+
"726\n",
|
34 |
+
"Got 690 training files and 5.0 % 36 validation files\n"
|
35 |
+
]
|
36 |
+
}
|
37 |
+
],
|
38 |
+
"source": [
|
39 |
+
"data_files = []\n",
|
40 |
+
"data_dir=\".\"\n",
|
41 |
+
"def train_val_files():\n",
|
42 |
+
" import glob\n",
|
43 |
+
" import random\n",
|
44 |
+
" SEED = 12345\n",
|
45 |
+
"\n",
|
46 |
+
" def add_jsonlines_dir(path, filespec):\n",
|
47 |
+
" global data_files\n",
|
48 |
+
" data_files += glob.glob(f\"{path}/{filespec}\")\n",
|
49 |
+
" data_files = list(set(data_files))\n",
|
50 |
+
" print(f\"Number of files {len(data_files)} after adding {path} glob {filespec}\")\n",
|
51 |
+
"\n",
|
52 |
+
" # add_jsonlines_dir(f\"{data_dir}/oscar_nl_cleaned\")\n",
|
53 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*73*.gz\")\n",
|
54 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*47*.gz\")\n",
|
55 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*12*.gz\")\n",
|
56 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*29*.gz\")\n",
|
57 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*74*.gz\")\n",
|
58 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*26*.gz\")\n",
|
59 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*54*.gz\")\n",
|
60 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*68*.gz\")\n",
|
61 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*57*.gz\")\n",
|
62 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*46*.gz\")\n",
|
63 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*35*.gz\")\n",
|
64 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*13*.gz\")\n",
|
65 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*41*.gz\")\n",
|
66 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*52*.gz\")\n",
|
67 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*63*.gz\")\n",
|
68 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*85*.gz\")\n",
|
69 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*81*.gz\")\n",
|
70 |
+
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*96*.gz\")\n",
|
71 |
+
" add_jsonlines_dir(f\"{data_dir}/nrc_uniq_cleaned_20210223\", \"*.gz\")\n",
|
72 |
+
" add_jsonlines_dir(f\"{data_dir}/nu_uniq_cleaned_20210225\", \"*.gz\")\n",
|
73 |
+
" random.Random(SEED).shuffle(data_files)\n",
|
74 |
+
"\n",
|
75 |
+
" total = len(data_files)\n",
|
76 |
+
" print(total)\n",
|
77 |
+
" perc = 0.05\n",
|
78 |
+
" val_size = int(perc * total)\n",
|
79 |
+
" train_size = total - val_size\n",
|
80 |
+
" train = data_files[:train_size]\n",
|
81 |
+
" val = data_files[train_size:]\n",
|
82 |
+
" print(f\"Got {len(train)} training files and {perc*100} % {len(val)} validation files\")\n",
|
83 |
+
"\n",
|
84 |
+
" assert list(set(train) & set(val)) == [], \"Train overlaps with test\"\n",
|
85 |
+
"\n",
|
86 |
+
" return train, val\n",
|
87 |
+
"\n",
|
88 |
+
"train, val = train_val_files()"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 72,
|
94 |
+
"id": "66a923c6-1c7e-4ac2-9aec-e75c572104dd",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [
|
97 |
+
{
|
98 |
+
"name": "stderr",
|
99 |
+
"output_type": "stream",
|
100 |
+
"text": [
|
101 |
+
"Using custom data configuration default-ce92ec7dc3732df4\n"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"name": "stdout",
|
106 |
+
"output_type": "stream",
|
107 |
+
"text": [
|
108 |
+
"Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/yeb/.cache/huggingface/datasets/json/default-ce92ec7dc3732df4/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9...\n"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"data": {
|
113 |
+
"application/vnd.jupyter.widget-view+json": {
|
114 |
+
"model_id": "",
|
115 |
+
"version_major": 2,
|
116 |
+
"version_minor": 0
|
117 |
+
},
|
118 |
+
"text/plain": [
|
119 |
+
"0 tables [00:00, ? tables/s]"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
"metadata": {},
|
123 |
+
"output_type": "display_data"
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"data": {
|
127 |
+
"application/vnd.jupyter.widget-view+json": {
|
128 |
+
"model_id": "",
|
129 |
+
"version_major": 2,
|
130 |
+
"version_minor": 0
|
131 |
+
},
|
132 |
+
"text/plain": [
|
133 |
+
"0 tables [00:00, ? tables/s]"
|
134 |
+
]
|
135 |
+
},
|
136 |
+
"metadata": {},
|
137 |
+
"output_type": "display_data"
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"name": "stdout",
|
141 |
+
"output_type": "stream",
|
142 |
+
"text": [
|
143 |
+
"Dataset json downloaded and prepared to /home/yeb/.cache/huggingface/datasets/json/default-ce92ec7dc3732df4/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9. Subsequent calls will reuse this data.\n"
|
144 |
+
]
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"from datasets import load_dataset\n",
|
149 |
+
"datasets = load_dataset('json', data_files={'train': train, 'validation': val})"
|
150 |
+
]
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": 73,
|
155 |
+
"id": "4a6d6009-00e7-4b30-b577-6805dd849b8a",
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [
|
158 |
+
{
|
159 |
+
"name": "stdout",
|
160 |
+
"output_type": "stream",
|
161 |
+
"text": [
|
162 |
+
"Num examples = 21153916\n"
|
163 |
+
]
|
164 |
+
}
|
165 |
+
],
|
166 |
+
"source": [
|
167 |
+
"print(f\"Num examples = {len(datasets['train'])}\")"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
{
|
171 |
+
"cell_type": "code",
|
172 |
+
"execution_count": 74,
|
173 |
+
"id": "c6186d88-4296-4d1d-b7cd-d0196f0b0f97",
|
174 |
+
"metadata": {},
|
175 |
+
"outputs": [],
|
176 |
+
"source": [
|
177 |
+
"from transformers import (\n",
|
178 |
+
" CONFIG_MAPPING,\n",
|
179 |
+
" FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n",
|
180 |
+
" BatchEncoding,\n",
|
181 |
+
" FlaxT5ForConditionalGeneration,\n",
|
182 |
+
" T5ForConditionalGeneration,\n",
|
183 |
+
" HfArgumentParser,\n",
|
184 |
+
" PreTrainedTokenizerBase,\n",
|
185 |
+
" T5Config,\n",
|
186 |
+
" T5TokenizerFast,\n",
|
187 |
+
" TrainingArguments,\n",
|
188 |
+
" is_tensorboard_available,\n",
|
189 |
+
" set_seed,\n",
|
190 |
+
")"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": 75,
|
196 |
+
"id": "10d90997-6eb6-4399-b1a7-8a858ae4738c",
|
197 |
+
"metadata": {},
|
198 |
+
"outputs": [
|
199 |
+
{
|
200 |
+
"name": "stderr",
|
201 |
+
"output_type": "stream",
|
202 |
+
"text": [
|
203 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"name": "stdout",
|
208 |
+
"output_type": "stream",
|
209 |
+
"text": [
|
210 |
+
"Start tokenization, remove_column_names = ['url', 'timestamp', 'text']\n"
|
211 |
+
]
|
212 |
+
}
|
213 |
+
],
|
214 |
+
"source": [
|
215 |
+
"tokenizer = T5TokenizerFast.from_pretrained(\"./t5-base-dutch\")\n",
|
216 |
+
"\n",
|
217 |
+
"def tokenize_function(examples):\n",
|
218 |
+
" return tokenizer(examples['text'], return_attention_mask=False)\n",
|
219 |
+
"\n",
|
220 |
+
"column_names = datasets[\"train\"].column_names\n",
|
221 |
+
"print(f\"Start tokenization, remove_column_names = {column_names}\")\n",
|
222 |
+
"\n",
|
223 |
+
"tokenized_datasets = datasets.map(\n",
|
224 |
+
" tokenize_function,\n",
|
225 |
+
" batched=True,\n",
|
226 |
+
" num_proc=96,\n",
|
227 |
+
" remove_columns=column_names,\n",
|
228 |
+
" load_from_cache_file=True,\n",
|
229 |
+
")\n",
|
230 |
+
"\n"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "code",
|
235 |
+
"execution_count": 76,
|
236 |
+
"id": "de7983e1-775d-4ee3-bf66-681f731501fb",
|
237 |
+
"metadata": {},
|
238 |
+
"outputs": [
|
239 |
+
{
|
240 |
+
"data": {
|
241 |
+
"text/plain": [
|
242 |
+
"21153916"
|
243 |
+
]
|
244 |
+
},
|
245 |
+
"execution_count": 76,
|
246 |
+
"metadata": {},
|
247 |
+
"output_type": "execute_result"
|
248 |
+
}
|
249 |
+
],
|
250 |
+
"source": [
|
251 |
+
"len(tokenized_datasets[\"train\"])"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "code",
|
256 |
+
"execution_count": 77,
|
257 |
+
"id": "5721ad35-8373-4999-8ac5-02c6f759373f",
|
258 |
+
"metadata": {},
|
259 |
+
"outputs": [
|
260 |
+
{
|
261 |
+
"name": "stdout",
|
262 |
+
"output_type": "stream",
|
263 |
+
"text": [
|
264 |
+
"Expanded_inputs_length: 141, targets_length: 29\n",
|
265 |
+
"Start group_texts\n"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"name": "stderr",
|
270 |
+
"output_type": "stream",
|
271 |
+
"text": [
|
272 |
+
"https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=503811,5cca55,7fe2dabc120f,7fe2dabc120f,90641f90b85f&map=&map= \n",
|
273 |
+
" \n",
|
274 |
+
"*** SIGTERM received by PID 47670 (TID 47670) on cpu 70 from PID 33223; stack trace: ***\n",
|
275 |
+
"*** SIGTERM received by PID 47686 (TID 47686) on cpu 71 from PID 33223; stack trace: ***\n",
|
276 |
+
"https://symbolize.stripped_domain/r/?trace=56a4e1,7fe2dabc120f&map= \n",
|
277 |
+
"https://symbolize.stripped_domain/r/?trace=*** SIGTERM received by PID 47673 (TID 47673) on cpu 16 from PID 33223; stack trace: ***\n",
|
278 |
+
"56a682,7fe2dabc120f,7fdfb4cf751f,90b3ff&map= \n",
|
279 |
+
"*** SIGTERM received by PID 47665 (TID 47665) on cpu 67 from PID 33223; stack trace: ***\n",
|
280 |
+
"PC: @ 0x503811 (unknown) (unknown)\n",
|
281 |
+
"PC: @ 0x56a4e1 (unknown) _PyEval_EvalFrameDefault\n",
|
282 |
+
"PC: @ 0x5cca55 (unknown) (unknown)\n",
|
283 |
+
" @ 0x7fde2703b800 976 (unknown)\n",
|
284 |
+
" @ 0x7fde2703b800 976 (unknown)\n",
|
285 |
+
" @ 0x7fe2dabc1210 (unknown) (unknown)\n",
|
286 |
+
" @ ... and at least 1 more frames\n",
|
287 |
+
"https://symbolize.stripped_domain/r/?trace= @ 0x7fe2dabc1210 852927808 (unknown)\n",
|
288 |
+
"56a4e1,7fde2703b7ff,7fe2dabc120f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
289 |
+
"E0710 11:59:41.025238 47673 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
290 |
+
" @ 0x7fde2703b800 976 (unknown)\n",
|
291 |
+
" @ 0x7fe2dabc1210 850855568 (unknown)\n",
|
292 |
+
" @ 0x90b860 (unknown) (unknown)\n",
|
293 |
+
"https://symbolize.stripped_domain/r/?trace=5cca55,7fde2703b7ff,7fe2dabc120f,90b85f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
294 |
+
"E0710 11:59:41.030755 47686 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
295 |
+
" @ 0x906420 (unknown) (unknown)\n",
|
296 |
+
"https://symbolize.stripped_domain/r/?trace=503811,7fde2703b7ff,7fe2dabc120f,90641f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
297 |
+
"E0710 11:59:41.033184 47670 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
298 |
+
"E0710 11:59:41.033730 47673 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
|
299 |
+
"PC: @ 0x56a682 (unknown) _PyEval_EvalFrameDefault\n",
|
300 |
+
" @ 0x7fde2703b800 976 (unknown)\n",
|
301 |
+
" @ 0x7fe2dabc1210 (unknown) (unknown)\n",
|
302 |
+
" @ 0x7fdfb4cf7520 (unknown) (unknown)\n",
|
303 |
+
"E0710 11:59:41.057700 47670 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
|
304 |
+
"E0710 11:59:41.063730 47686 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
|
305 |
+
" @ 0x90b400 (unknown) (unknown)\n",
|
306 |
+
"https://symbolize.stripped_domain/r/?trace=56a682,7fde2703b7ff,7fe2dabc120f,7fdfb4cf751f,90b3ff&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
307 |
+
"E0710 11:59:41.064237 47665 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
308 |
+
"E0710 11:59:41.091833 47665 process_state.cc:771] RAW: Raising signal 15 with default behavior\n"
|
309 |
+
]
|
310 |
+
}
|
311 |
+
],
|
312 |
+
"source": [
|
313 |
+
"def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):\n",
|
314 |
+
" \"\"\"This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .\n",
|
315 |
+
"\n",
|
316 |
+
" Training parameters to avoid padding with random_spans_noise_mask.\n",
|
317 |
+
" When training a model with random_spans_noise_mask, we would like to set the other\n",
|
318 |
+
" training hyperparmeters in a way that avoids padding.\n",
|
319 |
+
" This function helps us compute these hyperparameters.\n",
|
320 |
+
" We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,\n",
|
321 |
+
" and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.\n",
|
322 |
+
" This function tells us the required number of tokens in the raw example (for split_tokens())\n",
|
323 |
+
" as well as the length of the encoded targets. Note that this function assumes\n",
|
324 |
+
" the inputs and targets will have EOS appended and includes that in the reported length.\n",
|
325 |
+
"\n",
|
326 |
+
" Args:\n",
|
327 |
+
" inputs_length: an integer - desired length of the tokenized inputs sequence\n",
|
328 |
+
" noise_density: a float\n",
|
329 |
+
" mean_noise_span_length: a float\n",
|
330 |
+
" Returns:\n",
|
331 |
+
" tokens_length: length of original text in tokens\n",
|
332 |
+
" targets_length: an integer - length in tokens of encoded targets sequence\n",
|
333 |
+
" \"\"\"\n",
|
334 |
+
"\n",
|
335 |
+
" def _tokens_length_to_inputs_length_targets_length(tokens_length):\n",
|
336 |
+
" num_noise_tokens = int(round(tokens_length * noise_density))\n",
|
337 |
+
" num_nonnoise_tokens = tokens_length - num_noise_tokens\n",
|
338 |
+
" num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))\n",
|
339 |
+
" # inputs contain all nonnoise tokens, sentinels for all noise spans\n",
|
340 |
+
" # and one EOS token.\n",
|
341 |
+
" _input_length = num_nonnoise_tokens + num_noise_spans + 1\n",
|
342 |
+
" _output_length = num_noise_tokens + num_noise_spans + 1\n",
|
343 |
+
" return _input_length, _output_length\n",
|
344 |
+
"\n",
|
345 |
+
" tokens_length = inputs_length\n",
|
346 |
+
"\n",
|
347 |
+
" while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:\n",
|
348 |
+
" tokens_length += 1\n",
|
349 |
+
"\n",
|
350 |
+
" inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)\n",
|
351 |
+
"\n",
|
352 |
+
" # minor hack to get the targets length to be equal to inputs length\n",
|
353 |
+
" # which is more likely to have been set to a nice round number.\n",
|
354 |
+
" if noise_density == 0.5 and targets_length > inputs_length:\n",
|
355 |
+
" tokens_length -= 1\n",
|
356 |
+
" targets_length -= 1\n",
|
357 |
+
" return tokens_length, targets_length\n",
|
358 |
+
"\n",
|
359 |
+
"# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.\n",
|
360 |
+
"# To ensure that the input length is `max_seq_length`, we need to increase the maximum length\n",
|
361 |
+
"# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.\n",
|
362 |
+
"expanded_inputs_length, targets_length = compute_input_and_target_lengths(\n",
|
363 |
+
" inputs_length=128,\n",
|
364 |
+
" noise_density=0.15,\n",
|
365 |
+
" mean_noise_span_length=3.0,\n",
|
366 |
+
")\n",
|
367 |
+
"\n",
|
368 |
+
"print(f\"Expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}\")\n",
|
369 |
+
"print(f\"Start group_texts\")\n",
|
370 |
+
"\n",
|
371 |
+
"# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.\n",
|
372 |
+
"def group_texts(examples):\n",
|
373 |
+
" # Concatenate all texts.\n",
|
374 |
+
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
375 |
+
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
376 |
+
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
377 |
+
" # customize this part to your needs.\n",
|
378 |
+
" if total_length >= expanded_inputs_length:\n",
|
379 |
+
" total_length = (total_length // expanded_inputs_length) * expanded_inputs_length\n",
|
380 |
+
" # Split by chunks of max_len.\n",
|
381 |
+
" result = {\n",
|
382 |
+
" k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]\n",
|
383 |
+
" for k, t in concatenated_examples.items()\n",
|
384 |
+
" }\n",
|
385 |
+
" return result\n",
|
386 |
+
"\n",
|
387 |
+
"# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a\n",
|
388 |
+
"# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value\n",
|
389 |
+
"# might be slower to preprocess.\n",
|
390 |
+
"#\n",
|
391 |
+
"# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n",
|
392 |
+
"# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n",
|
393 |
+
"grouped_datasets = tokenized_datasets.map(\n",
|
394 |
+
" group_texts,\n",
|
395 |
+
" batched=True,\n",
|
396 |
+
" batch_size=200,\n",
|
397 |
+
" num_proc=96,\n",
|
398 |
+
" load_from_cache_file=True,\n",
|
399 |
+
")\n"
|
400 |
+
]
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"cell_type": "code",
|
404 |
+
"execution_count": 78,
|
405 |
+
"id": "f37e7559-fcc1-436b-a4ee-45adb856869e",
|
406 |
+
"metadata": {},
|
407 |
+
"outputs": [
|
408 |
+
{
|
409 |
+
"data": {
|
410 |
+
"text/plain": [
|
411 |
+
"86438919"
|
412 |
+
]
|
413 |
+
},
|
414 |
+
"execution_count": 78,
|
415 |
+
"metadata": {},
|
416 |
+
"output_type": "execute_result"
|
417 |
+
}
|
418 |
+
],
|
419 |
+
"source": [
|
420 |
+
"examples = len(grouped_datasets[\"train\"])\n",
|
421 |
+
"examples"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"cell_type": "code",
|
426 |
+
"execution_count": 79,
|
427 |
+
"id": "21aac2aa-9dc2-4b7a-8c46-62cfa47f18a7",
|
428 |
+
"metadata": {},
|
429 |
+
"outputs": [],
|
430 |
+
"source": [
|
431 |
+
"it = iter(grouped_datasets[\"train\"])"
|
432 |
+
]
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"cell_type": "code",
|
436 |
+
"execution_count": 80,
|
437 |
+
"id": "011a6a07-5fe0-441a-b032-79cf8664b5c5",
|
438 |
+
"metadata": {},
|
439 |
+
"outputs": [
|
440 |
+
{
|
441 |
+
"name": "stdout",
|
442 |
+
"output_type": "stream",
|
443 |
+
"text": [
|
444 |
+
"{'input_ids': [256, 3, 20, 18452, 6690, 7757, 1286, 43, 10, 4942, 1286, 80, 12, 4782, 5442, 39, 5385, 33, 4, 5, 3, 2924, 117, 5669, 228, 21, 193, 9030, 511, 24, 11, 5, 665, 165, 4218, 7, 26, 264, 1528, 35, 105, 3, 19653, 12, 9661, 17156, 13955, 4, 132, 5, 611, 959, 961, 146, 6522, 7757, 1286, 89, 7500, 9716, 11, 5, 4868, 107, 13604, 12, 12836, 13368, 11, 611, 959, 4, 3, 69, 99, 12, 13132, 6690, 590, 5, 1803, 1867, 69, 7, 924, 10, 1762, 4, 3, 69, 538, 489, 14, 1149, 16, 3, 11384, 199, 116, 399, 4782, 291, 3, 6, 237, 13, 2629, 3, 8987, 291, 4, 69, 5, 3, 27, 72, 20, 325, 3, 2924, 133, 21, 105, 9030, 10, 1149, 242, 16, 144, 13572, 11, 9, 13401, 20, 7951, 8, 165, 4218, 4, 5, 1910]}\n"
|
445 |
+
]
|
446 |
+
}
|
447 |
+
],
|
448 |
+
"source": [
|
449 |
+
"print(next(it))"
|
450 |
+
]
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"cell_type": "code",
|
454 |
+
"execution_count": 81,
|
455 |
+
"id": "f20d3da2-0132-4ecc-b9b9-c2b5ec06f031",
|
456 |
+
"metadata": {},
|
457 |
+
"outputs": [],
|
458 |
+
"source": [
|
459 |
+
"tokens = next(it)['input_ids']\n"
|
460 |
+
]
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"cell_type": "code",
|
464 |
+
"execution_count": 82,
|
465 |
+
"id": "2bad87cd-06e1-4c52-b2d6-d61fcb96e35d",
|
466 |
+
"metadata": {},
|
467 |
+
"outputs": [
|
468 |
+
{
|
469 |
+
"data": {
|
470 |
+
"text/plain": [
|
471 |
+
"141"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
"execution_count": 82,
|
475 |
+
"metadata": {},
|
476 |
+
"output_type": "execute_result"
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"source": [
|
480 |
+
"len(tokens)"
|
481 |
+
]
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"cell_type": "code",
|
485 |
+
"execution_count": 83,
|
486 |
+
"id": "4e0f573a-0abc-4f8f-b59a-a281fb306425",
|
487 |
+
"metadata": {},
|
488 |
+
"outputs": [
|
489 |
+
{
|
490 |
+
"data": {
|
491 |
+
"text/plain": [
|
492 |
+
"\"werden volgens getuigen vergezeld door een boomlange bodyguard. ook hing er een gordijntje om de tafel, zodat beyoncé in alle rust van de show kon genieten. volgens de bron verliet knowles pas om 03.30 uur's ochtends de hippe club.</s> utrecht - in de schouwburg van utrecht gaat vrijdagavond de musical 'joseph and the amazing technicolor dreamcoat' in première. voor het eerst in nederland. een voorloper van het geesteskind van andrew lloyd webber werd al in 1967 voor het eerst op een school in groot-brittannië uitgeprobeerd. twaalf jaar later werd het in\""
|
493 |
+
]
|
494 |
+
},
|
495 |
+
"execution_count": 83,
|
496 |
+
"metadata": {},
|
497 |
+
"output_type": "execute_result"
|
498 |
+
}
|
499 |
+
],
|
500 |
+
"source": [
|
501 |
+
"tokenizer.decode(tokens)"
|
502 |
+
]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "code",
|
506 |
+
"execution_count": 84,
|
507 |
+
"id": "ab853c1b-0e0f-4ae8-b1cb-053f76a7d9d7",
|
508 |
+
"metadata": {},
|
509 |
+
"outputs": [
|
510 |
+
{
|
511 |
+
"ename": "KeyboardInterrupt",
|
512 |
+
"evalue": "",
|
513 |
+
"output_type": "error",
|
514 |
+
"traceback": [
|
515 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
516 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
517 |
+
"\u001b[0;32m/tmp/ipykernel_33223/1050159500.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m \u001b[0;34m:=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m141\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
518 |
+
"\u001b[0;32m~/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1264\u001b[0m \u001b[0moutput_all_columns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output_all_columns\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1266\u001b[0;31m yield self._getitem(\n\u001b[0m\u001b[1;32m 1267\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1268\u001b[0m \u001b[0mformat_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformat_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
519 |
+
"\u001b[0;32m~/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_getitem\u001b[0;34m(self, key, format_type, format_columns, output_all_columns, format_kwargs)\u001b[0m\n\u001b[1;32m 1507\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1508\u001b[0m \u001b[0mformatter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_formatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mformat_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mformat_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1509\u001b[0;31m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquery_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_indices\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_indices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1510\u001b[0m formatted_output = format_table(\n\u001b[1;32m 1511\u001b[0m \u001b[0mpa_subtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformatter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformat_columns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_all_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_all_columns\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
520 |
+
"\u001b[0;32m~/datasets/src/datasets/formatting/formatting.py\u001b[0m in \u001b[0;36mquery_table\u001b[0;34m(table, key, indices)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0;31m# Query the main table\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 371\u001b[0;31m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_query_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 372\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_query_table_with_indices_mapping\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
521 |
+
"\u001b[0;32m~/datasets/src/datasets/formatting/formatting.py\u001b[0m in \u001b[0;36m_query_table\u001b[0;34m(table, key)\u001b[0m\n\u001b[1;32m 77\u001b[0m \"\"\"\n\u001b[1;32m 78\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfast_slice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mslice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
522 |
+
"\u001b[0;32m~/datasets/src/datasets/table.py\u001b[0m in \u001b[0;36mfast_slice\u001b[0;34m(self, offset, length)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0moffset\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlength\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mpa\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mschema\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_schema\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_interpolation_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0moffset\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mbatches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batches\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
523 |
+
"\u001b[0;32m~/datasets/src/datasets/table.py\u001b[0m in \u001b[0;36m_interpolation_search\u001b[0;34m(arr, x)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
524 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
525 |
+
]
|
526 |
+
}
|
527 |
+
],
|
528 |
+
"source": [
|
529 |
+
"while (example := next(it, None)) is not None:\n",
|
530 |
+
" if len(example['input_ids']) == 141:\n",
|
531 |
+
" continue\n",
|
532 |
+
" else:\n",
|
533 |
+
" print(example)\n",
|
534 |
+
" break"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"cell_type": "code",
|
539 |
+
"execution_count": null,
|
540 |
+
"id": "f71a0f6b-3b60-4dd5-a9af-0ef43aadc6a1",
|
541 |
+
"metadata": {},
|
542 |
+
"outputs": [],
|
543 |
+
"source": []
|
544 |
+
}
|
545 |
+
],
|
546 |
+
"metadata": {
|
547 |
+
"kernelspec": {
|
548 |
+
"display_name": "Python 3 (ipykernel)",
|
549 |
+
"language": "python",
|
550 |
+
"name": "python3"
|
551 |
+
},
|
552 |
+
"language_info": {
|
553 |
+
"codemirror_mode": {
|
554 |
+
"name": "ipython",
|
555 |
+
"version": 3
|
556 |
+
},
|
557 |
+
"file_extension": ".py",
|
558 |
+
"mimetype": "text/x-python",
|
559 |
+
"name": "python",
|
560 |
+
"nbconvert_exporter": "python",
|
561 |
+
"pygments_lexer": "ipython3",
|
562 |
+
"version": "3.8.10"
|
563 |
+
}
|
564 |
+
},
|
565 |
+
"nbformat": 4,
|
566 |
+
"nbformat_minor": 5
|
567 |
+
}
|
run_t5.sh
CHANGED
@@ -6,7 +6,7 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
6 |
|
7 |
# T5 paper lr 0.01 with batch size 128
|
8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
9 |
-
# Warmup steps is set to
|
10 |
|
11 |
./run_t5_mlm_flax_custom_dataset.py \
|
12 |
--output_dir="${MODEL_DIR}" \
|
@@ -19,20 +19,21 @@ mkdir -p "${MODEL_DIR}/runs"
|
|
19 |
--max_seq_length="512" \
|
20 |
--per_device_train_batch_size="32" \
|
21 |
--per_device_eval_batch_size="32" \
|
22 |
-
--learning_rate="
|
23 |
--dtype="bfloat16" \
|
24 |
--overwrite_output_dir \
|
25 |
--num_train_epochs="1" \
|
26 |
--logging_steps="50" \
|
27 |
-
--save_steps="
|
28 |
--eval_steps="1000000" \
|
29 |
-
--resume_from_checkpoint="${MODEL_DIR}/ckpt-1500" \
|
30 |
--push_to_hub
|
31 |
|
|
|
|
|
|
|
32 |
#git add pytorch_model.bin
|
33 |
#git commit -m "Update pytorch model after training"
|
34 |
#git push origin main
|
35 |
|
36 |
-
# --learning_rate="5e-3" \
|
37 |
# --gradient_accumulation_steps="2" \
|
38 |
|
|
|
6 |
|
7 |
# T5 paper lr 0.01 with batch size 128
|
8 |
# We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
|
9 |
+
# Warmup steps is set to 4% of the training steps
|
10 |
|
11 |
./run_t5_mlm_flax_custom_dataset.py \
|
12 |
--output_dir="${MODEL_DIR}" \
|
|
|
19 |
--max_seq_length="512" \
|
20 |
--per_device_train_batch_size="32" \
|
21 |
--per_device_eval_batch_size="32" \
|
22 |
+
--learning_rate="5e-3" \
|
23 |
--dtype="bfloat16" \
|
24 |
--overwrite_output_dir \
|
25 |
--num_train_epochs="1" \
|
26 |
--logging_steps="50" \
|
27 |
+
--save_steps="2000" \
|
28 |
--eval_steps="1000000" \
|
|
|
29 |
--push_to_hub
|
30 |
|
31 |
+
# --resume_from_checkpoint="${MODEL_DIR}/ckpt-1500" \
|
32 |
+
|
33 |
+
|
34 |
#git add pytorch_model.bin
|
35 |
#git commit -m "Update pytorch model after training"
|
36 |
#git push origin main
|
37 |
|
|
|
38 |
# --gradient_accumulation_steps="2" \
|
39 |
|
run_t5_mlm_flax_custom_dataset.py
CHANGED
@@ -543,13 +543,24 @@ if __name__ == "__main__":
|
|
543 |
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
544 |
|
545 |
# add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
|
546 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
|
547 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*73*.gz")
|
|
|
548 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*12*.gz")
|
549 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*29*.gz")
|
550 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*74*.gz")
|
551 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*26*.gz")
|
552 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
553 |
add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
|
554 |
add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
|
555 |
random.Random(SEED).shuffle(data_files)
|
@@ -740,9 +751,9 @@ if __name__ == "__main__":
|
|
740 |
|
741 |
# Create learning rate schedule
|
742 |
|
743 |
-
# See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at
|
744 |
-
warmup_steps = int(0.
|
745 |
-
logging.info(f"Warmup steps set to
|
746 |
|
747 |
warmup_fn = optax.linear_schedule(
|
748 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
|
@@ -796,6 +807,8 @@ if __name__ == "__main__":
|
|
796 |
else:
|
797 |
resume_step = 0
|
798 |
|
|
|
|
|
799 |
# Define gradient update step fn
|
800 |
def train_step(state, batch, dropout_rng):
|
801 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
@@ -849,10 +862,9 @@ if __name__ == "__main__":
|
|
849 |
# Replicate the train state on each device
|
850 |
state = jax_utils.replicate(state)
|
851 |
|
852 |
-
|
853 |
-
|
854 |
logger.info("***** Running training *****")
|
855 |
logger.info(f" Num examples = {len(datasets['train'])}")
|
|
|
856 |
logger.info(f" Num Epochs = {num_epochs}")
|
857 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
858 |
logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
|
|
|
543 |
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
544 |
|
545 |
# add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
|
|
|
546 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*73*.gz")
|
547 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
|
548 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*12*.gz")
|
549 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*29*.gz")
|
550 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*74*.gz")
|
551 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*26*.gz")
|
552 |
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
|
553 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
|
554 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
|
555 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
|
556 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
|
557 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
|
558 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
|
559 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
|
560 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
|
561 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
|
562 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
|
563 |
+
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
|
564 |
add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
|
565 |
add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
|
566 |
random.Random(SEED).shuffle(data_files)
|
|
|
751 |
|
752 |
# Create learning rate schedule
|
753 |
|
754 |
+
# See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 4% of training steps
|
755 |
+
warmup_steps = int(0.04 * num_train_steps)
|
756 |
+
logging.info(f"Warmup steps set to 4% = {warmup_steps} of total train steps {num_train_steps}")
|
757 |
|
758 |
warmup_fn = optax.linear_schedule(
|
759 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
|
|
|
807 |
else:
|
808 |
resume_step = 0
|
809 |
|
810 |
+
logger.info("")
|
811 |
+
|
812 |
# Define gradient update step fn
|
813 |
def train_step(state, batch, dropout_rng):
|
814 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
|
|
862 |
# Replicate the train state on each device
|
863 |
state = jax_utils.replicate(state)
|
864 |
|
|
|
|
|
865 |
logger.info("***** Running training *****")
|
866 |
logger.info(f" Num examples = {len(datasets['train'])}")
|
867 |
+
logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
|
868 |
logger.info(f" Num Epochs = {num_epochs}")
|
869 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
870 |
logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
|