dat
commited on
Commit
•
0446688
1
Parent(s):
fc71740
add all
Browse files- .ipynb_checkpoints/Load data & train tokenizer-checkpoint.ipynb +3 -0
- Load data & train tokenizer.ipynb +488 -0
- config.json +31 -0
- run.sh +30 -0
- run_mlm_flax.py +787 -0
- tokenizer.json +0 -0
.ipynb_checkpoints/Load data & train tokenizer-checkpoint.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1fc562ad584b6a6a4c02dcbf860f644f153519efb0996ddbe7a8c6861fb254b7
|
3 |
+
size 11997
|
Load data & train tokenizer.ipynb
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "723b5d4d",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import jax\n",
|
11 |
+
"import optax\n",
|
12 |
+
"import flax\n",
|
13 |
+
"import jax.numpy as jnp\n",
|
14 |
+
"import datasets\n",
|
15 |
+
"from flax.training import train_state\n",
|
16 |
+
"from flax.training.common_utils import get_metrics, onehot, shard\n",
|
17 |
+
"from datasets import load_dataset\n",
|
18 |
+
"from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n",
|
19 |
+
"from pathlib import Path\n",
|
20 |
+
"import numpy as np\n",
|
21 |
+
"import transformers\n",
|
22 |
+
"from tqdm.notebook import tqdm\n",
|
23 |
+
"from pathlib import Path\n",
|
24 |
+
"from transformers import AutoConfig\n",
|
25 |
+
"from typing import Dict, List, Optional, Tuple\n",
|
26 |
+
"from transformers import AutoTokenizer\n",
|
27 |
+
"from transformers import PreTrainedTokenizerBase\n",
|
28 |
+
"from transformers import FlaxAutoModelForMaskedLM\n",
|
29 |
+
"from dataclasses import dataclass, field\n",
|
30 |
+
"import time\n",
|
31 |
+
"import glob\n",
|
32 |
+
"import random"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "code",
|
37 |
+
"execution_count": 2,
|
38 |
+
"id": "f4a5edee",
|
39 |
+
"metadata": {},
|
40 |
+
"outputs": [],
|
41 |
+
"source": [
|
42 |
+
"from transformers import AutoConfig\n"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 3,
|
48 |
+
"id": "48daf2ec",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"\n",
|
53 |
+
"\n",
|
54 |
+
"config = AutoConfig.from_pretrained(\"google/bigbird-roberta-base\")"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 4,
|
60 |
+
"id": "fc816572",
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"config.save_pretrained(\"./\")"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "code",
|
69 |
+
"execution_count": null,
|
70 |
+
"id": "39b9fc3d",
|
71 |
+
"metadata": {},
|
72 |
+
"outputs": [],
|
73 |
+
"source": []
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"id": "ba855add",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": []
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 11,
|
86 |
+
"id": "59076aa7",
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [
|
89 |
+
{
|
90 |
+
"name": "stdout",
|
91 |
+
"output_type": "stream",
|
92 |
+
"text": [
|
93 |
+
"Number of files 20 after adding /data/c4_cleaned\n"
|
94 |
+
]
|
95 |
+
}
|
96 |
+
],
|
97 |
+
"source": [
|
98 |
+
"#59G c4_cleaned compressed\n",
|
99 |
+
"#937M nrc_uniq_cleaned_20210223 compressed\n",
|
100 |
+
"#410M nu_uniq_cleaned_20210225 compressed\n",
|
101 |
+
"#9.9G oscar_nl_cleaned compressed\n",
|
102 |
+
"\n",
|
103 |
+
"\n",
|
104 |
+
"\n",
|
105 |
+
"data_files = []\n",
|
106 |
+
"SEED=42\n",
|
107 |
+
"def add_jsonlines_dir(path):\n",
|
108 |
+
" global data_files\n",
|
109 |
+
" #data_files += glob.glob(f\"{path}/*47*.gz\")\n",
|
110 |
+
" #data_files += glob.glob(f\"{path}/*32*.gz\")\n",
|
111 |
+
" #data_files += glob.glob(f\"{path}/*59*.gz\")\n",
|
112 |
+
" data_files += glob.glob(f\"{path}/*11*.gz\")\n",
|
113 |
+
" print(f\"Number of files {len(data_files)} after adding {path}\")\n",
|
114 |
+
" \n",
|
115 |
+
"add_jsonlines_dir(\"/data/c4_cleaned\")\n",
|
116 |
+
"#add_jsonlines_dir(\"/data/nrc_uniq_cleaned_20210223\")\n",
|
117 |
+
"#add_jsonlines_dir(\"/data/nu_uniq_cleaned_20210225\")\n",
|
118 |
+
"#add_jsonlines_dir(\"/data/oscar_nl_cleaned\") This one gives an error like field url not in \n",
|
119 |
+
"\n"
|
120 |
+
]
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"cell_type": "code",
|
124 |
+
"execution_count": 40,
|
125 |
+
"id": "fc9519d2",
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [
|
128 |
+
{
|
129 |
+
"name": "stdout",
|
130 |
+
"output_type": "stream",
|
131 |
+
"text": [
|
132 |
+
"Number of files 209 after adding /data/oscar_nl_cleaned\n",
|
133 |
+
"95%: 199\n",
|
134 |
+
"Got 199 training files and 10 validation files\n"
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"name": "stderr",
|
139 |
+
"output_type": "stream",
|
140 |
+
"text": [
|
141 |
+
"Using custom data configuration default-00e4c1e272015fdb\n"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"name": "stdout",
|
146 |
+
"output_type": "stream",
|
147 |
+
"text": [
|
148 |
+
"Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/dat/.cache/huggingface/datasets/json/default-00e4c1e272015fdb/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723...\n"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"data": {
|
153 |
+
"application/vnd.jupyter.widget-view+json": {
|
154 |
+
"model_id": "7fc9159a741a4853abb8fa1abcb8bd4c",
|
155 |
+
"version_major": 2,
|
156 |
+
"version_minor": 0
|
157 |
+
},
|
158 |
+
"text/plain": [
|
159 |
+
"0 tables [00:00, ? tables/s]"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
"metadata": {},
|
163 |
+
"output_type": "display_data"
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"data": {
|
167 |
+
"application/vnd.jupyter.widget-view+json": {
|
168 |
+
"model_id": "db9fc4eb87094fa9aef909f8e8d41124",
|
169 |
+
"version_major": 2,
|
170 |
+
"version_minor": 0
|
171 |
+
},
|
172 |
+
"text/plain": [
|
173 |
+
"0 tables [00:00, ? tables/s]"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
"metadata": {},
|
177 |
+
"output_type": "display_data"
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"name": "stdout",
|
181 |
+
"output_type": "stream",
|
182 |
+
"text": [
|
183 |
+
"Dataset json downloaded and prepared to /home/dat/.cache/huggingface/datasets/json/default-00e4c1e272015fdb/0.0.0/f92a4de297ac644ad9781979b79064b0e222b3af766f8ea3bee32390dca23723. Subsequent calls will reuse this data.\n"
|
184 |
+
]
|
185 |
+
}
|
186 |
+
],
|
187 |
+
"source": [
|
188 |
+
"#59G c4_cleaned compressed\n",
|
189 |
+
"#937M nrc_uniq_cleaned_20210223 compressed\n",
|
190 |
+
"#410M nu_uniq_cleaned_20210225 compressed\n",
|
191 |
+
"#9.9G oscar_nl_cleaned compressed\n",
|
192 |
+
"\n",
|
193 |
+
"\n",
|
194 |
+
"\n",
|
195 |
+
"data_files = []\n",
|
196 |
+
"SEED=42\n",
|
197 |
+
"def add_jsonlines_dir(path,filespec):\n",
|
198 |
+
" global data_files\n",
|
199 |
+
" data_files += glob.glob(f\"{path}/{filespec}\")\n",
|
200 |
+
" print(f\"Number of files {len(data_files)} after adding {path}\")\n",
|
201 |
+
" \n",
|
202 |
+
"#add_jsonlines_dir(\"/home/dat/subset_c4_cleannl\",\"*.gz\") \n",
|
203 |
+
"add_jsonlines_dir(\"/data/oscar_nl_cleaned\",\"*.gz\")\n",
|
204 |
+
"#add_jsonlines_dir(\"/data/nrc_cleaned_idtextfmt\",\"*.gz\")\n",
|
205 |
+
"#add_jsonlines_dir(\"/data/nu_cleaned_idtextfmt\",\"*.gz\")\n",
|
206 |
+
"random.Random(SEED).shuffle(data_files)\n",
|
207 |
+
"total = len(data_files)\n",
|
208 |
+
"val_size = int(0.05 * total)\n",
|
209 |
+
"train_size = total - val_size\n",
|
210 |
+
"print(f\"95%: {train_size}\")\n",
|
211 |
+
"train = data_files[:train_size]\n",
|
212 |
+
"val = data_files[train_size:]\n",
|
213 |
+
"print(f\"Got {len(train)} training files and {len(val)} validation files\")\n",
|
214 |
+
"assert list(set(train) & set(val)) == [], \"Train overlaps with test\"\n",
|
215 |
+
"datasets = load_dataset('json', data_files={'train': train, 'validation': val})\n",
|
216 |
+
"\n",
|
217 |
+
"\n",
|
218 |
+
"assert list(set(train) & set(val)) == [], 'train overlaps with test'\n"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": 41,
|
224 |
+
"id": "865a9642",
|
225 |
+
"metadata": {},
|
226 |
+
"outputs": [],
|
227 |
+
"source": [
|
228 |
+
"dataset_iterator = iter(datasets['train'])"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "code",
|
233 |
+
"execution_count": 78,
|
234 |
+
"id": "523b0fc2",
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [
|
237 |
+
{
|
238 |
+
"name": "stdout",
|
239 |
+
"output_type": "stream",
|
240 |
+
"text": [
|
241 |
+
"Zo stel ik het me voor. Tegen iedere conventie in. Och wat heeft de burgerij gemopperd en schande gesproken. Dat was in die dagen. Nu nog steeds, maar anders. Daarover later meer. En wat zullen ze van u gehouden hebben in de kleine kring van liefhebbers.\n",
|
242 |
+
"Jaren geleden, toen ik nog op de academie zat bestudeerde ik uw werk. Vooral de paar overgebleven foto’s van uw Merzbau in Hannover troffen mij. Zo vrij en swingend en onconventioneel.\n",
|
243 |
+
"Ze werden opgeslagen in een afgelegen kamer in mijn geheugen, want eigentijdse choreografen en filmmakers en schilders uit de vroege renaissance vroegen om voorrang.\n",
|
244 |
+
"Toen u het huis van uw ouders in Hannover betrok transformeerde u acht kamers tot een betoverende sculptuur. Merzbau! Kathedrale des erotischen Elend.\n",
|
245 |
+
"In abstracte vlakken en vormen kruipen de volumes chaotisch omhoog langs de muren. Meestal wit. Er vormen zich ruimtes en grotachtige structuren. Hier en daar een typografisch detail of een herkenbaar object, dat uit zijn context geslingerd, vooral vragen oproept. Met hier en daar een antwoord of een vermoeden daarvan.\n",
|
246 |
+
"Soms verborg u zich in het kleine orgelkamertje bovenin als er gasten kwamen, om de reactie op hun gezichten te lezen als ze uw gedichten of het karnavals-achtige nummer Du lieber Augustin door de fantastische ruimte hoorden schallen, een lied vol humor en boerse middeleeuwse wreedheid, maar ook melancholie.\n",
|
247 |
+
"Banale liedjes laten horen in een ruimte die verschillende betekenissen kan hebben. Ik herken dat zo. Wij deden dat ook in het theater.\n",
|
248 |
+
"Ik vraag nu toch uw hand, zo’n beetje dwars door de tijd, om een paar pirouettes te draaien of misschien beter een twist.\n",
|
249 |
+
"Het gewicht van de tapijten of het zeil waaronder ik zowat bezwijk, de inspanning om hoog in de opstelling een klosje op te hangen… Op een gegeven moment raak ik in een staat waarin ik niet meer nadenk. Dan doe ik de ingreep die een beeld uiteindelijk af maakt. Grappig niet?\n",
|
250 |
+
"Ik vermoed dat u dat ook heeft, dat zware fysieke werken aan Merzbau; dat dat fijn is, dat het zo echt is daardoor en dat je uiteindelijk in trance raakt.\n",
|
251 |
+
"Daar leefde u van werken in opdracht; portretten en landschappen. Beeldschoon werk, maar u deed niet anders dan erop mopperen.\n",
|
252 |
+
"Ondertussen begon u een nieuwe Merzbau in een schuur op het platteland. U groef er een verdieping onder en begon daar te merzen. Weer die zware fysieke arbeid. Dat beschouwde u als uw echte werk. Daar legde u ‘connecties tussen alles in uw wereld’, al uw werk ‘een levenslange ervaring’.\n",
|
253 |
+
"Maar uw landschappen hoorden daar niet bij. Dat is nu vreemd, jammer zelfs. Tenminste, gezien vanuit mijn perspectief, vanuit het heden. Ze komen immers uit dezelfde bron. Is het omdat ze niet abstract zijn?\n",
|
254 |
+
"Per Kirkeby is een beroemd Deens schilder en beeldhouwer, graficus en dichter. Nu tachtig jaar oud. U zou hem weten te waarderen. Ook niet binnen een -isme te vangen. Hij heeft heel mooi over zuivere en onzuivere kunst gesproken. Dit klinkt een beetje eng maar ging over zuiver in de zin van kaal en zonder betekenis en in het onzuivere zaten alle associaties en verwijzingen.\n",
|
255 |
+
"In míjn werk houd ik van de associaties en verwijzingen. Maar we leven nu in een andere tijd. Pure abstractie wordt zeker nog gevierd door sommige kunstenaars, en zeker niet de minsten, maar de revolutie die het in uw tijd ontketende is uitgewoed.\n",
|
256 |
+
"Ik houd ervan dat in mijn werk niks helemaal lijkt te kloppen, maar er is wel samenhang. De objecten zijn volgens een innerlijke logica gekozen. Maar het mag geen surealisme worden. Daar houd ik niet van. Het is een smalle marge waarin ze mogen bestaan.\n",
|
257 |
+
"Het gaat vreemd genoeg volgens schilderkunstige principes, al komt er geen verf aan te pas. Ik bouw mijn opstellingen laag voor laag op. Vanuit de achtergrond. Ik doe weg, of bedek wat te makkelijk te duiden is en daarmee het beeld plat slaat, of wat ik te mooi of esthetisch vind. Soms draait het zich om, behoud ik juist wat mooi of betekenisvol is. Ik zet voortdurend voetangels en klemmen voor mijzelf. En ik geloof dat dat de kwaliteit van het werk uitmaakt.\n",
|
258 |
+
"Ik vraag me af in hoeverre dit een wet is die voor alle kunst opgaat. Ik geloof het wel. Al gebeurt het soms alleen in het denkproces dat vooraf gaat aan de uitvoering van het werk.\n",
|
259 |
+
"Ik ken het in ieder geval heel goed uit mijn theaterwerk. Dat schaven aan een productie tot alle puzzelstukken op hun plaats vallen.\n",
|
260 |
+
"Ik kan mij voorstellen dat dat zelfs bij Mondriaan gebeurde. Zijn Victory Boogy Woogy heeft zo iets magisch ongrijpbaars. En toch staan alle vlakken gewoon op hun plek. Daar is zoveel jaar werk voor nodig geweest!\n",
|
261 |
+
"In zijn vroege werken, ook landschappen en bomen, proef je wat er allemaal in zit. In die man bedoel ik en in die doeken.\n",
|
262 |
+
"Ik wil maar zeggen, die landschappen van u zijn denk ik toch met dezelfde mentaliteit gemaakt als uw dichtwerk of Merzbau. Ze zijn in ieder geval door u gemaakt. Met uw hand, uw geest, uw afwegingen tijdens het schilderen. Dit wel, dit niet.\n",
|
263 |
+
"Maar niet mystiek of transcendent? Ik lees in andere bronnen over Dada’s grondslag; Boeddhisme, Taoisme, vroegchristelijke mystici, en over filosofen als Bergson, Nietzsche en Descartes. Nogal tegenstrijdig allemaal.\n",
|
264 |
+
"En dat DaDa niets is, dat wil zeggen alles, of het niet-iets, of een vogel op vier poten, of een levensverzekering of een ladder zonder sporten….\n",
|
265 |
+
"Ik heb een leven lang studie en kijken en nog eens kijken voor me, om dit alles te doorvorsen. Maar begrijpen doe ik het al. Op m’n intuïtie.\n"
|
266 |
+
]
|
267 |
+
}
|
268 |
+
],
|
269 |
+
"source": [
|
270 |
+
"print(next(dataset_iterator)['text'])"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": 31,
|
276 |
+
"id": "b5839c79",
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [
|
279 |
+
{
|
280 |
+
"ename": "IndentationError",
|
281 |
+
"evalue": "unexpected indent (1021262509.py, line 15)",
|
282 |
+
"output_type": "error",
|
283 |
+
"traceback": [
|
284 |
+
"\u001b[0;36m File \u001b[0;32m\"/tmp/ipykernel_309684/1021262509.py\"\u001b[0;36m, line \u001b[0;32m15\u001b[0m\n\u001b[0;31m train, val = train_val_files()\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mIndentationError\u001b[0m\u001b[0;31m:\u001b[0m unexpected indent\n"
|
285 |
+
]
|
286 |
+
}
|
287 |
+
],
|
288 |
+
"source": [
|
289 |
+
"\n",
|
290 |
+
" add_jsonlines_dir(\"/home/dat/subset_c4_cleannl\") \n",
|
291 |
+
" add_jsonlines_dir(\"/data/oscar_nl_cleaned\")\n",
|
292 |
+
" add_jsonlines_dir(\"/data/nrc_cleaned_idtextfmt\")\n",
|
293 |
+
" add_jsonlines_dir(\"/data/nu_cleaned_idtextfmt\")\n",
|
294 |
+
" random.Random(SEED).shuffle(data_files)\n",
|
295 |
+
" total = len(data_files)\n",
|
296 |
+
" val_size = int(0.05 * total)\n",
|
297 |
+
" train_size = total - val_size\n",
|
298 |
+
" print(f\"95%: {train_size}\")\n",
|
299 |
+
" train = data_files\n",
|
300 |
+
" val = data_files\n",
|
301 |
+
" print(f\"Got {len(train)} training files and {len(val)} validation files\")\n",
|
302 |
+
" assert list(set(train) & set(val)) == [], \"Train overlaps with test\"\n",
|
303 |
+
" return train, val\n",
|
304 |
+
" train, val = train_val_files()\n",
|
305 |
+
" datasets = load_dataset('json', data_files={'train': train, 'validation': val})"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"execution_count": 4,
|
311 |
+
"id": "6685589f",
|
312 |
+
"metadata": {},
|
313 |
+
"outputs": [
|
314 |
+
{
|
315 |
+
"name": "stdout",
|
316 |
+
"output_type": "stream",
|
317 |
+
"text": [
|
318 |
+
"\n",
|
319 |
+
"\n",
|
320 |
+
"\n"
|
321 |
+
]
|
322 |
+
}
|
323 |
+
],
|
324 |
+
"source": [
|
325 |
+
"from tokenizers import ByteLevelBPETokenizer\n",
|
326 |
+
"tokenizer = ByteLevelBPETokenizer()\n",
|
327 |
+
"\n",
|
328 |
+
"def batch_iterator(batch_size=1000):\n",
|
329 |
+
" for i in range(0, len(datasets), batch_size):\n",
|
330 |
+
" yield datasets[\"train\"][i: i + batch_size][\"text\"]\n",
|
331 |
+
"\n",
|
332 |
+
"tokenizer.train_from_iterator(batch_iterator(), vocab_size=50358, min_frequency=2, special_tokens=[\n",
|
333 |
+
" \"<s>\",\n",
|
334 |
+
" \"<pad>\",\n",
|
335 |
+
" \"</s>\",\n",
|
336 |
+
" \"<unk>\",\n",
|
337 |
+
" \"<mask>\",\n",
|
338 |
+
"])"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"cell_type": "code",
|
343 |
+
"execution_count": 5,
|
344 |
+
"id": "5fed49b4",
|
345 |
+
"metadata": {},
|
346 |
+
"outputs": [
|
347 |
+
{
|
348 |
+
"data": {
|
349 |
+
"text/plain": [
|
350 |
+
"39503"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
"execution_count": 5,
|
354 |
+
"metadata": {},
|
355 |
+
"output_type": "execute_result"
|
356 |
+
}
|
357 |
+
],
|
358 |
+
"source": [
|
359 |
+
"tokenizer.get_vocab_size()"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": 6,
|
365 |
+
"id": "69401680",
|
366 |
+
"metadata": {},
|
367 |
+
"outputs": [
|
368 |
+
{
|
369 |
+
"name": "stdout",
|
370 |
+
"output_type": "stream",
|
371 |
+
"text": [
|
372 |
+
"/home/dat/pino-roberta-base\n"
|
373 |
+
]
|
374 |
+
}
|
375 |
+
],
|
376 |
+
"source": [
|
377 |
+
"cd ~/pino-roberta-base"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "code",
|
382 |
+
"execution_count": 7,
|
383 |
+
"id": "7a98d754",
|
384 |
+
"metadata": {},
|
385 |
+
"outputs": [],
|
386 |
+
"source": [
|
387 |
+
"tokenizer.save(\"tokenizer.json\")"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"cell_type": "code",
|
392 |
+
"execution_count": null,
|
393 |
+
"id": "e686b9c8",
|
394 |
+
"metadata": {},
|
395 |
+
"outputs": [
|
396 |
+
{
|
397 |
+
"name": "stderr",
|
398 |
+
"output_type": "stream",
|
399 |
+
"text": [
|
400 |
+
"Using custom data configuration nl-lang=nl\n"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"name": "stdout",
|
405 |
+
"output_type": "stream",
|
406 |
+
"text": [
|
407 |
+
"Downloading and preparing dataset cc100/nl (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/dat/.cache/huggingface/datasets/cc100/nl-lang=nl/0.0.0/b583dd47b0dd43a3c3773075abd993be12d0eee93dbd2cfe15a0e4e94d481e80...\n"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"data": {
|
412 |
+
"application/vnd.jupyter.widget-view+json": {
|
413 |
+
"model_id": "8bb6155775084c42841d5a786a3f014c",
|
414 |
+
"version_major": 2,
|
415 |
+
"version_minor": 0
|
416 |
+
},
|
417 |
+
"text/plain": [
|
418 |
+
"Downloading: 0%| | 0.00/8.42G [00:00<?, ?B/s]"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
"metadata": {},
|
422 |
+
"output_type": "display_data"
|
423 |
+
}
|
424 |
+
],
|
425 |
+
"source": [
|
426 |
+
"dataset1 = load_dataset(\"mc4\", \"nl\", streaming=True)\n",
|
427 |
+
"dataset2 = load_dataset(\"oscar\", \"unshuffled_deduplicated_nl\",streaming=True)\n",
|
428 |
+
"dataset3 = load_dataset(\"cc100\", lang=\"nl\")\n",
|
429 |
+
"\n"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "code",
|
434 |
+
"execution_count": 14,
|
435 |
+
"id": "1e1498d1",
|
436 |
+
"metadata": {},
|
437 |
+
"outputs": [
|
438 |
+
{
|
439 |
+
"name": "stderr",
|
440 |
+
"output_type": "stream",
|
441 |
+
"text": [
|
442 |
+
"INFO:absl:Starting the local TPU driver.\n",
|
443 |
+
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
444 |
+
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n",
|
445 |
+
"Some weights of FlaxBigBirdModel were not initialized from the model checkpoint at flax-community/pino-roberta-base and are newly initialized: {('pooler', 'kernel'), ('pooler', 'bias')}\n",
|
446 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
447 |
+
]
|
448 |
+
}
|
449 |
+
],
|
450 |
+
"source": [
|
451 |
+
"from transformers import AutoTokenizer, RobertaModel\n",
|
452 |
+
"from transformers import BigBirdForSequenceClassification,FlaxBigBirdModel,FlaxBigBirdForMaskedLM\n",
|
453 |
+
"\n",
|
454 |
+
"model = FlaxBigBirdModel.from_pretrained(\"flax-community/pino-roberta-base\")\n",
|
455 |
+
"model.save_pretrained('exported_pytorch_model')"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "code",
|
460 |
+
"execution_count": null,
|
461 |
+
"id": "82f2a9b7",
|
462 |
+
"metadata": {},
|
463 |
+
"outputs": [],
|
464 |
+
"source": []
|
465 |
+
}
|
466 |
+
],
|
467 |
+
"metadata": {
|
468 |
+
"kernelspec": {
|
469 |
+
"display_name": "Python 3 (ipykernel)",
|
470 |
+
"language": "python",
|
471 |
+
"name": "python3"
|
472 |
+
},
|
473 |
+
"language_info": {
|
474 |
+
"codemirror_mode": {
|
475 |
+
"name": "ipython",
|
476 |
+
"version": 3
|
477 |
+
},
|
478 |
+
"file_extension": ".py",
|
479 |
+
"mimetype": "text/x-python",
|
480 |
+
"name": "python",
|
481 |
+
"nbconvert_exporter": "python",
|
482 |
+
"pygments_lexer": "ipython3",
|
483 |
+
"version": "3.8.10"
|
484 |
+
}
|
485 |
+
},
|
486 |
+
"nbformat": 4,
|
487 |
+
"nbformat_minor": 5
|
488 |
+
}
|
config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BigBirdForPreTraining"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"attention_type": "block_sparse",
|
7 |
+
"block_size": 128,
|
8 |
+
"bos_token_id": 1,
|
9 |
+
"eos_token_id": 2,
|
10 |
+
"gradient_checkpointing": false,
|
11 |
+
"hidden_act": "gelu_new",
|
12 |
+
"hidden_dropout_prob": 0.1,
|
13 |
+
"hidden_size": 768,
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 3072,
|
16 |
+
"layer_norm_eps": 1e-12,
|
17 |
+
"max_position_embeddings": 4096,
|
18 |
+
"model_type": "big_bird",
|
19 |
+
"num_attention_heads": 12,
|
20 |
+
"num_hidden_layers": 12,
|
21 |
+
"num_random_blocks": 3,
|
22 |
+
"pad_token_id": 0,
|
23 |
+
"position_embedding_type": "absolute",
|
24 |
+
"rescale_embeddings": false,
|
25 |
+
"sep_token_id": 66,
|
26 |
+
"transformers_version": "4.9.0.dev0",
|
27 |
+
"type_vocab_size": 2,
|
28 |
+
"use_bias": true,
|
29 |
+
"use_cache": true,
|
30 |
+
"vocab_size": 50358
|
31 |
+
}
|
run.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
export TOKENIZERS_PARALLELISM=0
|
4 |
+
|
5 |
+
python ./run_mlm_flax.py \
|
6 |
+
--push_to_hub \
|
7 |
+
--output_dir="./" \
|
8 |
+
--model_type="big_bird" \
|
9 |
+
--config_name="./" \
|
10 |
+
--tokenizer_name="./" \
|
11 |
+
--max_seq_length="4096" \
|
12 |
+
--weight_decay="0.0095" \
|
13 |
+
--warmup_steps="5000" \
|
14 |
+
--overwrite_output_dir \
|
15 |
+
--adam_beta1="0.9" \
|
16 |
+
--adam_beta2="0.98" \
|
17 |
+
--logging_steps="500" \
|
18 |
+
--eval_steps="92768" \
|
19 |
+
--num_train_epochs="5" \
|
20 |
+
--preprocessing_num_workers="64" \
|
21 |
+
--save_steps="20000" \
|
22 |
+
--adafactor \
|
23 |
+
--learning_rate="5e-5" \
|
24 |
+
--per_device_train_batch_size="2" \
|
25 |
+
--per_device_eval_batch_size="2" \
|
26 |
+
--save_total_limit="5"\
|
27 |
+
--dtype="bfloat16" \
|
28 |
+
#--resume_from_checkpoint="./"\
|
29 |
+
#--gradient_accumulation_steps="4" \
|
30 |
+
|
run_mlm_flax.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2021 The HuggingFace Team All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
|
18 |
+
text file or a dataset.
|
19 |
+
|
20 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
21 |
+
https://huggingface.co/models?filter=masked-lm
|
22 |
+
"""
|
23 |
+
import shutil
|
24 |
+
import logging
|
25 |
+
import os
|
26 |
+
import sys
|
27 |
+
import time
|
28 |
+
from dataclasses import dataclass, field
|
29 |
+
from ast import Str
|
30 |
+
|
31 |
+
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
32 |
+
from pathlib import Path
|
33 |
+
from typing import Dict, List, Optional, Tuple
|
34 |
+
|
35 |
+
import numpy as np
|
36 |
+
from datasets import load_dataset
|
37 |
+
from tqdm import tqdm
|
38 |
+
|
39 |
+
import flax
|
40 |
+
import jax
|
41 |
+
import jax.numpy as jnp
|
42 |
+
import optax
|
43 |
+
from flax import jax_utils, traverse_util
|
44 |
+
from flax.training import train_state
|
45 |
+
from flax.training.common_utils import get_metrics, onehot, shard
|
46 |
+
from transformers import (
|
47 |
+
CONFIG_MAPPING,
|
48 |
+
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
49 |
+
AutoConfig,
|
50 |
+
AutoTokenizer,
|
51 |
+
FlaxAutoModelForMaskedLM,
|
52 |
+
HfArgumentParser,
|
53 |
+
PreTrainedTokenizerBase,
|
54 |
+
TensorType,
|
55 |
+
TrainingArguments,
|
56 |
+
is_tensorboard_available,
|
57 |
+
set_seed,
|
58 |
+
)
|
59 |
+
from transformers.testing_utils import CaptureLogger
|
60 |
+
from flax.serialization import to_bytes, from_bytes
|
61 |
+
from importlib.util import find_spec
|
62 |
+
from flax.training import checkpoints
|
63 |
+
from flax.jax_utils import unreplicate
|
64 |
+
from flax.training.checkpoints import save_checkpoint, restore_checkpoint
|
65 |
+
import json
|
66 |
+
|
67 |
+
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
68 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class ModelArguments:
|
73 |
+
"""
|
74 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
75 |
+
"""
|
76 |
+
|
77 |
+
model_name_or_path: Optional[str] = field(
|
78 |
+
default=None,
|
79 |
+
metadata={
|
80 |
+
"help": "The model checkpoint for weights initialization."
|
81 |
+
"Don't set if you want to train a model from scratch."
|
82 |
+
},
|
83 |
+
)
|
84 |
+
model_type: Optional[str] = field(
|
85 |
+
default=None,
|
86 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
87 |
+
)
|
88 |
+
config_name: Optional[str] = field(
|
89 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
90 |
+
)
|
91 |
+
tokenizer_name: Optional[str] = field(
|
92 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
93 |
+
)
|
94 |
+
cache_dir: Optional[str] = field(
|
95 |
+
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
96 |
+
)
|
97 |
+
use_fast_tokenizer: bool = field(
|
98 |
+
default=True,
|
99 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
100 |
+
)
|
101 |
+
dtype: Optional[str] = field(
|
102 |
+
default="float32",
|
103 |
+
metadata={
|
104 |
+
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
105 |
+
},
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
@dataclass
|
112 |
+
class DataTrainingArguments:
|
113 |
+
"""
|
114 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
115 |
+
"""
|
116 |
+
|
117 |
+
dataset_name: Optional[str] = field(
|
118 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
119 |
+
)
|
120 |
+
dataset_config_name: Optional[str] = field(
|
121 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
122 |
+
)
|
123 |
+
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
124 |
+
validation_file: Optional[str] = field(
|
125 |
+
default=None,
|
126 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
127 |
+
)
|
128 |
+
train_ref_file: Optional[str] = field(
|
129 |
+
default=None,
|
130 |
+
metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
|
131 |
+
)
|
132 |
+
validation_ref_file: Optional[str] = field(
|
133 |
+
default=None,
|
134 |
+
metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
|
135 |
+
)
|
136 |
+
overwrite_cache: bool = field(
|
137 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
138 |
+
)
|
139 |
+
validation_split_percentage: Optional[int] = field(
|
140 |
+
default=5,
|
141 |
+
metadata={
|
142 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
143 |
+
},
|
144 |
+
)
|
145 |
+
max_seq_length: Optional[int] = field(
|
146 |
+
default=None,
|
147 |
+
metadata={
|
148 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
149 |
+
"than this will be truncated. Default to the max input length of the model."
|
150 |
+
},
|
151 |
+
)
|
152 |
+
preprocessing_num_workers: Optional[int] = field(
|
153 |
+
default=None,
|
154 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
155 |
+
)
|
156 |
+
mlm_probability: float = field(
|
157 |
+
default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
|
158 |
+
)
|
159 |
+
pad_to_max_length: bool = field(
|
160 |
+
default=False,
|
161 |
+
metadata={
|
162 |
+
"help": "Whether to pad all samples to `max_seq_length`. "
|
163 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
164 |
+
},
|
165 |
+
)
|
166 |
+
line_by_line: bool = field(
|
167 |
+
default=False,
|
168 |
+
metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
@flax.struct.dataclass
|
173 |
+
class FlaxDataCollatorForLanguageModeling:
|
174 |
+
"""
|
175 |
+
Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
|
176 |
+
are not all of the same length.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
180 |
+
The tokenizer used for encoding the data.
|
181 |
+
mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
|
182 |
+
The probability with which to (randomly) mask tokens in the input.
|
183 |
+
|
184 |
+
.. note::
|
185 |
+
|
186 |
+
For best performance, this data collator should be used with a dataset having items that are dictionaries or
|
187 |
+
BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
|
188 |
+
:class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
|
189 |
+
argument :obj:`return_special_tokens_mask=True`.
|
190 |
+
"""
|
191 |
+
|
192 |
+
tokenizer: PreTrainedTokenizerBase
|
193 |
+
mlm_probability: float = 0.15
|
194 |
+
|
195 |
+
def __post_init__(self):
|
196 |
+
if self.tokenizer.mask_token is None:
|
197 |
+
raise ValueError(
|
198 |
+
"This tokenizer does not have a mask token which is necessary for masked language modeling. "
|
199 |
+
"You should pass `mlm=False` to train on causal language modeling instead."
|
200 |
+
)
|
201 |
+
|
202 |
+
def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
|
203 |
+
# Handle dict or lists with proper padding and conversion to tensor.
|
204 |
+
batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
|
205 |
+
|
206 |
+
# If special token mask has been preprocessed, pop it from the dict.
|
207 |
+
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
208 |
+
|
209 |
+
batch["input_ids"], batch["labels"] = self.mask_tokens(
|
210 |
+
batch["input_ids"], special_tokens_mask=special_tokens_mask
|
211 |
+
)
|
212 |
+
return batch
|
213 |
+
|
214 |
+
def mask_tokens(
|
215 |
+
self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
|
216 |
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
217 |
+
"""
|
218 |
+
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
|
219 |
+
"""
|
220 |
+
labels = inputs.copy()
|
221 |
+
# We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
|
222 |
+
probability_matrix = np.full(labels.shape, self.mlm_probability)
|
223 |
+
special_tokens_mask = special_tokens_mask.astype("bool")
|
224 |
+
|
225 |
+
probability_matrix[special_tokens_mask] = 0.0
|
226 |
+
masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
|
227 |
+
labels[~masked_indices] = -100 # We only compute loss on masked tokens
|
228 |
+
|
229 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
230 |
+
indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
|
231 |
+
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
232 |
+
|
233 |
+
# 10% of the time, we replace masked input tokens with random word
|
234 |
+
indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
|
235 |
+
indices_random &= masked_indices & ~indices_replaced
|
236 |
+
|
237 |
+
random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
|
238 |
+
inputs[indices_random] = random_words[indices_random]
|
239 |
+
|
240 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
241 |
+
return inputs, labels
|
242 |
+
|
243 |
+
|
244 |
+
def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
|
245 |
+
num_samples = len(samples_idx)
|
246 |
+
samples_to_remove = num_samples % batch_size
|
247 |
+
|
248 |
+
if samples_to_remove != 0:
|
249 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
250 |
+
sections_split = num_samples // batch_size
|
251 |
+
batch_idx = np.split(samples_idx, sections_split)
|
252 |
+
return batch_idx
|
253 |
+
|
254 |
+
|
255 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
256 |
+
summary_writer.scalar("train_time", train_time, step)
|
257 |
+
|
258 |
+
train_metrics = get_metrics(train_metrics)
|
259 |
+
for key, vals in train_metrics.items():
|
260 |
+
tag = f"train_{key}"
|
261 |
+
for i, val in enumerate(vals):
|
262 |
+
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
263 |
+
|
264 |
+
|
265 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
266 |
+
for metric_name, value in eval_metrics.items():
|
267 |
+
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
268 |
+
|
269 |
+
def mb_item(x):
|
270 |
+
return x.item() if hasattr(x, "item") else x
|
271 |
+
|
272 |
+
#checkpoint functions
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
|
279 |
+
"Removes older checkpoints so that `save_total_limit` checkpoints are kept"
|
280 |
+
# TODO: what to remove is decided using step number only, we might want to improve that
|
281 |
+
ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
|
282 |
+
# sort checkpoints by step
|
283 |
+
ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
|
284 |
+
ckpts_to_delete = ckpts_sorted[:-save_total_limit]
|
285 |
+
for ckpt in ckpts_to_delete:
|
286 |
+
logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
|
287 |
+
shutil.rmtree(ckpt)
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
if __name__ == "__main__":
|
296 |
+
# See all possible arguments in src/transformers/training_args.py
|
297 |
+
# or by passing the --help flag to this script.
|
298 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
299 |
+
|
300 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
301 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
302 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
303 |
+
# let's parse it to get our arguments.
|
304 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
305 |
+
else:
|
306 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
307 |
+
|
308 |
+
if (
|
309 |
+
os.path.exists(training_args.output_dir)
|
310 |
+
and os.listdir(training_args.output_dir)
|
311 |
+
and training_args.do_train
|
312 |
+
and not training_args.overwrite_output_dir
|
313 |
+
):
|
314 |
+
raise ValueError(
|
315 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty."
|
316 |
+
"Use --overwrite_output_dir to overcome."
|
317 |
+
)
|
318 |
+
|
319 |
+
# Setup logging
|
320 |
+
logging.basicConfig(
|
321 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
322 |
+
level="NOTSET",
|
323 |
+
datefmt="[%X]",
|
324 |
+
)
|
325 |
+
|
326 |
+
# Log on each process the small summary:
|
327 |
+
logger = logging.getLogger(__name__)
|
328 |
+
|
329 |
+
# Set the verbosity to info of the Transformers logger (on main process only):
|
330 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
331 |
+
|
332 |
+
# Set seed before initializing model.
|
333 |
+
set_seed(training_args.seed)
|
334 |
+
|
335 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
336 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
337 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
338 |
+
#
|
339 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
340 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
341 |
+
#
|
342 |
+
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
343 |
+
# download the dataset.
|
344 |
+
if data_args.dataset_name is not None:
|
345 |
+
# Downloading and loading a dataset from the hub.
|
346 |
+
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
347 |
+
|
348 |
+
if "validation" not in datasets.keys():
|
349 |
+
datasets["validation"] = load_dataset(
|
350 |
+
data_args.dataset_name,
|
351 |
+
data_args.dataset_config_name,
|
352 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
353 |
+
cache_dir=model_args.cache_dir,
|
354 |
+
)
|
355 |
+
datasets["train"] = load_dataset(
|
356 |
+
data_args.dataset_name,
|
357 |
+
data_args.dataset_config_name,
|
358 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
359 |
+
cache_dir=model_args.cache_dir,
|
360 |
+
)
|
361 |
+
else:
|
362 |
+
#data_files = {}
|
363 |
+
#if data_args.train_file is not None:
|
364 |
+
# data_files["train"] = data_args.train_file
|
365 |
+
#if data_args.validation_file is not None:
|
366 |
+
# data_files["validation"] = data_args.validation_file
|
367 |
+
#extension = data_args.train_file.split(".")[-1]
|
368 |
+
#if extension == "txt":
|
369 |
+
# extension = "text"
|
370 |
+
#datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
371 |
+
|
372 |
+
#data_dir = "/home/yeb"
|
373 |
+
# data_dir = "/home/yeb/Developer/data"
|
374 |
+
data_files = []
|
375 |
+
def train_val_files():
|
376 |
+
import glob
|
377 |
+
import random
|
378 |
+
SEED = 42
|
379 |
+
def add_jsonlines_dir(path):
|
380 |
+
global data_files
|
381 |
+
data_files += glob.glob(f"{path}/*.gz")
|
382 |
+
|
383 |
+
add_jsonlines_dir("/home/dat/subset_c4_cleannl")
|
384 |
+
add_jsonlines_dir("/data/oscar_nl_cleaned")
|
385 |
+
add_jsonlines_dir("/data/nrc_cleaned_idtextfmt")
|
386 |
+
add_jsonlines_dir("/data/nu_cleaned_idtextfmt")
|
387 |
+
random.Random(SEED).shuffle(data_files)
|
388 |
+
total = len(data_files)
|
389 |
+
val_size = int(0.05 * total)
|
390 |
+
train_size = total - val_size
|
391 |
+
print(f"95%: {train_size}")
|
392 |
+
train = data_files[:train_size]
|
393 |
+
val = data_files[train_size:]
|
394 |
+
print(f"Got {len(train)} training files and {len(val)} validation files")
|
395 |
+
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
396 |
+
return train, val
|
397 |
+
train, val = train_val_files()
|
398 |
+
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
399 |
+
datasets["train"] = datasets["train"].select(range(int(0.8*len(datasets["train"]))))
|
400 |
+
datasets["validation"] = datasets["validation"].select(range(int(0.8*len(datasets["validation"]))))
|
401 |
+
|
402 |
+
|
403 |
+
|
404 |
+
|
405 |
+
if model_args.config_name:
|
406 |
+
config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
|
407 |
+
elif model_args.model_name_or_path:
|
408 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
|
409 |
+
else:
|
410 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
411 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
412 |
+
|
413 |
+
if model_args.tokenizer_name:
|
414 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
415 |
+
model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
416 |
+
)
|
417 |
+
elif model_args.model_name_or_path:
|
418 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
419 |
+
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
raise ValueError(
|
423 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
424 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
425 |
+
)
|
426 |
+
|
427 |
+
# Preprocessing the datasets.
|
428 |
+
# First we tokenize all the texts.
|
429 |
+
if training_args.do_train:
|
430 |
+
column_names = datasets["train"].column_names
|
431 |
+
else:
|
432 |
+
column_names = datasets["validation"].column_names
|
433 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
434 |
+
|
435 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
436 |
+
|
437 |
+
|
438 |
+
if data_args.line_by_line:
|
439 |
+
# When using line_by_line, we just tokenize each nonempty line.
|
440 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
441 |
+
|
442 |
+
def tokenize_function(examples):
|
443 |
+
# Remove empty lines
|
444 |
+
examples = [line for line in examples if len(line) > 0 and not line.isspace()]
|
445 |
+
return tokenizer(
|
446 |
+
examples,
|
447 |
+
return_special_tokens_mask=True,
|
448 |
+
padding=padding,
|
449 |
+
truncation=True,
|
450 |
+
max_length=max_seq_length,
|
451 |
+
)
|
452 |
+
|
453 |
+
tokenized_datasets = datasets.map(
|
454 |
+
tokenize_function,
|
455 |
+
input_columns=[text_column_name],
|
456 |
+
batched=True,
|
457 |
+
num_proc=data_args.preprocessing_num_workers,
|
458 |
+
remove_columns=column_names,
|
459 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
460 |
+
)
|
461 |
+
|
462 |
+
else:
|
463 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
464 |
+
# We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
|
465 |
+
# efficient when it receives the `special_tokens_mask`.
|
466 |
+
def tokenize_function(examples):
|
467 |
+
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
|
468 |
+
|
469 |
+
tokenized_datasets = datasets.map(
|
470 |
+
tokenize_function,
|
471 |
+
batched=True,
|
472 |
+
num_proc=data_args.preprocessing_num_workers,
|
473 |
+
remove_columns=column_names,
|
474 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
475 |
+
)
|
476 |
+
|
477 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
478 |
+
# max_seq_length.
|
479 |
+
def group_texts(examples):
|
480 |
+
# Concatenate all texts.
|
481 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
482 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
483 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
484 |
+
# customize this part to your needs.
|
485 |
+
if total_length >= max_seq_length:
|
486 |
+
total_length = (total_length // max_seq_length) * max_seq_length
|
487 |
+
# Split by chunks of max_len.
|
488 |
+
result = {
|
489 |
+
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
|
490 |
+
for k, t in concatenated_examples.items()
|
491 |
+
}
|
492 |
+
return result
|
493 |
+
|
494 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
495 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
496 |
+
# might be slower to preprocess.
|
497 |
+
#
|
498 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
499 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
500 |
+
lm_datasets = tokenized_datasets.map(
|
501 |
+
group_texts,
|
502 |
+
batched=True,
|
503 |
+
batch_size=100,
|
504 |
+
num_proc=data_args.preprocessing_num_workers,
|
505 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
506 |
+
)
|
507 |
+
train_dataset = lm_datasets["train"]
|
508 |
+
eval_dataset = lm_datasets["validation"]
|
509 |
+
|
510 |
+
|
511 |
+
|
512 |
+
|
513 |
+
# Enable tensorboard only on the master node
|
514 |
+
has_tensorboard = is_tensorboard_available()
|
515 |
+
if has_tensorboard and jax.process_index() == 0:
|
516 |
+
try:
|
517 |
+
from flax.metrics.tensorboard import SummaryWriter
|
518 |
+
|
519 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
520 |
+
except ImportError as ie:
|
521 |
+
has_tensorboard = False
|
522 |
+
logger.warning(
|
523 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
524 |
+
)
|
525 |
+
else:
|
526 |
+
logger.warning(
|
527 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
528 |
+
"Please run pip install tensorboard to enable."
|
529 |
+
)
|
530 |
+
# enable wandb tracking
|
531 |
+
has_wandb = find_spec("wandb") is not None
|
532 |
+
if jax.process_index() == 0 and has_wandb and ("wandb" in training_args.report_to):
|
533 |
+
try:
|
534 |
+
import wandb
|
535 |
+
wandb.init(
|
536 |
+
entity="wandb",
|
537 |
+
project="hf-flax-pino-roberta",
|
538 |
+
sync_tensorboard=True
|
539 |
+
)
|
540 |
+
wandb.config.update(training_args)
|
541 |
+
wandb.config.update(model_args)
|
542 |
+
wandb.config.update(data_args)
|
543 |
+
except ImportError as e:
|
544 |
+
print(e)
|
545 |
+
has_wandb = False
|
546 |
+
|
547 |
+
# Data collator
|
548 |
+
# This one will take care of randomly masking the tokens.
|
549 |
+
data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
|
550 |
+
|
551 |
+
# Initialize our training
|
552 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
553 |
+
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
554 |
+
|
555 |
+
if model_args.model_name_or_path:
|
556 |
+
model = FlaxAutoModelForMaskedLM.from_pretrained(
|
557 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
558 |
+
)
|
559 |
+
else:
|
560 |
+
model = FlaxAutoModelForMaskedLM.from_config(
|
561 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
562 |
+
)
|
563 |
+
|
564 |
+
# Store some constant
|
565 |
+
num_epochs = int(training_args.num_train_epochs)
|
566 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
|
567 |
+
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
568 |
+
|
569 |
+
num_train_steps = len(train_dataset) // train_batch_size * num_epochs
|
570 |
+
|
571 |
+
# Create learning rate schedule
|
572 |
+
warmup_fn = optax.linear_schedule(
|
573 |
+
init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
|
574 |
+
)
|
575 |
+
decay_fn = optax.linear_schedule(
|
576 |
+
init_value=training_args.learning_rate,
|
577 |
+
end_value=0,
|
578 |
+
transition_steps=num_train_steps - training_args.warmup_steps,
|
579 |
+
)
|
580 |
+
linear_decay_lr_schedule_fn = optax.join_schedules(
|
581 |
+
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
|
582 |
+
)
|
583 |
+
|
584 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
585 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
586 |
+
# mask boolean with the same structure as the parameters.
|
587 |
+
# The mask is True for parameters that should be decayed.
|
588 |
+
# Note that this mask is specifically adapted for FlaxBERT-like models.
|
589 |
+
# For other models, one should correct the layer norm parameter naming
|
590 |
+
# accordingly.
|
591 |
+
def decay_mask_fn(params):
|
592 |
+
flat_params = traverse_util.flatten_dict(params)
|
593 |
+
flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
|
594 |
+
return traverse_util.unflatten_dict(flat_mask)
|
595 |
+
|
596 |
+
# create adam optimizer
|
597 |
+
if training_args.adafactor:
|
598 |
+
# We use the default parameters here to initialize adafactor,
|
599 |
+
# For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
|
600 |
+
optimizer = optax.adafactor(
|
601 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
602 |
+
)
|
603 |
+
else:
|
604 |
+
optimizer = optax.adamw(
|
605 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
606 |
+
b1=training_args.adam_beta1,
|
607 |
+
b2=training_args.adam_beta2,
|
608 |
+
eps=training_args.adam_epsilon,
|
609 |
+
weight_decay=training_args.weight_decay,
|
610 |
+
mask=decay_mask_fn,
|
611 |
+
)
|
612 |
+
|
613 |
+
if training_args.gradient_accumulation_steps > 1:
|
614 |
+
optimizer = optax.MultiSteps(optimizer, training_args.gradient_accumulation_steps)
|
615 |
+
grad_accum_steps = training_args.gradient_accumulation_steps
|
616 |
+
|
617 |
+
# Setup train state
|
618 |
+
|
619 |
+
|
620 |
+
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer)
|
621 |
+
|
622 |
+
if training_args.resume_from_checkpoint:
|
623 |
+
state = restore_checkpoint(training_args.resume_from_checkpoint, state)
|
624 |
+
resume_step = mb_item(state.step.item())
|
625 |
+
else:
|
626 |
+
resume_step = 0
|
627 |
+
|
628 |
+
|
629 |
+
# Define gradient update step fn
|
630 |
+
def train_step(state, batch, dropout_rng):
|
631 |
+
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
632 |
+
|
633 |
+
def loss_fn(params):
|
634 |
+
labels = batch.pop("labels")
|
635 |
+
|
636 |
+
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
637 |
+
|
638 |
+
# compute loss, ignore padded input tokens
|
639 |
+
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
640 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
641 |
+
|
642 |
+
# take average
|
643 |
+
loss = loss.sum() / label_mask.sum()
|
644 |
+
|
645 |
+
return loss
|
646 |
+
|
647 |
+
grad_fn = jax.value_and_grad(loss_fn)
|
648 |
+
loss, grad = grad_fn(state.params)
|
649 |
+
grad = jax.lax.pmean(grad, "batch")
|
650 |
+
new_state = state.apply_gradients(grads=grad)
|
651 |
+
|
652 |
+
metrics = jax.lax.pmean(
|
653 |
+
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}, axis_name="batch"
|
654 |
+
)
|
655 |
+
|
656 |
+
return new_state, metrics, new_dropout_rng
|
657 |
+
|
658 |
+
# Create parallel version of the train step
|
659 |
+
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
660 |
+
|
661 |
+
# Define eval fn
|
662 |
+
def eval_step(params, batch):
|
663 |
+
labels = batch.pop("labels")
|
664 |
+
|
665 |
+
logits = model(**batch, params=params, train=False)[0]
|
666 |
+
|
667 |
+
# compute loss, ignore padded input tokens
|
668 |
+
label_mask = jnp.where(labels > 0, 1.0, 0.0)
|
669 |
+
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
|
670 |
+
|
671 |
+
# compute accuracy
|
672 |
+
accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
|
673 |
+
|
674 |
+
# summarize metrics
|
675 |
+
metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
|
676 |
+
metrics = jax.lax.psum(metrics, axis_name="batch")
|
677 |
+
|
678 |
+
return metrics
|
679 |
+
|
680 |
+
p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
|
681 |
+
|
682 |
+
# Replicate the train state on each device
|
683 |
+
state = jax_utils.replicate(state)
|
684 |
+
|
685 |
+
train_time = 0
|
686 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
687 |
+
for epoch in epochs:
|
688 |
+
# ======================== Training ================================
|
689 |
+
train_start = time.time()
|
690 |
+
train_metrics = []
|
691 |
+
|
692 |
+
# Create sampling rng
|
693 |
+
rng, input_rng = jax.random.split(rng)
|
694 |
+
steps_per_epoch = len(train_dataset) // train_batch_size
|
695 |
+
|
696 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
697 |
+
num_train_samples = len(train_dataset)
|
698 |
+
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
699 |
+
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
|
700 |
+
|
701 |
+
# Gather the indexes for creating the batch and do a training step
|
702 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1,initial=resume_step)):
|
703 |
+
samples = [train_dataset[int(idx)] for idx in batch_idx]
|
704 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
705 |
+
|
706 |
+
|
707 |
+
# Model forward
|
708 |
+
model_inputs = shard(model_inputs.data)
|
709 |
+
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
710 |
+
train_metrics.append(train_metric)
|
711 |
+
|
712 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
713 |
+
if cur_step < resume_step:
|
714 |
+
continue
|
715 |
+
|
716 |
+
if (cur_step % training_args.logging_steps * grad_accum_steps) == 0 and cur_step > 0:
|
717 |
+
# Save metrics
|
718 |
+
train_metric = jax_utils.unreplicate(train_metric)
|
719 |
+
train_time += time.time() - train_start
|
720 |
+
if has_tensorboard and jax.process_index() == 0:
|
721 |
+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
722 |
+
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
723 |
+
# TODO: add accumulation of metrics
|
724 |
+
_metrics = {k if k=="learning_rate" else f"train_{k}":mb_item(v.mean()) for k, v in train_metric.items()}
|
725 |
+
wandb.log({"training_step":cur_step, **_metrics}, commit=True)
|
726 |
+
|
727 |
+
epochs.write(
|
728 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
729 |
+
)
|
730 |
+
|
731 |
+
train_metrics = []
|
732 |
+
|
733 |
+
if cur_step % (training_args.eval_steps * grad_accum_steps) == 0 and cur_step > 0:
|
734 |
+
# ======================== Evaluating ==============================
|
735 |
+
num_eval_samples = len(eval_dataset)
|
736 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
737 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
738 |
+
|
739 |
+
eval_metrics = []
|
740 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
741 |
+
samples = [eval_dataset[int(idx)] for idx in batch_idx]
|
742 |
+
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
743 |
+
|
744 |
+
# Model forward
|
745 |
+
model_inputs = shard(model_inputs.data)
|
746 |
+
metrics = p_eval_step(state.params, model_inputs)
|
747 |
+
eval_metrics.append(metrics)
|
748 |
+
|
749 |
+
# normalize eval metrics
|
750 |
+
eval_metrics = get_metrics(eval_metrics)
|
751 |
+
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
|
752 |
+
eval_normalizer = eval_metrics.pop("normalizer")
|
753 |
+
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
754 |
+
|
755 |
+
# Update progress bar
|
756 |
+
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
757 |
+
|
758 |
+
# Save metrics
|
759 |
+
if has_tensorboard and jax.process_index() == 0:
|
760 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
761 |
+
|
762 |
+
if has_wandb and jax.process_index() == 0 and ("wandb" in training_args.report_to):
|
763 |
+
_metrics = {f"eval_{k}":mb_item(v) for k, v in eval_metrics.items()}
|
764 |
+
wandb.log({"eval_step":cur_step, **_metrics})
|
765 |
+
|
766 |
+
if (cur_step % training_args.save_steps == 0 * grad_accum_steps) and cur_step > 0:
|
767 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
768 |
+
if jax.process_index() == 0:
|
769 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
770 |
+
model.save_pretrained(
|
771 |
+
training_args.output_dir,
|
772 |
+
params=params,
|
773 |
+
push_to_hub=training_args.push_to_hub,
|
774 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
775 |
+
)
|
776 |
+
save_checkpoint(training_args.output_dir, jax_utils.unreplicate(state), cur_step, keep=training_args.save_total_limit, overwrite=True)
|
777 |
+
if training_args.save_total_limit is not None:
|
778 |
+
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
|
779 |
+
|
780 |
+
if jax.process_index() == 0:
|
781 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
782 |
+
model.save_pretrained(
|
783 |
+
training_args.output_dir,
|
784 |
+
params=params,
|
785 |
+
push_to_hub=training_args.push_to_hub,
|
786 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
787 |
+
)
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|