Yeb Havinga commited on
Commit
49e8767
1 Parent(s): 2c7b7d9

Replace scripts and model with improved version

Browse files
Load_preprocessed_dataset.ipynb DELETED
@@ -1,165 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 4,
6
- "id": "cf148030-7287-4c9e-ae32-8d1e1c47be30",
7
- "metadata": {},
8
- "outputs": [],
9
- "source": [
10
- "from datasets import Dataset, DatasetDict"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": 7,
16
- "id": "5161b4ba-e8cf-43e1-b67e-503c29aa4271",
17
- "metadata": {},
18
- "outputs": [
19
- {
20
- "ename": "FileNotFoundError",
21
- "evalue": "[Errno 2] No such file or directory: '/home/yeb/grouped_dataset/dataset_dict.json'",
22
- "output_type": "error",
23
- "traceback": [
24
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
25
- "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
26
- "\u001b[0;32m/tmp/ipykernel_574434/3668239933.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdatasets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDatasetDict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_from_disk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/home/yeb/grouped_dataset\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
27
- "\u001b[0;32m~/datasets/src/datasets/dataset_dict.py\u001b[0m in \u001b[0;36mload_from_disk\u001b[0;34m(dataset_dict_path, fs, keep_in_memory)\u001b[0m\n\u001b[1;32m 727\u001b[0m \u001b[0;34mf\"No such file or directory: '{dataset_dict_json_path}'. Expected to load a DatasetDict object, but got a Dataset. Please use datasets.load_from_disk instead.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 728\u001b[0m )\n\u001b[0;32m--> 729\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset_dict_json_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"r\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"utf-8\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"splits\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 730\u001b[0m dataset_dict_split_path = (\n\u001b[1;32m 731\u001b[0m \u001b[0mdataset_dict_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"://\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"://\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mPath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdest_dataset_dict_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_posix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
28
- "\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/spec.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(self, path, mode, block_size, cache_options, **kwargs)\u001b[0m\n\u001b[1;32m 956\u001b[0m }\n\u001b[1;32m 957\u001b[0m return io.TextIOWrapper(\n\u001b[0;32m--> 958\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mblock_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mtext_kwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 959\u001b[0m )\n\u001b[1;32m 960\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
29
- "\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/spec.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(self, path, mode, block_size, cache_options, **kwargs)\u001b[0m\n\u001b[1;32m 960\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 961\u001b[0m \u001b[0mac\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"autocommit\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_intrans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 962\u001b[0;31m f = self._open(\n\u001b[0m\u001b[1;32m 963\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
30
- "\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/implementations/local.py\u001b[0m in \u001b[0;36m_open\u001b[0;34m(self, path, mode, block_size, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_mkdir\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m\"w\"\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmakedirs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexist_ok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 144\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mLocalFileOpener\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtouch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
31
- "\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/implementations/local.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, path, mode, autocommit, fs, compression, **kwargs)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompression\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_compression\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompression\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblocksize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFAULT_BUFFER_SIZE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 235\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_open\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
32
- "\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/implementations/local.py\u001b[0m in \u001b[0;36m_open\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosed\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocommit\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m\"w\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 240\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 241\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompression\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0mcompress\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompression\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
33
- "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/yeb/grouped_dataset/dataset_dict.json'"
34
- ]
35
- }
36
- ],
37
- "source": [
38
- "datasets = DatasetDict.load_from_disk(\"/home/yeb/grouped_dataset\")"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": 12,
44
- "id": "15f9d047-ac35-43d7-ab55-9f9afe96dd07",
45
- "metadata": {},
46
- "outputs": [
47
- {
48
- "data": {
49
- "text/plain": [
50
- "DatasetDict({\n",
51
- " train: Dataset({\n",
52
- " features: ['input_ids'],\n",
53
- " num_rows: 86438919\n",
54
- " })\n",
55
- " validation: Dataset({\n",
56
- " features: ['input_ids'],\n",
57
- " num_rows: 4735324\n",
58
- " })\n",
59
- "})"
60
- ]
61
- },
62
- "execution_count": 12,
63
- "metadata": {},
64
- "output_type": "execute_result"
65
- }
66
- ],
67
- "source": [
68
- "datasets"
69
- ]
70
- },
71
- {
72
- "cell_type": "code",
73
- "execution_count": 14,
74
- "id": "d1d1218e-142e-441a-b20d-d300b13b172a",
75
- "metadata": {},
76
- "outputs": [],
77
- "source": [
78
- "train = datasets['train']"
79
- ]
80
- },
81
- {
82
- "cell_type": "code",
83
- "execution_count": null,
84
- "id": "9eaddfb1-242f-4a25-8789-efe97b2a5712",
85
- "metadata": {},
86
- "outputs": [],
87
- "source": []
88
- },
89
- {
90
- "cell_type": "code",
91
- "execution_count": 15,
92
- "id": "8aabb26f-19ca-467a-b383-3a693be70cac",
93
- "metadata": {},
94
- "outputs": [
95
- {
96
- "name": "stdout",
97
- "output_type": "stream",
98
- "text": [
99
- "86438919\n"
100
- ]
101
- }
102
- ],
103
- "source": [
104
- "print(len(train))"
105
- ]
106
- },
107
- {
108
- "cell_type": "code",
109
- "execution_count": null,
110
- "id": "f3176986-5b34-4ed6-a643-e342db9a2ce8",
111
- "metadata": {},
112
- "outputs": [],
113
- "source": []
114
- },
115
- {
116
- "cell_type": "code",
117
- "execution_count": 16,
118
- "id": "1205bbef-ba9d-4ddc-af2e-602d56b7dd64",
119
- "metadata": {},
120
- "outputs": [
121
- {
122
- "name": "stdout",
123
- "output_type": "stream",
124
- "text": [
125
- "{'input_ids': [256, 3, 20, 18452, 6690, 7757, 1286, 43, 10, 4942, 1286, 80, 12, 4782, 5442, 39, 5385, 33, 4, 5, 3, 2924, 117, 5669, 228, 21, 193, 9030, 511, 24, 11, 5, 665, 165, 4218, 7, 26, 264, 1528, 35, 105, 3, 19653, 12, 9661, 17156, 13955, 4, 132, 5, 611, 959, 961, 146, 6522, 7757, 1286, 89, 7500, 9716, 11, 5, 4868, 107, 13604, 12, 12836, 13368, 11, 611, 959, 4, 3, 69, 99, 12, 13132, 6690, 590, 5, 1803, 1867, 69, 7, 924, 10, 1762, 4, 3, 69, 538, 489, 14, 1149, 16, 3, 11384, 199, 116, 399, 4782, 291, 3, 6, 237, 13, 2629, 3, 8987, 291, 4, 69, 5, 3, 27, 72, 20, 325, 3, 2924, 133, 21, 105, 9030, 10, 1149, 242, 16, 144, 13572, 11, 9, 13401, 20, 7951, 8, 165, 4218, 4, 5, 1910]}\n"
126
- ]
127
- }
128
- ],
129
- "source": [
130
- "it = iter(train)\n",
131
- "\n",
132
- "print(next(it))"
133
- ]
134
- },
135
- {
136
- "cell_type": "code",
137
- "execution_count": null,
138
- "id": "f5d4e8de-419c-4c70-896e-fbd640bb7321",
139
- "metadata": {},
140
- "outputs": [],
141
- "source": []
142
- }
143
- ],
144
- "metadata": {
145
- "kernelspec": {
146
- "display_name": "Python 3 (ipykernel)",
147
- "language": "python",
148
- "name": "python3"
149
- },
150
- "language_info": {
151
- "codemirror_mode": {
152
- "name": "ipython",
153
- "version": 3
154
- },
155
- "file_extension": ".py",
156
- "mimetype": "text/x-python",
157
- "name": "python",
158
- "nbconvert_exporter": "python",
159
- "pygments_lexer": "ipython3",
160
- "version": "3.8.10"
161
- }
162
- },
163
- "nbformat": 4,
164
- "nbformat_minor": 5
165
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Load_token_group_dataset.ipynb DELETED
@@ -1,567 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 71,
6
- "id": "d7f2bdb5-95c2-4a57-80e8-8f1a30a138b0",
7
- "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "Number of files 20 after adding ./c4_cleaned glob *73*.gz\n",
14
- "Number of files 39 after adding ./c4_cleaned glob *47*.gz\n",
15
- "Number of files 60 after adding ./c4_cleaned glob *12*.gz\n",
16
- "Number of files 79 after adding ./c4_cleaned glob *29*.gz\n",
17
- "Number of files 97 after adding ./c4_cleaned glob *74*.gz\n",
18
- "Number of files 116 after adding ./c4_cleaned glob *26*.gz\n",
19
- "Number of files 135 after adding ./c4_cleaned glob *54*.gz\n",
20
- "Number of files 154 after adding ./c4_cleaned glob *68*.gz\n",
21
- "Number of files 172 after adding ./c4_cleaned glob *57*.gz\n",
22
- "Number of files 189 after adding ./c4_cleaned glob *46*.gz\n",
23
- "Number of files 206 after adding ./c4_cleaned glob *35*.gz\n",
24
- "Number of files 226 after adding ./c4_cleaned glob *13*.gz\n",
25
- "Number of files 242 after adding ./c4_cleaned glob *41*.gz\n",
26
- "Number of files 259 after adding ./c4_cleaned glob *52*.gz\n",
27
- "Number of files 276 after adding ./c4_cleaned glob *63*.gz\n",
28
- "Number of files 292 after adding ./c4_cleaned glob *85*.gz\n",
29
- "Number of files 309 after adding ./c4_cleaned glob *81*.gz\n",
30
- "Number of files 326 after adding ./c4_cleaned glob *96*.gz\n",
31
- "Number of files 526 after adding ./nrc_uniq_cleaned_20210223 glob *.gz\n",
32
- "Number of files 726 after adding ./nu_uniq_cleaned_20210225 glob *.gz\n",
33
- "726\n",
34
- "Got 690 training files and 5.0 % 36 validation files\n"
35
- ]
36
- }
37
- ],
38
- "source": [
39
- "data_files = []\n",
40
- "data_dir=\".\"\n",
41
- "def train_val_files():\n",
42
- " import glob\n",
43
- " import random\n",
44
- " SEED = 12345\n",
45
- "\n",
46
- " def add_jsonlines_dir(path, filespec):\n",
47
- " global data_files\n",
48
- " data_files += glob.glob(f\"{path}/{filespec}\")\n",
49
- " data_files = list(set(data_files))\n",
50
- " print(f\"Number of files {len(data_files)} after adding {path} glob {filespec}\")\n",
51
- "\n",
52
- " # add_jsonlines_dir(f\"{data_dir}/oscar_nl_cleaned\")\n",
53
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*73*.gz\")\n",
54
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*47*.gz\")\n",
55
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*12*.gz\")\n",
56
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*29*.gz\")\n",
57
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*74*.gz\")\n",
58
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*26*.gz\")\n",
59
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*54*.gz\")\n",
60
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*68*.gz\")\n",
61
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*57*.gz\")\n",
62
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*46*.gz\")\n",
63
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*35*.gz\")\n",
64
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*13*.gz\")\n",
65
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*41*.gz\")\n",
66
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*52*.gz\")\n",
67
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*63*.gz\")\n",
68
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*85*.gz\")\n",
69
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*81*.gz\")\n",
70
- " add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*96*.gz\")\n",
71
- " add_jsonlines_dir(f\"{data_dir}/nrc_uniq_cleaned_20210223\", \"*.gz\")\n",
72
- " add_jsonlines_dir(f\"{data_dir}/nu_uniq_cleaned_20210225\", \"*.gz\")\n",
73
- " random.Random(SEED).shuffle(data_files)\n",
74
- "\n",
75
- " total = len(data_files)\n",
76
- " print(total)\n",
77
- " perc = 0.05\n",
78
- " val_size = int(perc * total)\n",
79
- " train_size = total - val_size\n",
80
- " train = data_files[:train_size]\n",
81
- " val = data_files[train_size:]\n",
82
- " print(f\"Got {len(train)} training files and {perc*100} % {len(val)} validation files\")\n",
83
- "\n",
84
- " assert list(set(train) & set(val)) == [], \"Train overlaps with test\"\n",
85
- "\n",
86
- " return train, val\n",
87
- "\n",
88
- "train, val = train_val_files()"
89
- ]
90
- },
91
- {
92
- "cell_type": "code",
93
- "execution_count": 72,
94
- "id": "66a923c6-1c7e-4ac2-9aec-e75c572104dd",
95
- "metadata": {},
96
- "outputs": [
97
- {
98
- "name": "stderr",
99
- "output_type": "stream",
100
- "text": [
101
- "Using custom data configuration default-ce92ec7dc3732df4\n"
102
- ]
103
- },
104
- {
105
- "name": "stdout",
106
- "output_type": "stream",
107
- "text": [
108
- "Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/yeb/.cache/huggingface/datasets/json/default-ce92ec7dc3732df4/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9...\n"
109
- ]
110
- },
111
- {
112
- "data": {
113
- "application/vnd.jupyter.widget-view+json": {
114
- "model_id": "",
115
- "version_major": 2,
116
- "version_minor": 0
117
- },
118
- "text/plain": [
119
- "0 tables [00:00, ? tables/s]"
120
- ]
121
- },
122
- "metadata": {},
123
- "output_type": "display_data"
124
- },
125
- {
126
- "data": {
127
- "application/vnd.jupyter.widget-view+json": {
128
- "model_id": "",
129
- "version_major": 2,
130
- "version_minor": 0
131
- },
132
- "text/plain": [
133
- "0 tables [00:00, ? tables/s]"
134
- ]
135
- },
136
- "metadata": {},
137
- "output_type": "display_data"
138
- },
139
- {
140
- "name": "stdout",
141
- "output_type": "stream",
142
- "text": [
143
- "Dataset json downloaded and prepared to /home/yeb/.cache/huggingface/datasets/json/default-ce92ec7dc3732df4/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9. Subsequent calls will reuse this data.\n"
144
- ]
145
- }
146
- ],
147
- "source": [
148
- "from datasets import load_dataset\n",
149
- "datasets = load_dataset('json', data_files={'train': train, 'validation': val})"
150
- ]
151
- },
152
- {
153
- "cell_type": "code",
154
- "execution_count": 73,
155
- "id": "4a6d6009-00e7-4b30-b577-6805dd849b8a",
156
- "metadata": {},
157
- "outputs": [
158
- {
159
- "name": "stdout",
160
- "output_type": "stream",
161
- "text": [
162
- "Num examples = 21153916\n"
163
- ]
164
- }
165
- ],
166
- "source": [
167
- "print(f\"Num examples = {len(datasets['train'])}\")"
168
- ]
169
- },
170
- {
171
- "cell_type": "code",
172
- "execution_count": 74,
173
- "id": "c6186d88-4296-4d1d-b7cd-d0196f0b0f97",
174
- "metadata": {},
175
- "outputs": [],
176
- "source": [
177
- "from transformers import (\n",
178
- " CONFIG_MAPPING,\n",
179
- " FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n",
180
- " BatchEncoding,\n",
181
- " FlaxT5ForConditionalGeneration,\n",
182
- " T5ForConditionalGeneration,\n",
183
- " HfArgumentParser,\n",
184
- " PreTrainedTokenizerBase,\n",
185
- " T5Config,\n",
186
- " T5TokenizerFast,\n",
187
- " TrainingArguments,\n",
188
- " is_tensorboard_available,\n",
189
- " set_seed,\n",
190
- ")"
191
- ]
192
- },
193
- {
194
- "cell_type": "code",
195
- "execution_count": 75,
196
- "id": "10d90997-6eb6-4399-b1a7-8a858ae4738c",
197
- "metadata": {},
198
- "outputs": [
199
- {
200
- "name": "stderr",
201
- "output_type": "stream",
202
- "text": [
203
- "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
204
- ]
205
- },
206
- {
207
- "name": "stdout",
208
- "output_type": "stream",
209
- "text": [
210
- "Start tokenization, remove_column_names = ['url', 'timestamp', 'text']\n"
211
- ]
212
- }
213
- ],
214
- "source": [
215
- "tokenizer = T5TokenizerFast.from_pretrained(\"./t5-base-dutch\")\n",
216
- "\n",
217
- "def tokenize_function(examples):\n",
218
- " return tokenizer(examples['text'], return_attention_mask=False)\n",
219
- "\n",
220
- "column_names = datasets[\"train\"].column_names\n",
221
- "print(f\"Start tokenization, remove_column_names = {column_names}\")\n",
222
- "\n",
223
- "tokenized_datasets = datasets.map(\n",
224
- " tokenize_function,\n",
225
- " batched=True,\n",
226
- " num_proc=96,\n",
227
- " remove_columns=column_names,\n",
228
- " load_from_cache_file=True,\n",
229
- ")\n",
230
- "\n"
231
- ]
232
- },
233
- {
234
- "cell_type": "code",
235
- "execution_count": 76,
236
- "id": "de7983e1-775d-4ee3-bf66-681f731501fb",
237
- "metadata": {},
238
- "outputs": [
239
- {
240
- "data": {
241
- "text/plain": [
242
- "21153916"
243
- ]
244
- },
245
- "execution_count": 76,
246
- "metadata": {},
247
- "output_type": "execute_result"
248
- }
249
- ],
250
- "source": [
251
- "len(tokenized_datasets[\"train\"])"
252
- ]
253
- },
254
- {
255
- "cell_type": "code",
256
- "execution_count": 77,
257
- "id": "5721ad35-8373-4999-8ac5-02c6f759373f",
258
- "metadata": {},
259
- "outputs": [
260
- {
261
- "name": "stdout",
262
- "output_type": "stream",
263
- "text": [
264
- "Expanded_inputs_length: 141, targets_length: 29\n",
265
- "Start group_texts\n"
266
- ]
267
- },
268
- {
269
- "name": "stderr",
270
- "output_type": "stream",
271
- "text": [
272
- "https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=503811,5cca55,7fe2dabc120f,7fe2dabc120f,90641f90b85f&map=&map= \n",
273
- " \n",
274
- "*** SIGTERM received by PID 47670 (TID 47670) on cpu 70 from PID 33223; stack trace: ***\n",
275
- "*** SIGTERM received by PID 47686 (TID 47686) on cpu 71 from PID 33223; stack trace: ***\n",
276
- "https://symbolize.stripped_domain/r/?trace=56a4e1,7fe2dabc120f&map= \n",
277
- "https://symbolize.stripped_domain/r/?trace=*** SIGTERM received by PID 47673 (TID 47673) on cpu 16 from PID 33223; stack trace: ***\n",
278
- "56a682,7fe2dabc120f,7fdfb4cf751f,90b3ff&map= \n",
279
- "*** SIGTERM received by PID 47665 (TID 47665) on cpu 67 from PID 33223; stack trace: ***\n",
280
- "PC: @ 0x503811 (unknown) (unknown)\n",
281
- "PC: @ 0x56a4e1 (unknown) _PyEval_EvalFrameDefault\n",
282
- "PC: @ 0x5cca55 (unknown) (unknown)\n",
283
- " @ 0x7fde2703b800 976 (unknown)\n",
284
- " @ 0x7fde2703b800 976 (unknown)\n",
285
- " @ 0x7fe2dabc1210 (unknown) (unknown)\n",
286
- " @ ... and at least 1 more frames\n",
287
- "https://symbolize.stripped_domain/r/?trace= @ 0x7fe2dabc1210 852927808 (unknown)\n",
288
- "56a4e1,7fde2703b7ff,7fe2dabc120f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
289
- "E0710 11:59:41.025238 47673 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
290
- " @ 0x7fde2703b800 976 (unknown)\n",
291
- " @ 0x7fe2dabc1210 850855568 (unknown)\n",
292
- " @ 0x90b860 (unknown) (unknown)\n",
293
- "https://symbolize.stripped_domain/r/?trace=5cca55,7fde2703b7ff,7fe2dabc120f,90b85f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
294
- "E0710 11:59:41.030755 47686 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
295
- " @ 0x906420 (unknown) (unknown)\n",
296
- "https://symbolize.stripped_domain/r/?trace=503811,7fde2703b7ff,7fe2dabc120f,90641f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
297
- "E0710 11:59:41.033184 47670 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
298
- "E0710 11:59:41.033730 47673 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
299
- "PC: @ 0x56a682 (unknown) _PyEval_EvalFrameDefault\n",
300
- " @ 0x7fde2703b800 976 (unknown)\n",
301
- " @ 0x7fe2dabc1210 (unknown) (unknown)\n",
302
- " @ 0x7fdfb4cf7520 (unknown) (unknown)\n",
303
- "E0710 11:59:41.057700 47670 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
304
- "E0710 11:59:41.063730 47686 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
305
- " @ 0x90b400 (unknown) (unknown)\n",
306
- "https://symbolize.stripped_domain/r/?trace=56a682,7fde2703b7ff,7fe2dabc120f,7fdfb4cf751f,90b3ff&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
307
- "E0710 11:59:41.064237 47665 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
308
- "E0710 11:59:41.091833 47665 process_state.cc:771] RAW: Raising signal 15 with default behavior\n"
309
- ]
310
- }
311
- ],
312
- "source": [
313
- "def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):\n",
314
- " \"\"\"This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .\n",
315
- "\n",
316
- " Training parameters to avoid padding with random_spans_noise_mask.\n",
317
- " When training a model with random_spans_noise_mask, we would like to set the other\n",
318
- " training hyperparmeters in a way that avoids padding.\n",
319
- " This function helps us compute these hyperparameters.\n",
320
- " We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,\n",
321
- " and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.\n",
322
- " This function tells us the required number of tokens in the raw example (for split_tokens())\n",
323
- " as well as the length of the encoded targets. Note that this function assumes\n",
324
- " the inputs and targets will have EOS appended and includes that in the reported length.\n",
325
- "\n",
326
- " Args:\n",
327
- " inputs_length: an integer - desired length of the tokenized inputs sequence\n",
328
- " noise_density: a float\n",
329
- " mean_noise_span_length: a float\n",
330
- " Returns:\n",
331
- " tokens_length: length of original text in tokens\n",
332
- " targets_length: an integer - length in tokens of encoded targets sequence\n",
333
- " \"\"\"\n",
334
- "\n",
335
- " def _tokens_length_to_inputs_length_targets_length(tokens_length):\n",
336
- " num_noise_tokens = int(round(tokens_length * noise_density))\n",
337
- " num_nonnoise_tokens = tokens_length - num_noise_tokens\n",
338
- " num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))\n",
339
- " # inputs contain all nonnoise tokens, sentinels for all noise spans\n",
340
- " # and one EOS token.\n",
341
- " _input_length = num_nonnoise_tokens + num_noise_spans + 1\n",
342
- " _output_length = num_noise_tokens + num_noise_spans + 1\n",
343
- " return _input_length, _output_length\n",
344
- "\n",
345
- " tokens_length = inputs_length\n",
346
- "\n",
347
- " while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:\n",
348
- " tokens_length += 1\n",
349
- "\n",
350
- " inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)\n",
351
- "\n",
352
- " # minor hack to get the targets length to be equal to inputs length\n",
353
- " # which is more likely to have been set to a nice round number.\n",
354
- " if noise_density == 0.5 and targets_length > inputs_length:\n",
355
- " tokens_length -= 1\n",
356
- " targets_length -= 1\n",
357
- " return tokens_length, targets_length\n",
358
- "\n",
359
- "# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.\n",
360
- "# To ensure that the input length is `max_seq_length`, we need to increase the maximum length\n",
361
- "# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.\n",
362
- "expanded_inputs_length, targets_length = compute_input_and_target_lengths(\n",
363
- " inputs_length=128,\n",
364
- " noise_density=0.15,\n",
365
- " mean_noise_span_length=3.0,\n",
366
- ")\n",
367
- "\n",
368
- "print(f\"Expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}\")\n",
369
- "print(f\"Start group_texts\")\n",
370
- "\n",
371
- "# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.\n",
372
- "def group_texts(examples):\n",
373
- " # Concatenate all texts.\n",
374
- " concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
375
- " total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
376
- " # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
377
- " # customize this part to your needs.\n",
378
- " if total_length >= expanded_inputs_length:\n",
379
- " total_length = (total_length // expanded_inputs_length) * expanded_inputs_length\n",
380
- " # Split by chunks of max_len.\n",
381
- " result = {\n",
382
- " k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]\n",
383
- " for k, t in concatenated_examples.items()\n",
384
- " }\n",
385
- " return result\n",
386
- "\n",
387
- "# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a\n",
388
- "# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value\n",
389
- "# might be slower to preprocess.\n",
390
- "#\n",
391
- "# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n",
392
- "# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n",
393
- "grouped_datasets = tokenized_datasets.map(\n",
394
- " group_texts,\n",
395
- " batched=True,\n",
396
- " batch_size=200,\n",
397
- " num_proc=96,\n",
398
- " load_from_cache_file=True,\n",
399
- ")\n"
400
- ]
401
- },
402
- {
403
- "cell_type": "code",
404
- "execution_count": 78,
405
- "id": "f37e7559-fcc1-436b-a4ee-45adb856869e",
406
- "metadata": {},
407
- "outputs": [
408
- {
409
- "data": {
410
- "text/plain": [
411
- "86438919"
412
- ]
413
- },
414
- "execution_count": 78,
415
- "metadata": {},
416
- "output_type": "execute_result"
417
- }
418
- ],
419
- "source": [
420
- "examples = len(grouped_datasets[\"train\"])\n",
421
- "examples"
422
- ]
423
- },
424
- {
425
- "cell_type": "code",
426
- "execution_count": 79,
427
- "id": "21aac2aa-9dc2-4b7a-8c46-62cfa47f18a7",
428
- "metadata": {},
429
- "outputs": [],
430
- "source": [
431
- "it = iter(grouped_datasets[\"train\"])"
432
- ]
433
- },
434
- {
435
- "cell_type": "code",
436
- "execution_count": 80,
437
- "id": "011a6a07-5fe0-441a-b032-79cf8664b5c5",
438
- "metadata": {},
439
- "outputs": [
440
- {
441
- "name": "stdout",
442
- "output_type": "stream",
443
- "text": [
444
- "{'input_ids': [256, 3, 20, 18452, 6690, 7757, 1286, 43, 10, 4942, 1286, 80, 12, 4782, 5442, 39, 5385, 33, 4, 5, 3, 2924, 117, 5669, 228, 21, 193, 9030, 511, 24, 11, 5, 665, 165, 4218, 7, 26, 264, 1528, 35, 105, 3, 19653, 12, 9661, 17156, 13955, 4, 132, 5, 611, 959, 961, 146, 6522, 7757, 1286, 89, 7500, 9716, 11, 5, 4868, 107, 13604, 12, 12836, 13368, 11, 611, 959, 4, 3, 69, 99, 12, 13132, 6690, 590, 5, 1803, 1867, 69, 7, 924, 10, 1762, 4, 3, 69, 538, 489, 14, 1149, 16, 3, 11384, 199, 116, 399, 4782, 291, 3, 6, 237, 13, 2629, 3, 8987, 291, 4, 69, 5, 3, 27, 72, 20, 325, 3, 2924, 133, 21, 105, 9030, 10, 1149, 242, 16, 144, 13572, 11, 9, 13401, 20, 7951, 8, 165, 4218, 4, 5, 1910]}\n"
445
- ]
446
- }
447
- ],
448
- "source": [
449
- "print(next(it))"
450
- ]
451
- },
452
- {
453
- "cell_type": "code",
454
- "execution_count": 81,
455
- "id": "f20d3da2-0132-4ecc-b9b9-c2b5ec06f031",
456
- "metadata": {},
457
- "outputs": [],
458
- "source": [
459
- "tokens = next(it)['input_ids']\n"
460
- ]
461
- },
462
- {
463
- "cell_type": "code",
464
- "execution_count": 82,
465
- "id": "2bad87cd-06e1-4c52-b2d6-d61fcb96e35d",
466
- "metadata": {},
467
- "outputs": [
468
- {
469
- "data": {
470
- "text/plain": [
471
- "141"
472
- ]
473
- },
474
- "execution_count": 82,
475
- "metadata": {},
476
- "output_type": "execute_result"
477
- }
478
- ],
479
- "source": [
480
- "len(tokens)"
481
- ]
482
- },
483
- {
484
- "cell_type": "code",
485
- "execution_count": 83,
486
- "id": "4e0f573a-0abc-4f8f-b59a-a281fb306425",
487
- "metadata": {},
488
- "outputs": [
489
- {
490
- "data": {
491
- "text/plain": [
492
- "\"werden volgens getuigen vergezeld door een boomlange bodyguard. ook hing er een gordijntje om de tafel, zodat beyoncé in alle rust van de show kon genieten. volgens de bron verliet knowles pas om 03.30 uur's ochtends de hippe club.</s> utrecht - in de schouwburg van utrecht gaat vrijdagavond de musical 'joseph and the amazing technicolor dreamcoat' in première. voor het eerst in nederland. een voorloper van het geesteskind van andrew lloyd webber werd al in 1967 voor het eerst op een school in groot-brittannië uitgeprobeerd. twaalf jaar later werd het in\""
493
- ]
494
- },
495
- "execution_count": 83,
496
- "metadata": {},
497
- "output_type": "execute_result"
498
- }
499
- ],
500
- "source": [
501
- "tokenizer.decode(tokens)"
502
- ]
503
- },
504
- {
505
- "cell_type": "code",
506
- "execution_count": 84,
507
- "id": "ab853c1b-0e0f-4ae8-b1cb-053f76a7d9d7",
508
- "metadata": {},
509
- "outputs": [
510
- {
511
- "ename": "KeyboardInterrupt",
512
- "evalue": "",
513
- "output_type": "error",
514
- "traceback": [
515
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
516
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
517
- "\u001b[0;32m/tmp/ipykernel_33223/1050159500.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m \u001b[0;34m:=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m141\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
518
- "\u001b[0;32m~/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1264\u001b[0m \u001b[0moutput_all_columns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output_all_columns\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1266\u001b[0;31m yield self._getitem(\n\u001b[0m\u001b[1;32m 1267\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1268\u001b[0m \u001b[0mformat_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformat_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
519
- "\u001b[0;32m~/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_getitem\u001b[0;34m(self, key, format_type, format_columns, output_all_columns, format_kwargs)\u001b[0m\n\u001b[1;32m 1507\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1508\u001b[0m \u001b[0mformatter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_formatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mformat_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mformat_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1509\u001b[0;31m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquery_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_indices\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_indices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1510\u001b[0m formatted_output = format_table(\n\u001b[1;32m 1511\u001b[0m \u001b[0mpa_subtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformatter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformat_columns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_all_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_all_columns\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
520
- "\u001b[0;32m~/datasets/src/datasets/formatting/formatting.py\u001b[0m in \u001b[0;36mquery_table\u001b[0;34m(table, key, indices)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0;31m# Query the main table\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 371\u001b[0;31m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_query_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 372\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_query_table_with_indices_mapping\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
521
- "\u001b[0;32m~/datasets/src/datasets/formatting/formatting.py\u001b[0m in \u001b[0;36m_query_table\u001b[0;34m(table, key)\u001b[0m\n\u001b[1;32m 77\u001b[0m \"\"\"\n\u001b[1;32m 78\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfast_slice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mslice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
522
- "\u001b[0;32m~/datasets/src/datasets/table.py\u001b[0m in \u001b[0;36mfast_slice\u001b[0;34m(self, offset, length)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0moffset\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlength\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mpa\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mschema\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_schema\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_interpolation_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0moffset\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mbatches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batches\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
523
- "\u001b[0;32m~/datasets/src/datasets/table.py\u001b[0m in \u001b[0;36m_interpolation_search\u001b[0;34m(arr, x)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
524
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
525
- ]
526
- }
527
- ],
528
- "source": [
529
- "while (example := next(it, None)) is not None:\n",
530
- " if len(example['input_ids']) == 141:\n",
531
- " continue\n",
532
- " else:\n",
533
- " print(example)\n",
534
- " break"
535
- ]
536
- },
537
- {
538
- "cell_type": "code",
539
- "execution_count": null,
540
- "id": "f71a0f6b-3b60-4dd5-a9af-0ef43aadc6a1",
541
- "metadata": {},
542
- "outputs": [],
543
- "source": []
544
- }
545
- ],
546
- "metadata": {
547
- "kernelspec": {
548
- "display_name": "Python 3 (ipykernel)",
549
- "language": "python",
550
- "name": "python3"
551
- },
552
- "language_info": {
553
- "codemirror_mode": {
554
- "name": "ipython",
555
- "version": 3
556
- },
557
- "file_extension": ".py",
558
- "mimetype": "text/x-python",
559
- "name": "python",
560
- "nbconvert_exporter": "python",
561
- "pygments_lexer": "ipython3",
562
- "version": "3.8.10"
563
- }
564
- },
565
- "nbformat": 4,
566
- "nbformat_minor": 5
567
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -3,21 +3,34 @@ language:
3
  - dutch
4
  tags:
5
  - seq2seq
6
- - text-generation
7
  datasets:
8
- - mc4
 
 
9
  ---
10
 
11
  # t5-base-dutch
12
 
13
- Created by [Yeb Havinga](https://www.linkedin.com/in/yeb-havinga-86530825/) & [Dat Nguyen](https://www.linkedin.com/in/dat-nguyen-49a641138/) during the [Hugging Face community week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google, for the project [Pre-train T5 from scratch in Dutch](https://discuss.huggingface.co/t/pretrain-t5-from-scratch-in-dutch/8109).
 
14
 
15
- See also the fine-tuned [t5-base-dutch-demo](https://huggingface.co/flax-community/t5-base-dutch-demo) model, and the demo application **[Netherformer 📰](https://huggingface.co/spaces/flax-community/netherformer)**, that are based on this model.
 
 
 
 
 
 
 
 
 
 
16
 
17
  ## Dataset
18
 
19
- This model was trained on a cleaned version of the Dutch part of [mC4](https://huggingface.co/datasets/mc4).
20
- See the `clean` directory for the clean script.
21
 
22
  * Documents that contained words from a selection of the Dutch and English [List of Dirty Naught Obscene and Otherwise Bad Words](https://github.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words) are removed
23
  * Sentences with less than 3 words are removed
@@ -26,13 +39,18 @@ See the `clean` directory for the clean script.
26
  * Documents with "javascript", "lorum ipsum", "terms of use", "privacy policy", "cookie policy", "uses cookies",
27
  "use of cookies", "use cookies", "elementen ontbreken", "deze printversie" are removed.
28
 
 
 
 
 
 
29
  ## Training
30
 
31
- Training of the model was resumed from an earlier checkpoint several times, as can be seen in the training metrics tab. (switch to wall time for a better view).
 
 
32
 
33
- After several hours of training an error would be raised that we haven't been able to identify and solve. As a workaround,
34
- the first few resumes would start again at step 0 with a different seeded reshuffling of the data.
35
- In the last two resumes the random seed was fixed, and training would resume at the previous step, since a try/except around the failing example would allow training to continue in the case of errors caused by a single example.
36
 
37
- The final model was trained for 63000 steps with a batch size of 128, ending with an evaluation loss of 1.79 and accuracy of 0.64.
38
- A triangle learning rate schedule was used, with peak learning rate 0.01 for the first few runs, and 0.001 for the last two runs.
 
3
  - dutch
4
  tags:
5
  - seq2seq
6
+ - lm-head
7
  datasets:
8
+ - yhavinga/mc4_nl_cleaned
9
+ license: apache-2.0
10
+ inference: false
11
  ---
12
 
13
  # t5-base-dutch
14
 
15
+ Created by [Yeb Havinga](https://www.linkedin.com/in/yeb-havinga-86530825/)
16
+ & [Dat Nguyen](https://www.linkedin.com/in/dat-nguyen-49a641138/) during the [Hugging Face community week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google, for the project [Pre-train T5 from scratch in Dutch](https://discuss.huggingface.co/t/pretrain-t5-from-scratch-in-dutch/8109).
17
 
18
+ See also the fine-tuned [t5-base-dutch-demo](https://huggingface.co/flax-community/t5-base-dutch-demo) model,
19
+ and the demo application **[Netherformer 📰](https://huggingface.co/spaces/flax-community/netherformer)**,
20
+ that are based on this model.
21
+
22
+ **5 jan 2022: Model updated. Evaluation accuracy increased from 0.64 to 0.70.**
23
+
24
+ ## Model
25
+
26
+ * Configuration based on `google/t5-base`
27
+ * 12 layers, 12 heads
28
+ * Dropout set to 0.1
29
 
30
  ## Dataset
31
 
32
+ This model was trained on the `full` configuration of [cleaned Dutch mC4](https://huggingface.co/datasets/mc4_nl_cleaned),
33
+ which is the original mC4, except
34
 
35
  * Documents that contained words from a selection of the Dutch and English [List of Dirty Naught Obscene and Otherwise Bad Words](https://github.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words) are removed
36
  * Sentences with less than 3 words are removed
 
39
  * Documents with "javascript", "lorum ipsum", "terms of use", "privacy policy", "cookie policy", "uses cookies",
40
  "use of cookies", "use cookies", "elementen ontbreken", "deze printversie" are removed.
41
 
42
+ ## Tokenization
43
+
44
+ A SentencePiece tokenizer was trained from scratch on this dataset.
45
+ The total tokens of the `full` configuration is 34B
46
+
47
  ## Training
48
 
49
+ The model was trained on the `full` mc4_nl_cleaned dataset configuration for 1 epoch, consisting of 34B tokens,
50
+ for 528 482 steps with a batch size of 128 and took 57 hours.
51
+ A triangle learning rate schedule was used, with peak learning rate 0.005.
52
 
53
+ ## Evaluation
 
 
54
 
55
+ * Loss: 1.38
56
+ * Accuracy: 0.70
config.json CHANGED
@@ -52,7 +52,7 @@
52
  }
53
  },
54
  "torch_dtype": "float32",
55
- "transformers_version": "4.9.0.dev0",
56
  "use_cache": true,
57
  "vocab_size": 32103
58
  }
 
52
  }
53
  },
54
  "torch_dtype": "float32",
55
+ "transformers_version": "4.13.0",
56
  "use_cache": true,
57
  "vocab_size": 32103
58
  }
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7530cff462d75db600d085d83bcc77ac48dde95d396cf714cf51f786dcddd7eb
3
  size 891548548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be5973ac1f68ec3c5ceb47e10ed848b83ad06e69affa938fc400e3ef368143ea
3
  size 891548548
flax_to_pt.py CHANGED
@@ -1,6 +1,26 @@
1
- from transformers import T5ForConditionalGeneration, TFT5ForConditionalGeneration
2
-
3
- pt_model = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
4
- pt_model.save_pretrained(".")
5
- tf_model = TFT5ForConditionalGeneration.from_pretrained(".", from_pt=True)
6
- tf_model.save_pretrained(".")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+ from transformers import AutoTokenizer
5
+ from transformers import FlaxT5ForConditionalGeneration
6
+ from transformers import T5ForConditionalGeneration
7
+ tokenizer = AutoTokenizer.from_pretrained(".")
8
+ model_fx = FlaxT5ForConditionalGeneration.from_pretrained(".")
9
+ model_pt = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
10
+ model_pt.save_pretrained("./")
11
+ text = "Hoe gaat het?"
12
+ e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
13
+ d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
14
+ e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
15
+ d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
16
+ print(e_input_ids_fx)
17
+ print(d_input_ids_fx)
18
+ print()
19
+ encoder_pt = model_fx.encode(**e_input_ids_pt)
20
+ decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
21
+ logits_pt = decoder_pt.logits
22
+ print(logits_pt)
23
+ encoder_fx = model_fx.encode(**e_input_ids_fx)
24
+ decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
25
+ logits_fx = decoder_fx.logits
26
+ print(logits_fx)
opt_state.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ffae8bd1730e35ebeb0619a7d1b75dab07addff2320d2394eb1af891820ca64f
3
- size 1985609
 
 
 
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d1b04c56abcc3a5bd4d7e871c7d017f44ab5b75af1c4adcc30c205da5fc5ede1
3
  size 891650495
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f102fac4815a8b1b29916b196bfe88a0e5fef76083c6007a5c7966a7fcb9b2d6
3
  size 891650495
run_t5.sh CHANGED
@@ -1,79 +1,37 @@
1
- MODEL="t5-base-dutch"
2
-
3
- MODEL_DIR="${HOME}/${MODEL}"
4
-
5
- mkdir -p "${MODEL_DIR}/runs"
6
-
7
- # T5 paper lr 0.01 with batch size 128
8
- # We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
9
-
10
- #SEED=9200
11
- #
12
- #./run_t5_mlm_flax_custom_dataset.py \
13
- # --output_dir="${MODEL_DIR}" \
14
- # --model_type="t5" \
15
- # --config_name="flax-community/${MODEL}" \
16
- # --tokenizer_name="${MODEL_DIR}" \
17
- # --seed="${SEED}" \
18
- # --preprocessing_num_workers="96" \
19
- # --do_train --do_eval \
20
- # --adafactor \
21
- # --max_seq_length="512" \
22
- # --per_device_train_batch_size="32" \
23
- # --per_device_eval_batch_size="32" \
24
- # --dtype="bfloat16" \
25
- # --learning_rate="5e-3" \
26
- # --overwrite_output_dir \
27
- # --num_train_epochs="3" \
28
- # --logging_steps="50" \
29
- # --save_steps="100" \
30
- # --eval_steps="5000" \
31
- # --warmup_steps="3413"
32
- #exit
33
-
34
- while true; do
35
-
36
- # Set the seed to random before each run, so date shuffling per epoch is different each run.
37
- # This kills reproducibility, but is required as long as during training ValueError can be raised.
38
- # SEED=$RANDOM
39
- SEED=22384
40
-
41
- ./run_t5_mlm_flax_custom_dataset.py \
42
- --output_dir="${MODEL_DIR}" \
43
- --model_type="t5" \
44
- --config_name="flax-community/${MODEL}" \
45
- --tokenizer_name="${MODEL_DIR}" \
46
- --seed="${SEED}" \
47
- --preprocessing_num_workers="96" \
48
- --do_train --do_eval \
49
- --adafactor \
50
- --max_seq_length="512" \
51
- --per_device_train_batch_size="16" \
52
- --per_device_eval_batch_size="16" \
53
- --dtype="bfloat16" \
54
- --learning_rate="1e-3" \
55
- --overwrite_output_dir \
56
- --num_train_epochs="1" \
57
- --logging_steps="50" \
58
- --save_steps="500" \
59
- --eval_steps="5000" \
60
- --resume_from_checkpoint="${MODEL_DIR}" \
61
- --warmup_steps="6519"
62
-
63
- # \
64
- # --push_to_hub
65
-
66
- echo "RESTARTING"
67
- sleep 20
68
- done
69
- #
70
- # \
71
-
72
-
73
- #git add pytorch_model.bin
74
- #git commit -m "Update pytorch model after training"
75
- #git push origin main
76
-
77
- # --gradient_accumulation_steps="2" \
78
-
79
- # --resume_from_checkpoint="${MODEL_DIR}/ckpt-18000" \
 
1
+ #!/bin/bash
2
+
3
+ export HF_PROJECT="t5-base-dutch"
4
+
5
+ # Variables for training the tokenizer and creating the config
6
+ export VOCAB_SIZE="32000"
7
+ export N_INPUT_SENTENCES="1000000" # Num of sentences to train the tokenizer
8
+ export DATASET="yhavinga/mc4_nl_cleaned" # Name of the dataset in the Huggingface Hub
9
+ export DATASET_CONFIG="full" # Config of the dataset in the Huggingface Hub
10
+ export DATASET_SPLIT="train" # Split to use for training tokenizer and model
11
+ export TEXT_FIELD="text" # Field containing the text to be used for training
12
+ export CONFIG_TYPE="t5-base" # Config that our model will use
13
+ export MODEL_PATH="${HOME}/data/${HF_PROJECT}" # Path to the model, e.g. here inside the mount
14
+
15
+ python run_t5_mlm_flax.py \
16
+ --output_dir="${MODEL_PATH}" \
17
+ --model_type="t5" \
18
+ --config_name="${MODEL_PATH}" \
19
+ --tokenizer_name="${MODEL_PATH}" \
20
+ --preprocessing_num_workers="96" \
21
+ --do_train --do_eval \
22
+ --dataset_name="${DATASET}" \
23
+ --dataset_config_name="${DATASET_CONFIG}" \
24
+ --max_seq_length="512" \
25
+ --per_device_train_batch_size="16" \
26
+ --per_device_eval_batch_size="16" \
27
+ --adafactor \
28
+ --learning_rate="0.005" \
29
+ --overwrite_output_dir \
30
+ --num_train_epochs="1" \
31
+ --logging_steps="500" \
32
+ --save_steps="80000" \
33
+ --eval_steps="2500" \
34
+ --weight_decay="0.01" \
35
+ --warmup_steps="10000" \
36
+ --validation_split_count="15000" \
37
+ --push_to_hub
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_t5_mlm_flax_custom_dataset.py → run_t5_mlm_flax.py RENAMED
@@ -18,6 +18,8 @@ Pretraining the library models for T5-like span-masked language modeling on a te
18
 
19
  Here is the full list of checkpoints on the hub that can be pretrained by this script:
20
  https://huggingface.co/models?filter=t5
 
 
21
  """
22
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
23
  import logging
@@ -25,13 +27,13 @@ import os
25
  import sys
26
  import time
27
  import json
28
- import shutil
29
  from dataclasses import dataclass, field
 
30
  from pathlib import Path
31
  from typing import Dict, List, Optional
32
 
33
  import numpy as np
34
- from datasets import load_dataset, DatasetDict
35
  from tqdm import tqdm
36
 
37
  import flax
@@ -39,34 +41,31 @@ import jax
39
  import jax.numpy as jnp
40
  import optax
41
  from flax import jax_utils, traverse_util
 
42
  from flax.training import train_state
43
  from flax.training.common_utils import get_metrics, onehot, shard
44
- from flax.serialization import to_bytes, from_bytes
45
  from transformers import (
46
  CONFIG_MAPPING,
47
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
 
48
  BatchEncoding,
49
  FlaxT5ForConditionalGeneration,
50
- T5ForConditionalGeneration,
51
  HfArgumentParser,
52
  PreTrainedTokenizerBase,
53
  T5Config,
54
- T5TokenizerFast,
55
  TrainingArguments,
56
  is_tensorboard_available,
57
  set_seed,
58
  )
 
59
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
60
 
61
  logger = logging.getLogger(__name__)
62
 
63
-
64
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
65
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
 
67
- data_files = []
68
-
69
-
70
  @dataclass
71
  class ModelArguments:
72
  """
@@ -103,6 +102,12 @@ class ModelArguments:
103
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
104
  },
105
  )
 
 
 
 
 
 
106
 
107
 
108
  @dataclass
@@ -133,10 +138,10 @@ class DataTrainingArguments:
133
  overwrite_cache: bool = field(
134
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
135
  )
136
- validation_split_percentage: Optional[int] = field(
137
- default=5,
138
  metadata={
139
- "help": "The percentage of the train set used as validation set in case there's no validation split"
140
  },
141
  )
142
  max_seq_length: Optional[int] = field(
@@ -156,18 +161,31 @@ class DataTrainingArguments:
156
  default=3.0,
157
  metadata={"help": "Mean span length of masked tokens"},
158
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  def __post_init__(self):
161
- return
162
- # if self.dataset_name is None and self.train_file is None and self.validation_file is None:
163
- # raise ValueError("Need either a dataset name or a training/validation file.")
164
- # else:
165
- # if self.train_file is not None:
166
- # extension = self.train_file.split(".")[-1]
167
- # assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
168
- # if self.validation_file is not None:
169
- # extension = self.validation_file.split(".")[-1]
170
- # assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
171
 
172
 
173
  def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
@@ -297,7 +315,7 @@ class FlaxDataCollatorForT5MLM:
297
  start_indices[:, 0] = mask_indices[:, 0]
298
 
299
  sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
300
- sentinel_ids = np.where(sentinel_ids != 0, (sentinel_ids + self.tokenizer.vocab_size - 1), 0)
301
  sentinel_ids -= mask_indices - start_indices
302
 
303
  return sentinel_ids
@@ -362,7 +380,8 @@ class FlaxDataCollatorForT5MLM:
362
  np.random.shuffle(mask_indices)
363
  first_in_segment = np.pad(mask_indices, [[1, 0]])
364
  segment_id = np.cumsum(first_in_segment)
365
- segment_length = np.asarray(jax.ops.segment_sum(np.ones_like(segment_id), segment_id))
 
366
  return segment_length
367
 
368
  noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
@@ -405,70 +424,40 @@ def write_eval_metric(summary_writer, eval_metrics, step):
405
  for metric_name, value in eval_metrics.items():
406
  summary_writer.scalar(f"eval_{metric_name}", value, step)
407
 
408
- # utils
409
  def mb_item(x):
410
  return x.item() if hasattr(x, "item") else x
411
 
412
 
413
- # checkpoint functions
414
- def save_checkpoint(model, save_dir, state, with_opt: bool = True):
415
  state = jax_utils.unreplicate(state)
416
- logger.info(f"SAVING CHECKPOINT IN {save_dir}")
417
- save_dir = f"{save_dir}/ckpt-{mb_item(state.step) - 1}"
418
- model.save_pretrained(
419
- save_dir,
420
- params=state.params,
421
- push_to_hub=False
422
- )
423
  if with_opt:
 
424
  with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
425
  f.write(to_bytes(state.opt_state))
426
  with open(os.path.join(save_dir, "training_state.json"), "w") as f:
427
  json.dump({"step": state.step.item()}, f)
428
- logger.info(f"Updating model on the hub")
429
  model.save_pretrained(
430
- training_args.output_dir,
431
  params=state.params,
432
- push_to_hub=training_args.push_to_hub,
433
  commit_message=f"Saving weights and logs of step {cur_step}",
434
  )
435
- if with_opt:
436
- with open(os.path.join(training_args.output_dir, "opt_state.msgpack"), "wb") as f:
437
- f.write(to_bytes(state.opt_state))
438
- with open(os.path.join(training_args.output_dir, "training_state.json"), "w") as f:
439
- json.dump({"step": state.step.item()}, f)
440
- logger.info("checkpoint saved")
441
 
442
-
443
- def restore_checkpoint(save_dir, state):
444
- logger.info(f"RESTORING CHECKPOINT FROM {save_dir}")
445
- with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
446
  params = from_bytes(state.params, f.read())
447
-
448
- with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
449
  opt_state = from_bytes(state.opt_state, f.read())
450
-
451
- with open(os.path.join(save_dir, "training_state.json"), "r") as f:
452
  training_state = json.load(f)
453
  step = training_state["step"]
454
-
455
- logger.info("checkpoint restored")
456
  return state.replace(step=step, params=params, opt_state=opt_state), step
457
 
458
 
459
- def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
460
- "Removes older checkpoints so that `save_total_limit` checkpoints are kept"
461
- # TODO: what to remove is decided using step number only, we might want to improve that
462
- ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
463
- # sort checkpoints by step
464
- ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
465
- ckpts_to_delete = ckpts_sorted[:-save_total_limit]
466
- for ckpt in ckpts_to_delete:
467
- logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
468
- shutil.rmtree(ckpt)
469
-
470
-
471
-
472
  if __name__ == "__main__":
473
  # See all possible arguments in src/transformers/training_args.py
474
  # or by passing the --help flag to this script.
@@ -509,6 +498,16 @@ if __name__ == "__main__":
509
  # Set seed before initializing model.
510
  set_seed(training_args.seed)
511
 
 
 
 
 
 
 
 
 
 
 
512
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
513
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
514
  # (the dataset will be downloaded automatically from the datasets Hub).
@@ -523,82 +522,38 @@ if __name__ == "__main__":
523
  datasets["validation"] = load_dataset(
524
  data_args.dataset_name,
525
  data_args.dataset_config_name,
526
- split=f"train[:{data_args.validation_split_percentage}%]",
527
  cache_dir=model_args.cache_dir,
528
  )
529
  datasets["train"] = load_dataset(
530
  data_args.dataset_name,
531
  data_args.dataset_config_name,
532
- split=f"train[{data_args.validation_split_percentage}%:]",
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  cache_dir=model_args.cache_dir,
534
  )
535
  else:
536
- data_dir = "/home/yeb"
537
- # data_dir = "/home/yeb/Developer/data"
538
-
539
- def train_val_files():
540
- import glob
541
- import random
542
- SEED = 12345
543
-
544
- def add_jsonlines_dir(path, filespec):
545
- global data_files
546
- data_files += glob.glob(f"{path}/{filespec}")
547
- data_files = list(set(data_files))
548
- print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
549
-
550
- # add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
551
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*73*.gz")
552
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
553
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*12*.gz")
554
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*29*.gz")
555
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*74*.gz")
556
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*26*.gz")
557
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
558
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
559
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
560
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
561
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
562
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
563
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
564
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
565
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
566
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
567
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
568
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
569
- add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
570
- add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
571
- random.Random(SEED).shuffle(data_files)
572
-
573
- total = len(data_files)
574
- print(total)
575
- perc = 0.05
576
- val_size = int(perc * total)
577
- train_size = total - val_size
578
- train = data_files[:train_size]
579
- val = data_files[train_size:]
580
- print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
581
-
582
- assert list(set(train) & set(val)) == [], "Train overlaps with test"
583
-
584
- return train, val
585
-
586
- # train, val = train_val_files()
587
-
588
- load_grouped = True
589
-
590
- if not load_grouped:
591
- datasets = load_dataset('json', data_files={'train': train, 'validation': val})
592
-
593
- # data_files = {}
594
- # if data_args.train_file is not None:
595
- # data_files["train"] = data_args.train_file
596
- # if data_args.validation_file is not None:
597
- # data_files["validation"] = data_args.validation_file
598
- # extension = data_args.train_file.split(".")[-1]
599
- # if extension == "txt":
600
- # extension = "text"
601
- # datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
602
 
603
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
604
  # https://huggingface.co/docs/datasets/loading_datasets.html.
@@ -606,12 +561,18 @@ if __name__ == "__main__":
606
  # Load pretrained model and tokenizer
607
 
608
  if model_args.tokenizer_name:
609
- tokenizer = T5TokenizerFast.from_pretrained(
610
- model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
 
611
  )
612
  elif model_args.model_name_or_path:
613
- tokenizer = T5TokenizerFast.from_pretrained(
614
- model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
 
615
  )
616
  else:
617
  raise ValueError(
@@ -631,8 +592,30 @@ if __name__ == "__main__":
631
  config = CONFIG_MAPPING[model_args.model_type]()
632
  logger.warning("You are instantiating a new config instance from scratch.")
633
 
 
 
 
 
 
 
 
 
634
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
637
  # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
638
  # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
@@ -643,64 +626,36 @@ if __name__ == "__main__":
643
  )
644
  logger.info(f"Max seq length: {max_seq_length}, expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
645
 
646
- # Preprocessing the datasets.
647
- # First we tokenize all the texts.
648
- if load_grouped:
649
- logger.info("Loading tokenized and grouped dataset")
650
- tokenized_datasets = DatasetDict.load_from_disk("/home/yeb/grouped_datasets")
651
- logger.info("Setting max validation examples to 500")
652
- tokenized_datasets['validation'] = tokenized_datasets['validation'].select(range(1000))
653
- else:
654
- if training_args.do_train:
655
- column_names = datasets["train"].column_names
656
- else:
657
- column_names = datasets["validation"].column_names
658
- text_column_name = "text" if "text" in column_names else column_names[0]
659
-
660
- # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
661
- # Since we make sure that all sequences are of the same length, no attention_mask is needed.
662
- def tokenize_function(examples):
663
- return tokenizer(examples[text_column_name], return_attention_mask=False)
664
-
665
- logger.info(f"Start tokenization, remove_column_names = {column_names}")
666
- tokenized_datasets = datasets.map(
667
- tokenize_function,
668
- batched=True,
669
- num_proc=data_args.preprocessing_num_workers,
670
- remove_columns=column_names,
671
- load_from_cache_file=not data_args.overwrite_cache,
672
- )
673
 
674
- # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
675
- def group_texts(examples):
676
- # Concatenate all texts.
677
- concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
678
- total_length = len(concatenated_examples[list(examples.keys())[0]])
679
- # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
680
- # customize this part to your needs.
681
- if total_length >= expanded_inputs_length:
682
- total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
683
- # Split by chunks of max_len.
684
- result = {
685
- k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
686
- for k, t in concatenated_examples.items()
687
- }
688
- return result
689
-
690
- # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
691
- # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
692
- # might be slower to preprocess.
693
- #
694
- # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
695
- # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
696
- logger.info(f"Start group_texts")
697
- tokenized_datasets = tokenized_datasets.map(
698
- group_texts,
699
- batched=True,
700
- batch_size=200,
701
- num_proc=data_args.preprocessing_num_workers,
702
- load_from_cache_file=not data_args.overwrite_cache,
703
- )
704
 
705
  # Enable tensorboard only on the master node
706
  has_tensorboard = is_tensorboard_available()
@@ -729,15 +684,9 @@ if __name__ == "__main__":
729
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
730
  )
731
  else:
 
732
  model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
733
 
734
-
735
- # def to_bf16(t):
736
- # return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
737
- #
738
- #
739
- # model.params = to_bf16(model.params)
740
-
741
  # Data collator
742
  # This one will take care of randomly masking the tokens.
743
  data_collator = FlaxDataCollatorForT5MLM(
@@ -752,16 +701,13 @@ if __name__ == "__main__":
752
 
753
  # Store some constant
754
  num_epochs = int(training_args.num_train_epochs)
755
- train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() * training_args.gradient_accumulation_steps
756
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
757
 
758
- num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
759
-
760
  steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
761
- total_train_steps = steps_per_epoch * num_epochs
762
 
763
  # Create learning rate schedule
764
-
765
  if training_args.warmup_steps:
766
  warmup_steps = training_args.warmup_steps
767
  elif training_args.warmup_ratio:
@@ -770,7 +716,6 @@ if __name__ == "__main__":
770
  logging.info(f"Warmup steps set to {100*training_args.warmup_ratio}% = {warmup_steps} of total train steps {num_train_steps}")
771
  else:
772
  raise Exception("Need either --warmup_steps or --warmup_ratio")
773
-
774
  warmup_fn = optax.linear_schedule(
775
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
776
  )
@@ -823,8 +768,6 @@ if __name__ == "__main__":
823
  else:
824
  resume_step = 0
825
 
826
- logger.info("")
827
-
828
  # Define gradient update step fn
829
  def train_step(state, batch, dropout_rng):
830
  dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
@@ -845,7 +788,8 @@ if __name__ == "__main__":
845
  new_state = state.apply_gradients(grads=grad)
846
 
847
  metrics = jax.lax.pmean(
848
- {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)}, axis_name="batch"
 
849
  )
850
 
851
  return new_state, metrics, new_dropout_rng
@@ -875,17 +819,20 @@ if __name__ == "__main__":
875
 
876
  logger.info("Replicate the train state on each device")
877
 
 
 
 
 
878
  # Replicate the train state on each device
879
  state = jax_utils.replicate(state)
880
 
881
  logger.info("***** Running training *****")
882
- if not load_grouped:
883
- logger.info(f" Num examples = {len(datasets['train'])}")
884
  logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
885
  logger.info(f" Num Epochs = {num_epochs}")
886
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
887
  logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
888
- logger.info(f" Total optimization steps = {total_train_steps}")
889
 
890
  train_time = 0
891
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
@@ -899,16 +846,26 @@ if __name__ == "__main__":
899
 
900
  # Generate an epoch by shuffling sampling indices from the train dataset
901
  num_train_samples = len(tokenized_datasets["train"])
902
- train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
903
- train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size // grad_accum_steps)
 
 
 
 
 
 
 
904
 
905
  # Gather the indexes for creating the batch and do a training step
906
- for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
 
 
907
  cur_step = epoch * (num_train_samples // train_batch_size) + step
908
  # skip to the step from which we are resuming
909
  if cur_step < resume_step:
910
  continue
911
 
 
912
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
913
  try:
914
  model_inputs = data_collator(samples)
@@ -922,7 +879,6 @@ if __name__ == "__main__":
922
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
923
  train_metrics.append(train_metric)
924
 
925
-
926
  if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
927
  # Save metrics
928
  train_metric = jax_utils.unreplicate(train_metric)
@@ -931,7 +887,7 @@ if __name__ == "__main__":
931
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
932
 
933
  epochs.write(
934
- f"Step... ({cur_step} ({cur_step+resume_step}| Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
935
  )
936
 
937
  train_metrics = []
@@ -961,39 +917,50 @@ if __name__ == "__main__":
961
 
962
  # Save metrics
963
  if has_tensorboard and jax.process_index() == 0:
964
- cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
965
  write_eval_metric(summary_writer, eval_metrics, cur_step)
966
 
967
  if cur_step % training_args.save_steps * grad_accum_steps == 0 and cur_step > 0:
968
- logger.info(f"We should save the model here after {cur_step} steps")
969
  # save checkpoint after each epoch and push checkpoint to the hub
970
  if jax.process_index() == 0:
971
- save_checkpoint(model, training_args.output_dir, state)
972
- if training_args.save_total_limit is not None:
973
- rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
974
  # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
975
- #
976
- # logger.info(f"Saving model after {cur_step} steps")
977
- # model.save_pretrained(
978
- # training_args.output_dir,
979
- # params=params,
980
- # push_to_hub=training_args.push_to_hub,
981
- # commit_message=f"Saving weights and logs of step {cur_step}",
982
- # )
 
 
 
 
 
 
 
 
983
 
 
 
 
 
 
 
 
 
984
 
985
- # Save model at end
986
  if jax.process_index() == 0:
987
- save_checkpoint(model, training_args.output_dir, state, with_opt=False)
988
-
989
- # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
990
- # logger.info(f"Saving model at end")
991
- # model.save_pretrained(
992
- # training_args.output_dir,
993
- # params=params,
994
- # push_to_hub=training_args.push_to_hub,
995
- # commit_message=f"Saving weights and logs at end of run (step {cur_step})",
996
- # )
997
- # pt_model = T5ForConditionalGeneration.from_pretrained(training_args.output_dir, from_flax=True)
998
- # pt_model.save_pretrained(training_args.output_dir,
999
- # params=params)
 
 
18
 
19
  Here is the full list of checkpoints on the hub that can be pretrained by this script:
20
  https://huggingface.co/models?filter=t5
21
+
22
+ Adapted from the original version to support gradient accumulation and restarting.
23
  """
24
  # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
25
  import logging
 
27
  import sys
28
  import time
29
  import json
 
30
  from dataclasses import dataclass, field
31
+ from itertools import chain
32
  from pathlib import Path
33
  from typing import Dict, List, Optional
34
 
35
  import numpy as np
36
+ from datasets import load_dataset
37
  from tqdm import tqdm
38
 
39
  import flax
 
41
  import jax.numpy as jnp
42
  import optax
43
  from flax import jax_utils, traverse_util
44
+ from flax.serialization import to_bytes, from_bytes
45
  from flax.training import train_state
46
  from flax.training.common_utils import get_metrics, onehot, shard
47
+ # from huggingface_hub import Repository
48
  from transformers import (
49
  CONFIG_MAPPING,
50
  FLAX_MODEL_FOR_MASKED_LM_MAPPING,
51
+ AutoTokenizer,
52
  BatchEncoding,
53
  FlaxT5ForConditionalGeneration,
 
54
  HfArgumentParser,
55
  PreTrainedTokenizerBase,
56
  T5Config,
 
57
  TrainingArguments,
58
  is_tensorboard_available,
59
  set_seed,
60
  )
61
+ # from transformers.file_utils import get_full_repo_name
62
  from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
63
 
64
  logger = logging.getLogger(__name__)
65
 
 
66
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
67
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
68
 
 
 
 
69
  @dataclass
70
  class ModelArguments:
71
  """
 
102
  "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
103
  },
104
  )
105
+ auth_token: Optional[str] = field(
106
+ default=None,
107
+ metadata={
108
+ "help": "Auth token for private repositories on the Huggingface Hub"
109
+ }
110
+ )
111
 
112
 
113
  @dataclass
 
138
  overwrite_cache: bool = field(
139
  default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
140
  )
141
+ validation_split_count: Optional[int] = field(
142
+ default=10000,
143
  metadata={
144
+ "help": "The count of the train set used as validation set in case there's no validation split"
145
  },
146
  )
147
  max_seq_length: Optional[int] = field(
 
161
  default=3.0,
162
  metadata={"help": "Mean span length of masked tokens"},
163
  )
164
+ max_train_samples: Optional[int] = field(
165
+ default=None,
166
+ metadata={
167
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
168
+ "value if set."
169
+ },
170
+ )
171
+ max_eval_samples: Optional[int] = field(
172
+ default=None,
173
+ metadata={
174
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
175
+ "value if set."
176
+ },
177
+ )
178
 
179
  def __post_init__(self):
180
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
181
+ raise ValueError("Need either a dataset name or a training/validation file.")
182
+ else:
183
+ if self.train_file is not None:
184
+ extension = self.train_file.split(".")[-1]
185
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
186
+ if self.validation_file is not None:
187
+ extension = self.validation_file.split(".")[-1]
188
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
 
189
 
190
 
191
  def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
 
315
  start_indices[:, 0] = mask_indices[:, 0]
316
 
317
  sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
318
+ sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
319
  sentinel_ids -= mask_indices - start_indices
320
 
321
  return sentinel_ids
 
380
  np.random.shuffle(mask_indices)
381
  first_in_segment = np.pad(mask_indices, [[1, 0]])
382
  segment_id = np.cumsum(first_in_segment)
383
+ # count length of sub segments assuming that list is sorted
384
+ _, segment_length = np.unique(segment_id, return_counts=True)
385
  return segment_length
386
 
387
  noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
 
424
  for metric_name, value in eval_metrics.items():
425
  summary_writer.scalar(f"eval_{metric_name}", value, step)
426
 
427
+
428
  def mb_item(x):
429
  return x.item() if hasattr(x, "item") else x
430
 
431
 
432
+ def save_checkpoint(model, save_dir, state, cur_step: int, with_opt: bool = True, push_to_hub: bool = False):
 
433
  state = jax_utils.unreplicate(state)
 
 
 
 
 
 
 
434
  if with_opt:
435
+ logger.info(f'Saving optimizer and training state in {save_dir}...')
436
  with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
437
  f.write(to_bytes(state.opt_state))
438
  with open(os.path.join(save_dir, "training_state.json"), "w") as f:
439
  json.dump({"step": state.step.item()}, f)
440
+ logger.info(f'Saving model in {save_dir} {"and pushing it to HF Hub" if push_to_hub else ""}')
441
  model.save_pretrained(
442
+ save_dir,
443
  params=state.params,
444
+ push_to_hub=push_to_hub,
445
  commit_message=f"Saving weights and logs of step {cur_step}",
446
  )
 
 
 
 
 
 
447
 
448
+ def restore_checkpoint(load_dir, state):
449
+ logger.info(f"Restoring checkpoint from {load_dir}")
450
+ with open(os.path.join(load_dir, "flax_model.msgpack"), "rb") as f:
 
451
  params = from_bytes(state.params, f.read())
452
+ with open(os.path.join(load_dir, "opt_state.msgpack"), "rb") as f:
 
453
  opt_state = from_bytes(state.opt_state, f.read())
454
+ with open(os.path.join(load_dir, "training_state.json"), "r") as f:
 
455
  training_state = json.load(f)
456
  step = training_state["step"]
457
+ logger.info(f"Checkpoint restored at step {step}")
 
458
  return state.replace(step=step, params=params, opt_state=opt_state), step
459
 
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  if __name__ == "__main__":
462
  # See all possible arguments in src/transformers/training_args.py
463
  # or by passing the --help flag to this script.
 
498
  # Set seed before initializing model.
499
  set_seed(training_args.seed)
500
 
501
+ # Handle the repository creation
502
+ # if training_args.push_to_hub:
503
+ # if training_args.hub_model_id is None:
504
+ # repo_name = get_full_repo_name(
505
+ # Path(training_args.output_dir).absolute().name, token=training_args.hub_token
506
+ # )
507
+ # else:
508
+ # repo_name = training_args.hub_model_id
509
+ # repo = Repository(training_args.output_dir, clone_from=repo_name)
510
+
511
  # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
512
  # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
513
  # (the dataset will be downloaded automatically from the datasets Hub).
 
522
  datasets["validation"] = load_dataset(
523
  data_args.dataset_name,
524
  data_args.dataset_config_name,
525
+ split=f"train[:{data_args.validation_split_count}]",
526
  cache_dir=model_args.cache_dir,
527
  )
528
  datasets["train"] = load_dataset(
529
  data_args.dataset_name,
530
  data_args.dataset_config_name,
531
+ split=f"train[{data_args.validation_split_count}:]",
532
+ cache_dir=model_args.cache_dir,
533
+ )
534
+ else:
535
+ datasets["validation"] = load_dataset(
536
+ data_args.dataset_name,
537
+ data_args.dataset_config_name,
538
+ split=f"validation[:{data_args.validation_split_count}]",
539
+ cache_dir=model_args.cache_dir,
540
+ )
541
+ datasets["train"] = load_dataset(
542
+ data_args.dataset_name,
543
+ data_args.dataset_config_name,
544
+ split="train",
545
  cache_dir=model_args.cache_dir,
546
  )
547
  else:
548
+ data_files = {}
549
+ if data_args.train_file is not None:
550
+ data_files["train"] = data_args.train_file
551
+ if data_args.validation_file is not None:
552
+ data_files["validation"] = data_args.validation_file
553
+ extension = data_args.train_file.split(".")[-1]
554
+ if extension == "txt":
555
+ extension = "text"
556
+ datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
 
558
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
559
  # https://huggingface.co/docs/datasets/loading_datasets.html.
 
561
  # Load pretrained model and tokenizer
562
 
563
  if model_args.tokenizer_name:
564
+ tokenizer = AutoTokenizer.from_pretrained(
565
+ model_args.tokenizer_name,
566
+ cache_dir=model_args.cache_dir,
567
+ use_fast=model_args.use_fast_tokenizer,
568
+ use_auth_token=model_args.auth_token
569
  )
570
  elif model_args.model_name_or_path:
571
+ tokenizer = AutoTokenizer.from_pretrained(
572
+ model_args.model_name_or_path,
573
+ cache_dir=model_args.cache_dir,
574
+ use_fast=model_args.use_fast_tokenizer,
575
+ use_auth_token=model_args.auth_token
576
  )
577
  else:
578
  raise ValueError(
 
592
  config = CONFIG_MAPPING[model_args.model_type]()
593
  logger.warning("You are instantiating a new config instance from scratch.")
594
 
595
+ # Preprocessing the datasets.
596
+ # First we tokenize all the texts.
597
+ if training_args.do_train:
598
+ column_names = datasets["train"].column_names
599
+ else:
600
+ column_names = datasets["validation"].column_names
601
+ text_column_name = "text" if "text" in column_names else column_names[0]
602
+
603
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
604
 
605
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
606
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
607
+ def tokenize_function(examples):
608
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
609
+
610
+ logger.info(f"Start tokenization, remove_column_names = {column_names}")
611
+ tokenized_datasets = datasets.map(
612
+ tokenize_function,
613
+ batched=True,
614
+ num_proc=data_args.preprocessing_num_workers,
615
+ remove_columns=column_names,
616
+ load_from_cache_file=not data_args.overwrite_cache,
617
+ )
618
+
619
  # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
620
  # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
621
  # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
 
626
  )
627
  logger.info(f"Max seq length: {max_seq_length}, expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
628
 
629
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
630
+ def group_texts(examples):
631
+ # Concatenate all texts.
632
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
633
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
634
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
635
+ # customize this part to your needs.
636
+ if total_length >= expanded_inputs_length:
637
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
638
+ # Split by chunks of max_len.
639
+ result = {
640
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
641
+ for k, t in concatenated_examples.items()
642
+ }
643
+ return result
 
 
 
 
 
 
 
 
 
 
 
 
644
 
645
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
646
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
647
+ # might be slower to preprocess.
648
+ #
649
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
650
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
651
+ logger.info(f"Start group_texts")
652
+ tokenized_datasets = tokenized_datasets.map(
653
+ group_texts,
654
+ batched=True,
655
+ batch_size=200,
656
+ num_proc=data_args.preprocessing_num_workers,
657
+ load_from_cache_file=not data_args.overwrite_cache,
658
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
  # Enable tensorboard only on the master node
661
  has_tensorboard = is_tensorboard_available()
 
684
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
685
  )
686
  else:
687
+ config.vocab_size = len(tokenizer)
688
  model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
689
 
 
 
 
 
 
 
 
690
  # Data collator
691
  # This one will take care of randomly masking the tokens.
692
  data_collator = FlaxDataCollatorForT5MLM(
 
701
 
702
  # Store some constant
703
  num_epochs = int(training_args.num_train_epochs)
704
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
705
  eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
706
 
 
 
707
  steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
708
+ num_train_steps = steps_per_epoch * num_epochs
709
 
710
  # Create learning rate schedule
 
711
  if training_args.warmup_steps:
712
  warmup_steps = training_args.warmup_steps
713
  elif training_args.warmup_ratio:
 
716
  logging.info(f"Warmup steps set to {100*training_args.warmup_ratio}% = {warmup_steps} of total train steps {num_train_steps}")
717
  else:
718
  raise Exception("Need either --warmup_steps or --warmup_ratio")
 
719
  warmup_fn = optax.linear_schedule(
720
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
721
  )
 
768
  else:
769
  resume_step = 0
770
 
 
 
771
  # Define gradient update step fn
772
  def train_step(state, batch, dropout_rng):
773
  dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
 
788
  new_state = state.apply_gradients(grads=grad)
789
 
790
  metrics = jax.lax.pmean(
791
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)},
792
+ axis_name="batch"
793
  )
794
 
795
  return new_state, metrics, new_dropout_rng
 
819
 
820
  logger.info("Replicate the train state on each device")
821
 
822
+ # import pydevd_pycharm
823
+ #
824
+ # pydevd_pycharm.settrace('localhost', port=12345, stdoutToServer=True, stderrToServer=True)
825
+
826
  # Replicate the train state on each device
827
  state = jax_utils.replicate(state)
828
 
829
  logger.info("***** Running training *****")
830
+ logger.info(f" Num examples = {len(datasets['train'])}")
 
831
  logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
832
  logger.info(f" Num Epochs = {num_epochs}")
833
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
834
  logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
835
+ logger.info(f" Total optimization steps = {num_train_steps}")
836
 
837
  train_time = 0
838
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
 
846
 
847
  # Generate an epoch by shuffling sampling indices from the train dataset
848
  num_train_samples = len(tokenized_datasets["train"])
849
+ # train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
850
+ # train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
851
+
852
+ ## IF THE DATASET IS TOO LONG, WE ONLY PROCEED SEQUENTIALLY WITHOUT SHUFFLING
853
+ samples_to_remove = num_train_samples % (train_batch_size // grad_accum_steps)
854
+ samples_idx = np.arange(num_train_samples)
855
+ if samples_to_remove != 0:
856
+ samples_idx = samples_idx[:-samples_to_remove]
857
+ steps = num_train_samples // (train_batch_size // grad_accum_steps)
858
 
859
  # Gather the indexes for creating the batch and do a training step
860
+ # for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
861
+ # samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
862
+ for step in tqdm(range(steps), desc="Training...", position=1):
863
  cur_step = epoch * (num_train_samples // train_batch_size) + step
864
  # skip to the step from which we are resuming
865
  if cur_step < resume_step:
866
  continue
867
 
868
+ batch_idx = [x for x in range(step * train_batch_size, (step + 1) * train_batch_size)]
869
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
870
  try:
871
  model_inputs = data_collator(samples)
 
879
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
880
  train_metrics.append(train_metric)
881
 
 
882
  if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
883
  # Save metrics
884
  train_metric = jax_utils.unreplicate(train_metric)
 
887
  write_train_metric(summary_writer, train_metrics, train_time, cur_step)
888
 
889
  epochs.write(
890
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
891
  )
892
 
893
  train_metrics = []
 
917
 
918
  # Save metrics
919
  if has_tensorboard and jax.process_index() == 0:
 
920
  write_eval_metric(summary_writer, eval_metrics, cur_step)
921
 
922
  if cur_step % training_args.save_steps * grad_accum_steps == 0 and cur_step > 0:
 
923
  # save checkpoint after each epoch and push checkpoint to the hub
924
  if jax.process_index() == 0:
 
 
 
925
  # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
926
+ # model.save_pretrained(training_args.output_dir, params=params)
927
+ # tokenizer.save_pretrained(training_args.output_dir)
928
+ # if training_args.push_to_hub:
929
+ # repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
930
+ save_checkpoint(model, training_args.output_dir, state, cur_step, with_opt=False, push_to_hub=True)
931
+
932
+ # Eval after training
933
+ if training_args.do_eval:
934
+ num_eval_samples = len(tokenized_datasets["validation"])
935
+ eval_samples_idx = jnp.arange(num_eval_samples)
936
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
937
+
938
+ eval_metrics = []
939
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
940
+ samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
941
+ model_inputs = data_collator(samples)
942
 
943
+ # Model forward
944
+ model_inputs = shard(model_inputs.data)
945
+ metrics = p_eval_step(state.params, model_inputs)
946
+ eval_metrics.append(metrics)
947
+
948
+ # get eval metrics
949
+ eval_metrics = get_metrics(eval_metrics)
950
+ eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
951
 
 
952
  if jax.process_index() == 0:
953
+ eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
954
+ path = os.path.join(training_args.output_dir, "eval_results.json")
955
+ with open(path, "w") as f:
956
+ json.dump(eval_metrics, f, indent=4, sort_keys=True)
957
+
958
+ # Save model at end
959
+ if jax.process_index() == 0:
960
+ # params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
961
+ # model.save_pretrained(training_args.output_dir, params=params)
962
+ # tokenizer.save_pretrained(training_args.output_dir)
963
+ # if training_args.push_to_hub:
964
+ # repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
965
+ #
966
+ save_checkpoint(model, training_args.output_dir, state, cur_step, with_opt=False, push_to_hub=True)
streaming_dataset_filter_test.py DELETED
@@ -1,93 +0,0 @@
1
- from clean import clean_text
2
-
3
- from datasets import load_dataset
4
-
5
- dataset_v0 = load_dataset('oscar', "unshuffled_deduplicated_nl", split='train', streaming=True)
6
-
7
- # data_dir = "/home/yeb"
8
- data_dir = "/home/yeb/Developer/data"
9
- data_files = []
10
-
11
- def train_val_files():
12
- import glob
13
- import random
14
- SEED = 12345
15
-
16
- def add_jsonlines_dir(path, filespec):
17
- global data_files
18
- data_files += glob.glob(f"{path}/{filespec}")
19
- data_files = list(set(data_files))
20
- print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
21
-
22
- # add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
23
- add_jsonlines_dir(f"{data_dir}/c4_cleaned2", "*73*.gz")
24
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
25
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*12*.gz")
26
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*29*.gz")
27
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*74*.gz")
28
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*26*.gz")
29
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
30
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
31
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
32
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
33
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
34
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
35
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
36
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
37
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
38
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
39
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
40
- # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
41
- # add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
42
- # add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
43
- random.Random(SEED).shuffle(data_files)
44
-
45
- total = len(data_files)
46
- print(total)
47
- perc = 0.05
48
- val_size = int(perc * total)
49
- train_size = total - val_size
50
- train = data_files[:train_size]
51
- val = data_files[train_size:]
52
- print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
53
-
54
- assert list(set(train) & set(val)) == [], "Train overlaps with test"
55
-
56
- return train, val
57
-
58
- train, val = train_val_files()
59
- dataset_v0 = load_dataset('json', data_files={'train': train, 'validation': val})
60
-
61
-
62
- dataset_v0 = load_dataset('oscar', "unshuffled_deduplicated_nl")
63
-
64
- def f(obj):
65
- obj["text"] = clean_text(obj["text"])
66
- return obj
67
-
68
-
69
- dataset_v1 = dataset_v0.map(
70
- f,
71
- batched=False,
72
- num_proc=10,
73
- )
74
-
75
- datasets = dataset_v1.filter(
76
- lambda obj: obj['text'] is not None,
77
- num_proc=10,
78
- )
79
-
80
- it = iter(dataset_v0['train'])
81
- print(next(it))
82
- print(next(it))
83
- print(next(it))
84
-
85
- it = iter(dataset_v1['train'])
86
- print(next(it))
87
- print(next(it))
88
- print(next(it))
89
-
90
- # it = iter(dataset_v2)
91
- # print(next(it))
92
- # print(next(it))
93
- # print(next(it))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tf_model.h5 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ca091f719f88d0c460cb709fead1521082e46ac9b1d9873a06e65bb0ca2d94c
3
- size 892067416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3083c65d23d0521977a9739022c8e48f3ee1094d43317b150cf044f1451cfd9c
3
+ size 892068248
train_tokenizer.py DELETED
@@ -1,66 +0,0 @@
1
- from datasets import load_dataset
2
- from t5_tokenizer_model import SentencePieceUnigramTokenizer
3
-
4
- # from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
5
-
6
- data_dir = "/home/yeb"
7
- data_files = []
8
-
9
-
10
- def train_val_files():
11
- import glob
12
- import random
13
- SEED = 12345
14
-
15
- def add_jsonlines_dir(path, filespec):
16
- global data_files
17
- data_files += glob.glob(f"{path}/{filespec}")
18
- print(f"Number of files {len(data_files)} after adding {path}")
19
-
20
- # add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
21
- add_jsonlines_dir(f"{data_dir}/c4_cleaned2", "*47*.gz")
22
- add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
23
- add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
24
- random.Random(SEED).shuffle(data_files)
25
-
26
- print(data_files)
27
- total = len(data_files)
28
- print(total)
29
- perc = 0.01
30
- val_size = int(perc * total)
31
- train_size = total - val_size
32
- train = data_files[:train_size]
33
- val = data_files[train_size:]
34
- print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
35
-
36
- assert list(set(train) & set(val)) == [], "Train overlaps with test"
37
-
38
- return train, val
39
-
40
-
41
- train, val = train_val_files()
42
-
43
- dataset = load_dataset('json', data_files={'train': train, 'validation': val}, split='train')
44
-
45
- vocab_size = 32000
46
- input_sentence_size = None
47
- tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")
48
-
49
-
50
- # Build an iterator over this dataset
51
- def batch_iterator(input_sentence_size=None):
52
- if input_sentence_size is None:
53
- input_sentence_size = len(dataset)
54
- batch_length = 100
55
- for i in range(0, input_sentence_size, batch_length):
56
- yield dataset[i: i + batch_length]["text"]
57
-
58
- # Train tokenizer
59
- tokenizer.train_from_iterator(
60
- iterator=batch_iterator(input_sentence_size=input_sentence_size),
61
- vocab_size=vocab_size,
62
- show_progress=True,
63
- )
64
-
65
- # Save files to disk
66
- tokenizer.save("./tokenizer.json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
training_state.json DELETED
@@ -1 +0,0 @@
1
- {"step": 62500}