Yeb Havinga
commited on
Commit
•
49e8767
1
Parent(s):
2c7b7d9
Replace scripts and model with improved version
Browse files- Load_preprocessed_dataset.ipynb +0 -165
- Load_token_group_dataset.ipynb +0 -567
- README.md +30 -12
- config.json +1 -1
- flax_model.msgpack +1 -1
- flax_to_pt.py +26 -6
- opt_state.msgpack +0 -3
- pytorch_model.bin +1 -1
- run_t5.sh +37 -79
- run_t5_mlm_flax_custom_dataset.py → run_t5_mlm_flax.py +213 -246
- streaming_dataset_filter_test.py +0 -93
- tf_model.h5 +2 -2
- train_tokenizer.py +0 -66
- training_state.json +0 -1
Load_preprocessed_dataset.ipynb
DELETED
@@ -1,165 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 4,
|
6 |
-
"id": "cf148030-7287-4c9e-ae32-8d1e1c47be30",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [],
|
9 |
-
"source": [
|
10 |
-
"from datasets import Dataset, DatasetDict"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": 7,
|
16 |
-
"id": "5161b4ba-e8cf-43e1-b67e-503c29aa4271",
|
17 |
-
"metadata": {},
|
18 |
-
"outputs": [
|
19 |
-
{
|
20 |
-
"ename": "FileNotFoundError",
|
21 |
-
"evalue": "[Errno 2] No such file or directory: '/home/yeb/grouped_dataset/dataset_dict.json'",
|
22 |
-
"output_type": "error",
|
23 |
-
"traceback": [
|
24 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
25 |
-
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
26 |
-
"\u001b[0;32m/tmp/ipykernel_574434/3668239933.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdatasets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDatasetDict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_from_disk\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"/home/yeb/grouped_dataset\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
|
27 |
-
"\u001b[0;32m~/datasets/src/datasets/dataset_dict.py\u001b[0m in \u001b[0;36mload_from_disk\u001b[0;34m(dataset_dict_path, fs, keep_in_memory)\u001b[0m\n\u001b[1;32m 727\u001b[0m \u001b[0;34mf\"No such file or directory: '{dataset_dict_json_path}'. Expected to load a DatasetDict object, but got a Dataset. Please use datasets.load_from_disk instead.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 728\u001b[0m )\n\u001b[0;32m--> 729\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mjson\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset_dict_json_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"r\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mencoding\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"utf-8\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"splits\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 730\u001b[0m dataset_dict_split_path = (\n\u001b[1;32m 731\u001b[0m \u001b[0mdataset_dict_path\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"://\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"://\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mPath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdest_dataset_dict_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mas_posix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
28 |
-
"\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/spec.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(self, path, mode, block_size, cache_options, **kwargs)\u001b[0m\n\u001b[1;32m 956\u001b[0m }\n\u001b[1;32m 957\u001b[0m return io.TextIOWrapper(\n\u001b[0;32m--> 958\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mblock_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mtext_kwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 959\u001b[0m )\n\u001b[1;32m 960\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
29 |
-
"\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/spec.py\u001b[0m in \u001b[0;36mopen\u001b[0;34m(self, path, mode, block_size, cache_options, **kwargs)\u001b[0m\n\u001b[1;32m 960\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 961\u001b[0m \u001b[0mac\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"autocommit\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_intrans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 962\u001b[0;31m f = self._open(\n\u001b[0m\u001b[1;32m 963\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
30 |
-
"\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/implementations/local.py\u001b[0m in \u001b[0;36m_open\u001b[0;34m(self, path, mode, block_size, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_mkdir\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m\"w\"\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmakedirs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexist_ok\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 144\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mLocalFileOpener\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 145\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtouch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
31 |
-
"\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/implementations/local.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, path, mode, autocommit, fs, compression, **kwargs)\u001b[0m\n\u001b[1;32m 233\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompression\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_compression\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompression\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblocksize\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFAULT_BUFFER_SIZE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 235\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_open\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 236\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_open\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
32 |
-
"\u001b[0;32m~/venv/lib/python3.8/site-packages/fsspec/implementations/local.py\u001b[0m in \u001b[0;36m_open\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclosed\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautocommit\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m\"w\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 240\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 241\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompression\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0mcompress\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompression\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
33 |
-
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/yeb/grouped_dataset/dataset_dict.json'"
|
34 |
-
]
|
35 |
-
}
|
36 |
-
],
|
37 |
-
"source": [
|
38 |
-
"datasets = DatasetDict.load_from_disk(\"/home/yeb/grouped_dataset\")"
|
39 |
-
]
|
40 |
-
},
|
41 |
-
{
|
42 |
-
"cell_type": "code",
|
43 |
-
"execution_count": 12,
|
44 |
-
"id": "15f9d047-ac35-43d7-ab55-9f9afe96dd07",
|
45 |
-
"metadata": {},
|
46 |
-
"outputs": [
|
47 |
-
{
|
48 |
-
"data": {
|
49 |
-
"text/plain": [
|
50 |
-
"DatasetDict({\n",
|
51 |
-
" train: Dataset({\n",
|
52 |
-
" features: ['input_ids'],\n",
|
53 |
-
" num_rows: 86438919\n",
|
54 |
-
" })\n",
|
55 |
-
" validation: Dataset({\n",
|
56 |
-
" features: ['input_ids'],\n",
|
57 |
-
" num_rows: 4735324\n",
|
58 |
-
" })\n",
|
59 |
-
"})"
|
60 |
-
]
|
61 |
-
},
|
62 |
-
"execution_count": 12,
|
63 |
-
"metadata": {},
|
64 |
-
"output_type": "execute_result"
|
65 |
-
}
|
66 |
-
],
|
67 |
-
"source": [
|
68 |
-
"datasets"
|
69 |
-
]
|
70 |
-
},
|
71 |
-
{
|
72 |
-
"cell_type": "code",
|
73 |
-
"execution_count": 14,
|
74 |
-
"id": "d1d1218e-142e-441a-b20d-d300b13b172a",
|
75 |
-
"metadata": {},
|
76 |
-
"outputs": [],
|
77 |
-
"source": [
|
78 |
-
"train = datasets['train']"
|
79 |
-
]
|
80 |
-
},
|
81 |
-
{
|
82 |
-
"cell_type": "code",
|
83 |
-
"execution_count": null,
|
84 |
-
"id": "9eaddfb1-242f-4a25-8789-efe97b2a5712",
|
85 |
-
"metadata": {},
|
86 |
-
"outputs": [],
|
87 |
-
"source": []
|
88 |
-
},
|
89 |
-
{
|
90 |
-
"cell_type": "code",
|
91 |
-
"execution_count": 15,
|
92 |
-
"id": "8aabb26f-19ca-467a-b383-3a693be70cac",
|
93 |
-
"metadata": {},
|
94 |
-
"outputs": [
|
95 |
-
{
|
96 |
-
"name": "stdout",
|
97 |
-
"output_type": "stream",
|
98 |
-
"text": [
|
99 |
-
"86438919\n"
|
100 |
-
]
|
101 |
-
}
|
102 |
-
],
|
103 |
-
"source": [
|
104 |
-
"print(len(train))"
|
105 |
-
]
|
106 |
-
},
|
107 |
-
{
|
108 |
-
"cell_type": "code",
|
109 |
-
"execution_count": null,
|
110 |
-
"id": "f3176986-5b34-4ed6-a643-e342db9a2ce8",
|
111 |
-
"metadata": {},
|
112 |
-
"outputs": [],
|
113 |
-
"source": []
|
114 |
-
},
|
115 |
-
{
|
116 |
-
"cell_type": "code",
|
117 |
-
"execution_count": 16,
|
118 |
-
"id": "1205bbef-ba9d-4ddc-af2e-602d56b7dd64",
|
119 |
-
"metadata": {},
|
120 |
-
"outputs": [
|
121 |
-
{
|
122 |
-
"name": "stdout",
|
123 |
-
"output_type": "stream",
|
124 |
-
"text": [
|
125 |
-
"{'input_ids': [256, 3, 20, 18452, 6690, 7757, 1286, 43, 10, 4942, 1286, 80, 12, 4782, 5442, 39, 5385, 33, 4, 5, 3, 2924, 117, 5669, 228, 21, 193, 9030, 511, 24, 11, 5, 665, 165, 4218, 7, 26, 264, 1528, 35, 105, 3, 19653, 12, 9661, 17156, 13955, 4, 132, 5, 611, 959, 961, 146, 6522, 7757, 1286, 89, 7500, 9716, 11, 5, 4868, 107, 13604, 12, 12836, 13368, 11, 611, 959, 4, 3, 69, 99, 12, 13132, 6690, 590, 5, 1803, 1867, 69, 7, 924, 10, 1762, 4, 3, 69, 538, 489, 14, 1149, 16, 3, 11384, 199, 116, 399, 4782, 291, 3, 6, 237, 13, 2629, 3, 8987, 291, 4, 69, 5, 3, 27, 72, 20, 325, 3, 2924, 133, 21, 105, 9030, 10, 1149, 242, 16, 144, 13572, 11, 9, 13401, 20, 7951, 8, 165, 4218, 4, 5, 1910]}\n"
|
126 |
-
]
|
127 |
-
}
|
128 |
-
],
|
129 |
-
"source": [
|
130 |
-
"it = iter(train)\n",
|
131 |
-
"\n",
|
132 |
-
"print(next(it))"
|
133 |
-
]
|
134 |
-
},
|
135 |
-
{
|
136 |
-
"cell_type": "code",
|
137 |
-
"execution_count": null,
|
138 |
-
"id": "f5d4e8de-419c-4c70-896e-fbd640bb7321",
|
139 |
-
"metadata": {},
|
140 |
-
"outputs": [],
|
141 |
-
"source": []
|
142 |
-
}
|
143 |
-
],
|
144 |
-
"metadata": {
|
145 |
-
"kernelspec": {
|
146 |
-
"display_name": "Python 3 (ipykernel)",
|
147 |
-
"language": "python",
|
148 |
-
"name": "python3"
|
149 |
-
},
|
150 |
-
"language_info": {
|
151 |
-
"codemirror_mode": {
|
152 |
-
"name": "ipython",
|
153 |
-
"version": 3
|
154 |
-
},
|
155 |
-
"file_extension": ".py",
|
156 |
-
"mimetype": "text/x-python",
|
157 |
-
"name": "python",
|
158 |
-
"nbconvert_exporter": "python",
|
159 |
-
"pygments_lexer": "ipython3",
|
160 |
-
"version": "3.8.10"
|
161 |
-
}
|
162 |
-
},
|
163 |
-
"nbformat": 4,
|
164 |
-
"nbformat_minor": 5
|
165 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Load_token_group_dataset.ipynb
DELETED
@@ -1,567 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": 71,
|
6 |
-
"id": "d7f2bdb5-95c2-4a57-80e8-8f1a30a138b0",
|
7 |
-
"metadata": {},
|
8 |
-
"outputs": [
|
9 |
-
{
|
10 |
-
"name": "stdout",
|
11 |
-
"output_type": "stream",
|
12 |
-
"text": [
|
13 |
-
"Number of files 20 after adding ./c4_cleaned glob *73*.gz\n",
|
14 |
-
"Number of files 39 after adding ./c4_cleaned glob *47*.gz\n",
|
15 |
-
"Number of files 60 after adding ./c4_cleaned glob *12*.gz\n",
|
16 |
-
"Number of files 79 after adding ./c4_cleaned glob *29*.gz\n",
|
17 |
-
"Number of files 97 after adding ./c4_cleaned glob *74*.gz\n",
|
18 |
-
"Number of files 116 after adding ./c4_cleaned glob *26*.gz\n",
|
19 |
-
"Number of files 135 after adding ./c4_cleaned glob *54*.gz\n",
|
20 |
-
"Number of files 154 after adding ./c4_cleaned glob *68*.gz\n",
|
21 |
-
"Number of files 172 after adding ./c4_cleaned glob *57*.gz\n",
|
22 |
-
"Number of files 189 after adding ./c4_cleaned glob *46*.gz\n",
|
23 |
-
"Number of files 206 after adding ./c4_cleaned glob *35*.gz\n",
|
24 |
-
"Number of files 226 after adding ./c4_cleaned glob *13*.gz\n",
|
25 |
-
"Number of files 242 after adding ./c4_cleaned glob *41*.gz\n",
|
26 |
-
"Number of files 259 after adding ./c4_cleaned glob *52*.gz\n",
|
27 |
-
"Number of files 276 after adding ./c4_cleaned glob *63*.gz\n",
|
28 |
-
"Number of files 292 after adding ./c4_cleaned glob *85*.gz\n",
|
29 |
-
"Number of files 309 after adding ./c4_cleaned glob *81*.gz\n",
|
30 |
-
"Number of files 326 after adding ./c4_cleaned glob *96*.gz\n",
|
31 |
-
"Number of files 526 after adding ./nrc_uniq_cleaned_20210223 glob *.gz\n",
|
32 |
-
"Number of files 726 after adding ./nu_uniq_cleaned_20210225 glob *.gz\n",
|
33 |
-
"726\n",
|
34 |
-
"Got 690 training files and 5.0 % 36 validation files\n"
|
35 |
-
]
|
36 |
-
}
|
37 |
-
],
|
38 |
-
"source": [
|
39 |
-
"data_files = []\n",
|
40 |
-
"data_dir=\".\"\n",
|
41 |
-
"def train_val_files():\n",
|
42 |
-
" import glob\n",
|
43 |
-
" import random\n",
|
44 |
-
" SEED = 12345\n",
|
45 |
-
"\n",
|
46 |
-
" def add_jsonlines_dir(path, filespec):\n",
|
47 |
-
" global data_files\n",
|
48 |
-
" data_files += glob.glob(f\"{path}/{filespec}\")\n",
|
49 |
-
" data_files = list(set(data_files))\n",
|
50 |
-
" print(f\"Number of files {len(data_files)} after adding {path} glob {filespec}\")\n",
|
51 |
-
"\n",
|
52 |
-
" # add_jsonlines_dir(f\"{data_dir}/oscar_nl_cleaned\")\n",
|
53 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*73*.gz\")\n",
|
54 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*47*.gz\")\n",
|
55 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*12*.gz\")\n",
|
56 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*29*.gz\")\n",
|
57 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*74*.gz\")\n",
|
58 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*26*.gz\")\n",
|
59 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*54*.gz\")\n",
|
60 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*68*.gz\")\n",
|
61 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*57*.gz\")\n",
|
62 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*46*.gz\")\n",
|
63 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*35*.gz\")\n",
|
64 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*13*.gz\")\n",
|
65 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*41*.gz\")\n",
|
66 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*52*.gz\")\n",
|
67 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*63*.gz\")\n",
|
68 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*85*.gz\")\n",
|
69 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*81*.gz\")\n",
|
70 |
-
" add_jsonlines_dir(f\"{data_dir}/c4_cleaned\", \"*96*.gz\")\n",
|
71 |
-
" add_jsonlines_dir(f\"{data_dir}/nrc_uniq_cleaned_20210223\", \"*.gz\")\n",
|
72 |
-
" add_jsonlines_dir(f\"{data_dir}/nu_uniq_cleaned_20210225\", \"*.gz\")\n",
|
73 |
-
" random.Random(SEED).shuffle(data_files)\n",
|
74 |
-
"\n",
|
75 |
-
" total = len(data_files)\n",
|
76 |
-
" print(total)\n",
|
77 |
-
" perc = 0.05\n",
|
78 |
-
" val_size = int(perc * total)\n",
|
79 |
-
" train_size = total - val_size\n",
|
80 |
-
" train = data_files[:train_size]\n",
|
81 |
-
" val = data_files[train_size:]\n",
|
82 |
-
" print(f\"Got {len(train)} training files and {perc*100} % {len(val)} validation files\")\n",
|
83 |
-
"\n",
|
84 |
-
" assert list(set(train) & set(val)) == [], \"Train overlaps with test\"\n",
|
85 |
-
"\n",
|
86 |
-
" return train, val\n",
|
87 |
-
"\n",
|
88 |
-
"train, val = train_val_files()"
|
89 |
-
]
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"cell_type": "code",
|
93 |
-
"execution_count": 72,
|
94 |
-
"id": "66a923c6-1c7e-4ac2-9aec-e75c572104dd",
|
95 |
-
"metadata": {},
|
96 |
-
"outputs": [
|
97 |
-
{
|
98 |
-
"name": "stderr",
|
99 |
-
"output_type": "stream",
|
100 |
-
"text": [
|
101 |
-
"Using custom data configuration default-ce92ec7dc3732df4\n"
|
102 |
-
]
|
103 |
-
},
|
104 |
-
{
|
105 |
-
"name": "stdout",
|
106 |
-
"output_type": "stream",
|
107 |
-
"text": [
|
108 |
-
"Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/yeb/.cache/huggingface/datasets/json/default-ce92ec7dc3732df4/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9...\n"
|
109 |
-
]
|
110 |
-
},
|
111 |
-
{
|
112 |
-
"data": {
|
113 |
-
"application/vnd.jupyter.widget-view+json": {
|
114 |
-
"model_id": "",
|
115 |
-
"version_major": 2,
|
116 |
-
"version_minor": 0
|
117 |
-
},
|
118 |
-
"text/plain": [
|
119 |
-
"0 tables [00:00, ? tables/s]"
|
120 |
-
]
|
121 |
-
},
|
122 |
-
"metadata": {},
|
123 |
-
"output_type": "display_data"
|
124 |
-
},
|
125 |
-
{
|
126 |
-
"data": {
|
127 |
-
"application/vnd.jupyter.widget-view+json": {
|
128 |
-
"model_id": "",
|
129 |
-
"version_major": 2,
|
130 |
-
"version_minor": 0
|
131 |
-
},
|
132 |
-
"text/plain": [
|
133 |
-
"0 tables [00:00, ? tables/s]"
|
134 |
-
]
|
135 |
-
},
|
136 |
-
"metadata": {},
|
137 |
-
"output_type": "display_data"
|
138 |
-
},
|
139 |
-
{
|
140 |
-
"name": "stdout",
|
141 |
-
"output_type": "stream",
|
142 |
-
"text": [
|
143 |
-
"Dataset json downloaded and prepared to /home/yeb/.cache/huggingface/datasets/json/default-ce92ec7dc3732df4/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9. Subsequent calls will reuse this data.\n"
|
144 |
-
]
|
145 |
-
}
|
146 |
-
],
|
147 |
-
"source": [
|
148 |
-
"from datasets import load_dataset\n",
|
149 |
-
"datasets = load_dataset('json', data_files={'train': train, 'validation': val})"
|
150 |
-
]
|
151 |
-
},
|
152 |
-
{
|
153 |
-
"cell_type": "code",
|
154 |
-
"execution_count": 73,
|
155 |
-
"id": "4a6d6009-00e7-4b30-b577-6805dd849b8a",
|
156 |
-
"metadata": {},
|
157 |
-
"outputs": [
|
158 |
-
{
|
159 |
-
"name": "stdout",
|
160 |
-
"output_type": "stream",
|
161 |
-
"text": [
|
162 |
-
"Num examples = 21153916\n"
|
163 |
-
]
|
164 |
-
}
|
165 |
-
],
|
166 |
-
"source": [
|
167 |
-
"print(f\"Num examples = {len(datasets['train'])}\")"
|
168 |
-
]
|
169 |
-
},
|
170 |
-
{
|
171 |
-
"cell_type": "code",
|
172 |
-
"execution_count": 74,
|
173 |
-
"id": "c6186d88-4296-4d1d-b7cd-d0196f0b0f97",
|
174 |
-
"metadata": {},
|
175 |
-
"outputs": [],
|
176 |
-
"source": [
|
177 |
-
"from transformers import (\n",
|
178 |
-
" CONFIG_MAPPING,\n",
|
179 |
-
" FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n",
|
180 |
-
" BatchEncoding,\n",
|
181 |
-
" FlaxT5ForConditionalGeneration,\n",
|
182 |
-
" T5ForConditionalGeneration,\n",
|
183 |
-
" HfArgumentParser,\n",
|
184 |
-
" PreTrainedTokenizerBase,\n",
|
185 |
-
" T5Config,\n",
|
186 |
-
" T5TokenizerFast,\n",
|
187 |
-
" TrainingArguments,\n",
|
188 |
-
" is_tensorboard_available,\n",
|
189 |
-
" set_seed,\n",
|
190 |
-
")"
|
191 |
-
]
|
192 |
-
},
|
193 |
-
{
|
194 |
-
"cell_type": "code",
|
195 |
-
"execution_count": 75,
|
196 |
-
"id": "10d90997-6eb6-4399-b1a7-8a858ae4738c",
|
197 |
-
"metadata": {},
|
198 |
-
"outputs": [
|
199 |
-
{
|
200 |
-
"name": "stderr",
|
201 |
-
"output_type": "stream",
|
202 |
-
"text": [
|
203 |
-
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
204 |
-
]
|
205 |
-
},
|
206 |
-
{
|
207 |
-
"name": "stdout",
|
208 |
-
"output_type": "stream",
|
209 |
-
"text": [
|
210 |
-
"Start tokenization, remove_column_names = ['url', 'timestamp', 'text']\n"
|
211 |
-
]
|
212 |
-
}
|
213 |
-
],
|
214 |
-
"source": [
|
215 |
-
"tokenizer = T5TokenizerFast.from_pretrained(\"./t5-base-dutch\")\n",
|
216 |
-
"\n",
|
217 |
-
"def tokenize_function(examples):\n",
|
218 |
-
" return tokenizer(examples['text'], return_attention_mask=False)\n",
|
219 |
-
"\n",
|
220 |
-
"column_names = datasets[\"train\"].column_names\n",
|
221 |
-
"print(f\"Start tokenization, remove_column_names = {column_names}\")\n",
|
222 |
-
"\n",
|
223 |
-
"tokenized_datasets = datasets.map(\n",
|
224 |
-
" tokenize_function,\n",
|
225 |
-
" batched=True,\n",
|
226 |
-
" num_proc=96,\n",
|
227 |
-
" remove_columns=column_names,\n",
|
228 |
-
" load_from_cache_file=True,\n",
|
229 |
-
")\n",
|
230 |
-
"\n"
|
231 |
-
]
|
232 |
-
},
|
233 |
-
{
|
234 |
-
"cell_type": "code",
|
235 |
-
"execution_count": 76,
|
236 |
-
"id": "de7983e1-775d-4ee3-bf66-681f731501fb",
|
237 |
-
"metadata": {},
|
238 |
-
"outputs": [
|
239 |
-
{
|
240 |
-
"data": {
|
241 |
-
"text/plain": [
|
242 |
-
"21153916"
|
243 |
-
]
|
244 |
-
},
|
245 |
-
"execution_count": 76,
|
246 |
-
"metadata": {},
|
247 |
-
"output_type": "execute_result"
|
248 |
-
}
|
249 |
-
],
|
250 |
-
"source": [
|
251 |
-
"len(tokenized_datasets[\"train\"])"
|
252 |
-
]
|
253 |
-
},
|
254 |
-
{
|
255 |
-
"cell_type": "code",
|
256 |
-
"execution_count": 77,
|
257 |
-
"id": "5721ad35-8373-4999-8ac5-02c6f759373f",
|
258 |
-
"metadata": {},
|
259 |
-
"outputs": [
|
260 |
-
{
|
261 |
-
"name": "stdout",
|
262 |
-
"output_type": "stream",
|
263 |
-
"text": [
|
264 |
-
"Expanded_inputs_length: 141, targets_length: 29\n",
|
265 |
-
"Start group_texts\n"
|
266 |
-
]
|
267 |
-
},
|
268 |
-
{
|
269 |
-
"name": "stderr",
|
270 |
-
"output_type": "stream",
|
271 |
-
"text": [
|
272 |
-
"https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=503811,5cca55,7fe2dabc120f,7fe2dabc120f,90641f90b85f&map=&map= \n",
|
273 |
-
" \n",
|
274 |
-
"*** SIGTERM received by PID 47670 (TID 47670) on cpu 70 from PID 33223; stack trace: ***\n",
|
275 |
-
"*** SIGTERM received by PID 47686 (TID 47686) on cpu 71 from PID 33223; stack trace: ***\n",
|
276 |
-
"https://symbolize.stripped_domain/r/?trace=56a4e1,7fe2dabc120f&map= \n",
|
277 |
-
"https://symbolize.stripped_domain/r/?trace=*** SIGTERM received by PID 47673 (TID 47673) on cpu 16 from PID 33223; stack trace: ***\n",
|
278 |
-
"56a682,7fe2dabc120f,7fdfb4cf751f,90b3ff&map= \n",
|
279 |
-
"*** SIGTERM received by PID 47665 (TID 47665) on cpu 67 from PID 33223; stack trace: ***\n",
|
280 |
-
"PC: @ 0x503811 (unknown) (unknown)\n",
|
281 |
-
"PC: @ 0x56a4e1 (unknown) _PyEval_EvalFrameDefault\n",
|
282 |
-
"PC: @ 0x5cca55 (unknown) (unknown)\n",
|
283 |
-
" @ 0x7fde2703b800 976 (unknown)\n",
|
284 |
-
" @ 0x7fde2703b800 976 (unknown)\n",
|
285 |
-
" @ 0x7fe2dabc1210 (unknown) (unknown)\n",
|
286 |
-
" @ ... and at least 1 more frames\n",
|
287 |
-
"https://symbolize.stripped_domain/r/?trace= @ 0x7fe2dabc1210 852927808 (unknown)\n",
|
288 |
-
"56a4e1,7fde2703b7ff,7fe2dabc120f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
289 |
-
"E0710 11:59:41.025238 47673 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
290 |
-
" @ 0x7fde2703b800 976 (unknown)\n",
|
291 |
-
" @ 0x7fe2dabc1210 850855568 (unknown)\n",
|
292 |
-
" @ 0x90b860 (unknown) (unknown)\n",
|
293 |
-
"https://symbolize.stripped_domain/r/?trace=5cca55,7fde2703b7ff,7fe2dabc120f,90b85f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
294 |
-
"E0710 11:59:41.030755 47686 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
295 |
-
" @ 0x906420 (unknown) (unknown)\n",
|
296 |
-
"https://symbolize.stripped_domain/r/?trace=503811,7fde2703b7ff,7fe2dabc120f,90641f&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
297 |
-
"E0710 11:59:41.033184 47670 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
298 |
-
"E0710 11:59:41.033730 47673 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
|
299 |
-
"PC: @ 0x56a682 (unknown) _PyEval_EvalFrameDefault\n",
|
300 |
-
" @ 0x7fde2703b800 976 (unknown)\n",
|
301 |
-
" @ 0x7fe2dabc1210 (unknown) (unknown)\n",
|
302 |
-
" @ 0x7fdfb4cf7520 (unknown) (unknown)\n",
|
303 |
-
"E0710 11:59:41.057700 47670 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
|
304 |
-
"E0710 11:59:41.063730 47686 process_state.cc:771] RAW: Raising signal 15 with default behavior\n",
|
305 |
-
" @ 0x90b400 (unknown) (unknown)\n",
|
306 |
-
"https://symbolize.stripped_domain/r/?trace=56a682,7fde2703b7ff,7fe2dabc120f,7fdfb4cf751f,90b3ff&map=2a762cd764e70bc90ae4c7f9747c08d7:7fde1a0f9000-7fde2737a280 \n",
|
307 |
-
"E0710 11:59:41.064237 47665 coredump_hook.cc:250] RAW: Remote crash gathering disabled for SIGTERM.\n",
|
308 |
-
"E0710 11:59:41.091833 47665 process_state.cc:771] RAW: Raising signal 15 with default behavior\n"
|
309 |
-
]
|
310 |
-
}
|
311 |
-
],
|
312 |
-
"source": [
|
313 |
-
"def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):\n",
|
314 |
-
" \"\"\"This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .\n",
|
315 |
-
"\n",
|
316 |
-
" Training parameters to avoid padding with random_spans_noise_mask.\n",
|
317 |
-
" When training a model with random_spans_noise_mask, we would like to set the other\n",
|
318 |
-
" training hyperparmeters in a way that avoids padding.\n",
|
319 |
-
" This function helps us compute these hyperparameters.\n",
|
320 |
-
" We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens,\n",
|
321 |
-
" and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens.\n",
|
322 |
-
" This function tells us the required number of tokens in the raw example (for split_tokens())\n",
|
323 |
-
" as well as the length of the encoded targets. Note that this function assumes\n",
|
324 |
-
" the inputs and targets will have EOS appended and includes that in the reported length.\n",
|
325 |
-
"\n",
|
326 |
-
" Args:\n",
|
327 |
-
" inputs_length: an integer - desired length of the tokenized inputs sequence\n",
|
328 |
-
" noise_density: a float\n",
|
329 |
-
" mean_noise_span_length: a float\n",
|
330 |
-
" Returns:\n",
|
331 |
-
" tokens_length: length of original text in tokens\n",
|
332 |
-
" targets_length: an integer - length in tokens of encoded targets sequence\n",
|
333 |
-
" \"\"\"\n",
|
334 |
-
"\n",
|
335 |
-
" def _tokens_length_to_inputs_length_targets_length(tokens_length):\n",
|
336 |
-
" num_noise_tokens = int(round(tokens_length * noise_density))\n",
|
337 |
-
" num_nonnoise_tokens = tokens_length - num_noise_tokens\n",
|
338 |
-
" num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length))\n",
|
339 |
-
" # inputs contain all nonnoise tokens, sentinels for all noise spans\n",
|
340 |
-
" # and one EOS token.\n",
|
341 |
-
" _input_length = num_nonnoise_tokens + num_noise_spans + 1\n",
|
342 |
-
" _output_length = num_noise_tokens + num_noise_spans + 1\n",
|
343 |
-
" return _input_length, _output_length\n",
|
344 |
-
"\n",
|
345 |
-
" tokens_length = inputs_length\n",
|
346 |
-
"\n",
|
347 |
-
" while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:\n",
|
348 |
-
" tokens_length += 1\n",
|
349 |
-
"\n",
|
350 |
-
" inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)\n",
|
351 |
-
"\n",
|
352 |
-
" # minor hack to get the targets length to be equal to inputs length\n",
|
353 |
-
" # which is more likely to have been set to a nice round number.\n",
|
354 |
-
" if noise_density == 0.5 and targets_length > inputs_length:\n",
|
355 |
-
" tokens_length -= 1\n",
|
356 |
-
" targets_length -= 1\n",
|
357 |
-
" return tokens_length, targets_length\n",
|
358 |
-
"\n",
|
359 |
-
"# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.\n",
|
360 |
-
"# To ensure that the input length is `max_seq_length`, we need to increase the maximum length\n",
|
361 |
-
"# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.\n",
|
362 |
-
"expanded_inputs_length, targets_length = compute_input_and_target_lengths(\n",
|
363 |
-
" inputs_length=128,\n",
|
364 |
-
" noise_density=0.15,\n",
|
365 |
-
" mean_noise_span_length=3.0,\n",
|
366 |
-
")\n",
|
367 |
-
"\n",
|
368 |
-
"print(f\"Expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}\")\n",
|
369 |
-
"print(f\"Start group_texts\")\n",
|
370 |
-
"\n",
|
371 |
-
"# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.\n",
|
372 |
-
"def group_texts(examples):\n",
|
373 |
-
" # Concatenate all texts.\n",
|
374 |
-
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
375 |
-
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
376 |
-
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
377 |
-
" # customize this part to your needs.\n",
|
378 |
-
" if total_length >= expanded_inputs_length:\n",
|
379 |
-
" total_length = (total_length // expanded_inputs_length) * expanded_inputs_length\n",
|
380 |
-
" # Split by chunks of max_len.\n",
|
381 |
-
" result = {\n",
|
382 |
-
" k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]\n",
|
383 |
-
" for k, t in concatenated_examples.items()\n",
|
384 |
-
" }\n",
|
385 |
-
" return result\n",
|
386 |
-
"\n",
|
387 |
-
"# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a\n",
|
388 |
-
"# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value\n",
|
389 |
-
"# might be slower to preprocess.\n",
|
390 |
-
"#\n",
|
391 |
-
"# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n",
|
392 |
-
"# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n",
|
393 |
-
"grouped_datasets = tokenized_datasets.map(\n",
|
394 |
-
" group_texts,\n",
|
395 |
-
" batched=True,\n",
|
396 |
-
" batch_size=200,\n",
|
397 |
-
" num_proc=96,\n",
|
398 |
-
" load_from_cache_file=True,\n",
|
399 |
-
")\n"
|
400 |
-
]
|
401 |
-
},
|
402 |
-
{
|
403 |
-
"cell_type": "code",
|
404 |
-
"execution_count": 78,
|
405 |
-
"id": "f37e7559-fcc1-436b-a4ee-45adb856869e",
|
406 |
-
"metadata": {},
|
407 |
-
"outputs": [
|
408 |
-
{
|
409 |
-
"data": {
|
410 |
-
"text/plain": [
|
411 |
-
"86438919"
|
412 |
-
]
|
413 |
-
},
|
414 |
-
"execution_count": 78,
|
415 |
-
"metadata": {},
|
416 |
-
"output_type": "execute_result"
|
417 |
-
}
|
418 |
-
],
|
419 |
-
"source": [
|
420 |
-
"examples = len(grouped_datasets[\"train\"])\n",
|
421 |
-
"examples"
|
422 |
-
]
|
423 |
-
},
|
424 |
-
{
|
425 |
-
"cell_type": "code",
|
426 |
-
"execution_count": 79,
|
427 |
-
"id": "21aac2aa-9dc2-4b7a-8c46-62cfa47f18a7",
|
428 |
-
"metadata": {},
|
429 |
-
"outputs": [],
|
430 |
-
"source": [
|
431 |
-
"it = iter(grouped_datasets[\"train\"])"
|
432 |
-
]
|
433 |
-
},
|
434 |
-
{
|
435 |
-
"cell_type": "code",
|
436 |
-
"execution_count": 80,
|
437 |
-
"id": "011a6a07-5fe0-441a-b032-79cf8664b5c5",
|
438 |
-
"metadata": {},
|
439 |
-
"outputs": [
|
440 |
-
{
|
441 |
-
"name": "stdout",
|
442 |
-
"output_type": "stream",
|
443 |
-
"text": [
|
444 |
-
"{'input_ids': [256, 3, 20, 18452, 6690, 7757, 1286, 43, 10, 4942, 1286, 80, 12, 4782, 5442, 39, 5385, 33, 4, 5, 3, 2924, 117, 5669, 228, 21, 193, 9030, 511, 24, 11, 5, 665, 165, 4218, 7, 26, 264, 1528, 35, 105, 3, 19653, 12, 9661, 17156, 13955, 4, 132, 5, 611, 959, 961, 146, 6522, 7757, 1286, 89, 7500, 9716, 11, 5, 4868, 107, 13604, 12, 12836, 13368, 11, 611, 959, 4, 3, 69, 99, 12, 13132, 6690, 590, 5, 1803, 1867, 69, 7, 924, 10, 1762, 4, 3, 69, 538, 489, 14, 1149, 16, 3, 11384, 199, 116, 399, 4782, 291, 3, 6, 237, 13, 2629, 3, 8987, 291, 4, 69, 5, 3, 27, 72, 20, 325, 3, 2924, 133, 21, 105, 9030, 10, 1149, 242, 16, 144, 13572, 11, 9, 13401, 20, 7951, 8, 165, 4218, 4, 5, 1910]}\n"
|
445 |
-
]
|
446 |
-
}
|
447 |
-
],
|
448 |
-
"source": [
|
449 |
-
"print(next(it))"
|
450 |
-
]
|
451 |
-
},
|
452 |
-
{
|
453 |
-
"cell_type": "code",
|
454 |
-
"execution_count": 81,
|
455 |
-
"id": "f20d3da2-0132-4ecc-b9b9-c2b5ec06f031",
|
456 |
-
"metadata": {},
|
457 |
-
"outputs": [],
|
458 |
-
"source": [
|
459 |
-
"tokens = next(it)['input_ids']\n"
|
460 |
-
]
|
461 |
-
},
|
462 |
-
{
|
463 |
-
"cell_type": "code",
|
464 |
-
"execution_count": 82,
|
465 |
-
"id": "2bad87cd-06e1-4c52-b2d6-d61fcb96e35d",
|
466 |
-
"metadata": {},
|
467 |
-
"outputs": [
|
468 |
-
{
|
469 |
-
"data": {
|
470 |
-
"text/plain": [
|
471 |
-
"141"
|
472 |
-
]
|
473 |
-
},
|
474 |
-
"execution_count": 82,
|
475 |
-
"metadata": {},
|
476 |
-
"output_type": "execute_result"
|
477 |
-
}
|
478 |
-
],
|
479 |
-
"source": [
|
480 |
-
"len(tokens)"
|
481 |
-
]
|
482 |
-
},
|
483 |
-
{
|
484 |
-
"cell_type": "code",
|
485 |
-
"execution_count": 83,
|
486 |
-
"id": "4e0f573a-0abc-4f8f-b59a-a281fb306425",
|
487 |
-
"metadata": {},
|
488 |
-
"outputs": [
|
489 |
-
{
|
490 |
-
"data": {
|
491 |
-
"text/plain": [
|
492 |
-
"\"werden volgens getuigen vergezeld door een boomlange bodyguard. ook hing er een gordijntje om de tafel, zodat beyoncé in alle rust van de show kon genieten. volgens de bron verliet knowles pas om 03.30 uur's ochtends de hippe club.</s> utrecht - in de schouwburg van utrecht gaat vrijdagavond de musical 'joseph and the amazing technicolor dreamcoat' in première. voor het eerst in nederland. een voorloper van het geesteskind van andrew lloyd webber werd al in 1967 voor het eerst op een school in groot-brittannië uitgeprobeerd. twaalf jaar later werd het in\""
|
493 |
-
]
|
494 |
-
},
|
495 |
-
"execution_count": 83,
|
496 |
-
"metadata": {},
|
497 |
-
"output_type": "execute_result"
|
498 |
-
}
|
499 |
-
],
|
500 |
-
"source": [
|
501 |
-
"tokenizer.decode(tokens)"
|
502 |
-
]
|
503 |
-
},
|
504 |
-
{
|
505 |
-
"cell_type": "code",
|
506 |
-
"execution_count": 84,
|
507 |
-
"id": "ab853c1b-0e0f-4ae8-b1cb-053f76a7d9d7",
|
508 |
-
"metadata": {},
|
509 |
-
"outputs": [
|
510 |
-
{
|
511 |
-
"ename": "KeyboardInterrupt",
|
512 |
-
"evalue": "",
|
513 |
-
"output_type": "error",
|
514 |
-
"traceback": [
|
515 |
-
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
516 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
517 |
-
"\u001b[0;32m/tmp/ipykernel_33223/1050159500.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m \u001b[0;34m:=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'input_ids'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m141\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexample\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
518 |
-
"\u001b[0;32m~/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1264\u001b[0m \u001b[0moutput_all_columns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_output_all_columns\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1266\u001b[0;31m yield self._getitem(\n\u001b[0m\u001b[1;32m 1267\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1268\u001b[0m \u001b[0mformat_type\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformat_type\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
519 |
-
"\u001b[0;32m~/datasets/src/datasets/arrow_dataset.py\u001b[0m in \u001b[0;36m_getitem\u001b[0;34m(self, key, format_type, format_columns, output_all_columns, format_kwargs)\u001b[0m\n\u001b[1;32m 1507\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mformat_kwargs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1508\u001b[0m \u001b[0mformatter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_formatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mformat_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mformat_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1509\u001b[0;31m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mquery_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_indices\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_indices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1510\u001b[0m formatted_output = format_table(\n\u001b[1;32m 1511\u001b[0m \u001b[0mpa_subtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformatter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformatter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mformat_columns\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_all_columns\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_all_columns\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
520 |
-
"\u001b[0;32m~/datasets/src/datasets/formatting/formatting.py\u001b[0m in \u001b[0;36mquery_table\u001b[0;34m(table, key, indices)\u001b[0m\n\u001b[1;32m 369\u001b[0m \u001b[0;31m# Query the main table\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mindices\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 371\u001b[0;31m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_query_table\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 372\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[0mpa_subtable\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_query_table_with_indices_mapping\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
521 |
-
"\u001b[0;32m~/datasets/src/datasets/formatting/formatting.py\u001b[0m in \u001b[0;36m_query_table\u001b[0;34m(table, key)\u001b[0m\n\u001b[1;32m 77\u001b[0m \"\"\"\n\u001b[1;32m 78\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 79\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfast_slice\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 80\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mslice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_rows\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
522 |
-
"\u001b[0;32m~/datasets/src/datasets/table.py\u001b[0m in \u001b[0;36mfast_slice\u001b[0;34m(self, offset, length)\u001b[0m\n\u001b[1;32m 127\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0moffset\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlength\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 128\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mpa\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTable\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mschema\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_schema\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 129\u001b[0;31m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_interpolation_search\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 130\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlength\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0moffset\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offsets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mbatches\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_batches\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
523 |
-
"\u001b[0;32m~/datasets/src/datasets/table.py\u001b[0m in \u001b[0;36m_interpolation_search\u001b[0;34m(arr, x)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0marr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
524 |
-
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
525 |
-
]
|
526 |
-
}
|
527 |
-
],
|
528 |
-
"source": [
|
529 |
-
"while (example := next(it, None)) is not None:\n",
|
530 |
-
" if len(example['input_ids']) == 141:\n",
|
531 |
-
" continue\n",
|
532 |
-
" else:\n",
|
533 |
-
" print(example)\n",
|
534 |
-
" break"
|
535 |
-
]
|
536 |
-
},
|
537 |
-
{
|
538 |
-
"cell_type": "code",
|
539 |
-
"execution_count": null,
|
540 |
-
"id": "f71a0f6b-3b60-4dd5-a9af-0ef43aadc6a1",
|
541 |
-
"metadata": {},
|
542 |
-
"outputs": [],
|
543 |
-
"source": []
|
544 |
-
}
|
545 |
-
],
|
546 |
-
"metadata": {
|
547 |
-
"kernelspec": {
|
548 |
-
"display_name": "Python 3 (ipykernel)",
|
549 |
-
"language": "python",
|
550 |
-
"name": "python3"
|
551 |
-
},
|
552 |
-
"language_info": {
|
553 |
-
"codemirror_mode": {
|
554 |
-
"name": "ipython",
|
555 |
-
"version": 3
|
556 |
-
},
|
557 |
-
"file_extension": ".py",
|
558 |
-
"mimetype": "text/x-python",
|
559 |
-
"name": "python",
|
560 |
-
"nbconvert_exporter": "python",
|
561 |
-
"pygments_lexer": "ipython3",
|
562 |
-
"version": "3.8.10"
|
563 |
-
}
|
564 |
-
},
|
565 |
-
"nbformat": 4,
|
566 |
-
"nbformat_minor": 5
|
567 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -3,21 +3,34 @@ language:
|
|
3 |
- dutch
|
4 |
tags:
|
5 |
- seq2seq
|
6 |
-
-
|
7 |
datasets:
|
8 |
-
-
|
|
|
|
|
9 |
---
|
10 |
|
11 |
# t5-base-dutch
|
12 |
|
13 |
-
Created by [Yeb Havinga](https://www.linkedin.com/in/yeb-havinga-86530825/)
|
|
|
14 |
|
15 |
-
See also the fine-tuned [t5-base-dutch-demo](https://huggingface.co/flax-community/t5-base-dutch-demo) model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
## Dataset
|
18 |
|
19 |
-
This model was trained on
|
20 |
-
|
21 |
|
22 |
* Documents that contained words from a selection of the Dutch and English [List of Dirty Naught Obscene and Otherwise Bad Words](https://github.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words) are removed
|
23 |
* Sentences with less than 3 words are removed
|
@@ -26,13 +39,18 @@ See the `clean` directory for the clean script.
|
|
26 |
* Documents with "javascript", "lorum ipsum", "terms of use", "privacy policy", "cookie policy", "uses cookies",
|
27 |
"use of cookies", "use cookies", "elementen ontbreken", "deze printversie" are removed.
|
28 |
|
|
|
|
|
|
|
|
|
|
|
29 |
## Training
|
30 |
|
31 |
-
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
the first few resumes would start again at step 0 with a different seeded reshuffling of the data.
|
35 |
-
In the last two resumes the random seed was fixed, and training would resume at the previous step, since a try/except around the failing example would allow training to continue in the case of errors caused by a single example.
|
36 |
|
37 |
-
|
38 |
-
|
|
|
3 |
- dutch
|
4 |
tags:
|
5 |
- seq2seq
|
6 |
+
- lm-head
|
7 |
datasets:
|
8 |
+
- yhavinga/mc4_nl_cleaned
|
9 |
+
license: apache-2.0
|
10 |
+
inference: false
|
11 |
---
|
12 |
|
13 |
# t5-base-dutch
|
14 |
|
15 |
+
Created by [Yeb Havinga](https://www.linkedin.com/in/yeb-havinga-86530825/)
|
16 |
+
& [Dat Nguyen](https://www.linkedin.com/in/dat-nguyen-49a641138/) during the [Hugging Face community week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google, for the project [Pre-train T5 from scratch in Dutch](https://discuss.huggingface.co/t/pretrain-t5-from-scratch-in-dutch/8109).
|
17 |
|
18 |
+
See also the fine-tuned [t5-base-dutch-demo](https://huggingface.co/flax-community/t5-base-dutch-demo) model,
|
19 |
+
and the demo application **[Netherformer 📰](https://huggingface.co/spaces/flax-community/netherformer)**,
|
20 |
+
that are based on this model.
|
21 |
+
|
22 |
+
**5 jan 2022: Model updated. Evaluation accuracy increased from 0.64 to 0.70.**
|
23 |
+
|
24 |
+
## Model
|
25 |
+
|
26 |
+
* Configuration based on `google/t5-base`
|
27 |
+
* 12 layers, 12 heads
|
28 |
+
* Dropout set to 0.1
|
29 |
|
30 |
## Dataset
|
31 |
|
32 |
+
This model was trained on the `full` configuration of [cleaned Dutch mC4](https://huggingface.co/datasets/mc4_nl_cleaned),
|
33 |
+
which is the original mC4, except
|
34 |
|
35 |
* Documents that contained words from a selection of the Dutch and English [List of Dirty Naught Obscene and Otherwise Bad Words](https://github.com/LDNOOBW/List-of-Dirty-Naughty-Obscene-and-Otherwise-Bad-Words) are removed
|
36 |
* Sentences with less than 3 words are removed
|
|
|
39 |
* Documents with "javascript", "lorum ipsum", "terms of use", "privacy policy", "cookie policy", "uses cookies",
|
40 |
"use of cookies", "use cookies", "elementen ontbreken", "deze printversie" are removed.
|
41 |
|
42 |
+
## Tokenization
|
43 |
+
|
44 |
+
A SentencePiece tokenizer was trained from scratch on this dataset.
|
45 |
+
The total tokens of the `full` configuration is 34B
|
46 |
+
|
47 |
## Training
|
48 |
|
49 |
+
The model was trained on the `full` mc4_nl_cleaned dataset configuration for 1 epoch, consisting of 34B tokens,
|
50 |
+
for 528 482 steps with a batch size of 128 and took 57 hours.
|
51 |
+
A triangle learning rate schedule was used, with peak learning rate 0.005.
|
52 |
|
53 |
+
## Evaluation
|
|
|
|
|
54 |
|
55 |
+
* Loss: 1.38
|
56 |
+
* Accuracy: 0.70
|
config.json
CHANGED
@@ -52,7 +52,7 @@
|
|
52 |
}
|
53 |
},
|
54 |
"torch_dtype": "float32",
|
55 |
-
"transformers_version": "4.
|
56 |
"use_cache": true,
|
57 |
"vocab_size": 32103
|
58 |
}
|
|
|
52 |
}
|
53 |
},
|
54 |
"torch_dtype": "float32",
|
55 |
+
"transformers_version": "4.13.0",
|
56 |
"use_cache": true,
|
57 |
"vocab_size": 32103
|
58 |
}
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891548548
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:be5973ac1f68ec3c5ceb47e10ed848b83ad06e69affa938fc400e3ef368143ea
|
3 |
size 891548548
|
flax_to_pt.py
CHANGED
@@ -1,6 +1,26 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import jax.numpy as jnp
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from transformers import FlaxT5ForConditionalGeneration
|
6 |
+
from transformers import T5ForConditionalGeneration
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(".")
|
8 |
+
model_fx = FlaxT5ForConditionalGeneration.from_pretrained(".")
|
9 |
+
model_pt = T5ForConditionalGeneration.from_pretrained(".", from_flax=True)
|
10 |
+
model_pt.save_pretrained("./")
|
11 |
+
text = "Hoe gaat het?"
|
12 |
+
e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
|
13 |
+
d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
|
14 |
+
e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
|
15 |
+
d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
|
16 |
+
print(e_input_ids_fx)
|
17 |
+
print(d_input_ids_fx)
|
18 |
+
print()
|
19 |
+
encoder_pt = model_fx.encode(**e_input_ids_pt)
|
20 |
+
decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
|
21 |
+
logits_pt = decoder_pt.logits
|
22 |
+
print(logits_pt)
|
23 |
+
encoder_fx = model_fx.encode(**e_input_ids_fx)
|
24 |
+
decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
|
25 |
+
logits_fx = decoder_fx.logits
|
26 |
+
print(logits_fx)
|
opt_state.msgpack
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ffae8bd1730e35ebeb0619a7d1b75dab07addff2320d2394eb1af891820ca64f
|
3 |
-
size 1985609
|
|
|
|
|
|
|
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 891650495
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f102fac4815a8b1b29916b196bfe88a0e5fef76083c6007a5c7966a7fcb9b2d6
|
3 |
size 891650495
|
run_t5.sh
CHANGED
@@ -1,79 +1,37 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
#
|
8 |
-
|
9 |
-
|
10 |
-
#
|
11 |
-
#
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# SEED=$RANDOM
|
39 |
-
SEED=22384
|
40 |
-
|
41 |
-
./run_t5_mlm_flax_custom_dataset.py \
|
42 |
-
--output_dir="${MODEL_DIR}" \
|
43 |
-
--model_type="t5" \
|
44 |
-
--config_name="flax-community/${MODEL}" \
|
45 |
-
--tokenizer_name="${MODEL_DIR}" \
|
46 |
-
--seed="${SEED}" \
|
47 |
-
--preprocessing_num_workers="96" \
|
48 |
-
--do_train --do_eval \
|
49 |
-
--adafactor \
|
50 |
-
--max_seq_length="512" \
|
51 |
-
--per_device_train_batch_size="16" \
|
52 |
-
--per_device_eval_batch_size="16" \
|
53 |
-
--dtype="bfloat16" \
|
54 |
-
--learning_rate="1e-3" \
|
55 |
-
--overwrite_output_dir \
|
56 |
-
--num_train_epochs="1" \
|
57 |
-
--logging_steps="50" \
|
58 |
-
--save_steps="500" \
|
59 |
-
--eval_steps="5000" \
|
60 |
-
--resume_from_checkpoint="${MODEL_DIR}" \
|
61 |
-
--warmup_steps="6519"
|
62 |
-
|
63 |
-
# \
|
64 |
-
# --push_to_hub
|
65 |
-
|
66 |
-
echo "RESTARTING"
|
67 |
-
sleep 20
|
68 |
-
done
|
69 |
-
#
|
70 |
-
# \
|
71 |
-
|
72 |
-
|
73 |
-
#git add pytorch_model.bin
|
74 |
-
#git commit -m "Update pytorch model after training"
|
75 |
-
#git push origin main
|
76 |
-
|
77 |
-
# --gradient_accumulation_steps="2" \
|
78 |
-
|
79 |
-
# --resume_from_checkpoint="${MODEL_DIR}/ckpt-18000" \
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
export HF_PROJECT="t5-base-dutch"
|
4 |
+
|
5 |
+
# Variables for training the tokenizer and creating the config
|
6 |
+
export VOCAB_SIZE="32000"
|
7 |
+
export N_INPUT_SENTENCES="1000000" # Num of sentences to train the tokenizer
|
8 |
+
export DATASET="yhavinga/mc4_nl_cleaned" # Name of the dataset in the Huggingface Hub
|
9 |
+
export DATASET_CONFIG="full" # Config of the dataset in the Huggingface Hub
|
10 |
+
export DATASET_SPLIT="train" # Split to use for training tokenizer and model
|
11 |
+
export TEXT_FIELD="text" # Field containing the text to be used for training
|
12 |
+
export CONFIG_TYPE="t5-base" # Config that our model will use
|
13 |
+
export MODEL_PATH="${HOME}/data/${HF_PROJECT}" # Path to the model, e.g. here inside the mount
|
14 |
+
|
15 |
+
python run_t5_mlm_flax.py \
|
16 |
+
--output_dir="${MODEL_PATH}" \
|
17 |
+
--model_type="t5" \
|
18 |
+
--config_name="${MODEL_PATH}" \
|
19 |
+
--tokenizer_name="${MODEL_PATH}" \
|
20 |
+
--preprocessing_num_workers="96" \
|
21 |
+
--do_train --do_eval \
|
22 |
+
--dataset_name="${DATASET}" \
|
23 |
+
--dataset_config_name="${DATASET_CONFIG}" \
|
24 |
+
--max_seq_length="512" \
|
25 |
+
--per_device_train_batch_size="16" \
|
26 |
+
--per_device_eval_batch_size="16" \
|
27 |
+
--adafactor \
|
28 |
+
--learning_rate="0.005" \
|
29 |
+
--overwrite_output_dir \
|
30 |
+
--num_train_epochs="1" \
|
31 |
+
--logging_steps="500" \
|
32 |
+
--save_steps="80000" \
|
33 |
+
--eval_steps="2500" \
|
34 |
+
--weight_decay="0.01" \
|
35 |
+
--warmup_steps="10000" \
|
36 |
+
--validation_split_count="15000" \
|
37 |
+
--push_to_hub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
run_t5_mlm_flax_custom_dataset.py → run_t5_mlm_flax.py
RENAMED
@@ -18,6 +18,8 @@ Pretraining the library models for T5-like span-masked language modeling on a te
|
|
18 |
|
19 |
Here is the full list of checkpoints on the hub that can be pretrained by this script:
|
20 |
https://huggingface.co/models?filter=t5
|
|
|
|
|
21 |
"""
|
22 |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
23 |
import logging
|
@@ -25,13 +27,13 @@ import os
|
|
25 |
import sys
|
26 |
import time
|
27 |
import json
|
28 |
-
import shutil
|
29 |
from dataclasses import dataclass, field
|
|
|
30 |
from pathlib import Path
|
31 |
from typing import Dict, List, Optional
|
32 |
|
33 |
import numpy as np
|
34 |
-
from datasets import load_dataset
|
35 |
from tqdm import tqdm
|
36 |
|
37 |
import flax
|
@@ -39,34 +41,31 @@ import jax
|
|
39 |
import jax.numpy as jnp
|
40 |
import optax
|
41 |
from flax import jax_utils, traverse_util
|
|
|
42 |
from flax.training import train_state
|
43 |
from flax.training.common_utils import get_metrics, onehot, shard
|
44 |
-
from
|
45 |
from transformers import (
|
46 |
CONFIG_MAPPING,
|
47 |
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
|
|
48 |
BatchEncoding,
|
49 |
FlaxT5ForConditionalGeneration,
|
50 |
-
T5ForConditionalGeneration,
|
51 |
HfArgumentParser,
|
52 |
PreTrainedTokenizerBase,
|
53 |
T5Config,
|
54 |
-
T5TokenizerFast,
|
55 |
TrainingArguments,
|
56 |
is_tensorboard_available,
|
57 |
set_seed,
|
58 |
)
|
|
|
59 |
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
60 |
|
61 |
logger = logging.getLogger(__name__)
|
62 |
|
63 |
-
|
64 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
65 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
66 |
|
67 |
-
data_files = []
|
68 |
-
|
69 |
-
|
70 |
@dataclass
|
71 |
class ModelArguments:
|
72 |
"""
|
@@ -103,6 +102,12 @@ class ModelArguments:
|
|
103 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
104 |
},
|
105 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
|
108 |
@dataclass
|
@@ -133,10 +138,10 @@ class DataTrainingArguments:
|
|
133 |
overwrite_cache: bool = field(
|
134 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
135 |
)
|
136 |
-
|
137 |
-
default=
|
138 |
metadata={
|
139 |
-
"help": "The
|
140 |
},
|
141 |
)
|
142 |
max_seq_length: Optional[int] = field(
|
@@ -156,18 +161,31 @@ class DataTrainingArguments:
|
|
156 |
default=3.0,
|
157 |
metadata={"help": "Mean span length of masked tokens"},
|
158 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
def __post_init__(self):
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
# assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
171 |
|
172 |
|
173 |
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
|
@@ -297,7 +315,7 @@ class FlaxDataCollatorForT5MLM:
|
|
297 |
start_indices[:, 0] = mask_indices[:, 0]
|
298 |
|
299 |
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
|
300 |
-
sentinel_ids = np.where(sentinel_ids != 0, (
|
301 |
sentinel_ids -= mask_indices - start_indices
|
302 |
|
303 |
return sentinel_ids
|
@@ -362,7 +380,8 @@ class FlaxDataCollatorForT5MLM:
|
|
362 |
np.random.shuffle(mask_indices)
|
363 |
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
364 |
segment_id = np.cumsum(first_in_segment)
|
365 |
-
|
|
|
366 |
return segment_length
|
367 |
|
368 |
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
@@ -405,70 +424,40 @@ def write_eval_metric(summary_writer, eval_metrics, step):
|
|
405 |
for metric_name, value in eval_metrics.items():
|
406 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
407 |
|
408 |
-
|
409 |
def mb_item(x):
|
410 |
return x.item() if hasattr(x, "item") else x
|
411 |
|
412 |
|
413 |
-
|
414 |
-
def save_checkpoint(model, save_dir, state, with_opt: bool = True):
|
415 |
state = jax_utils.unreplicate(state)
|
416 |
-
logger.info(f"SAVING CHECKPOINT IN {save_dir}")
|
417 |
-
save_dir = f"{save_dir}/ckpt-{mb_item(state.step) - 1}"
|
418 |
-
model.save_pretrained(
|
419 |
-
save_dir,
|
420 |
-
params=state.params,
|
421 |
-
push_to_hub=False
|
422 |
-
)
|
423 |
if with_opt:
|
|
|
424 |
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
|
425 |
f.write(to_bytes(state.opt_state))
|
426 |
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
|
427 |
json.dump({"step": state.step.item()}, f)
|
428 |
-
logger.info(f
|
429 |
model.save_pretrained(
|
430 |
-
|
431 |
params=state.params,
|
432 |
-
push_to_hub=
|
433 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
434 |
)
|
435 |
-
if with_opt:
|
436 |
-
with open(os.path.join(training_args.output_dir, "opt_state.msgpack"), "wb") as f:
|
437 |
-
f.write(to_bytes(state.opt_state))
|
438 |
-
with open(os.path.join(training_args.output_dir, "training_state.json"), "w") as f:
|
439 |
-
json.dump({"step": state.step.item()}, f)
|
440 |
-
logger.info("checkpoint saved")
|
441 |
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
|
446 |
params = from_bytes(state.params, f.read())
|
447 |
-
|
448 |
-
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
|
449 |
opt_state = from_bytes(state.opt_state, f.read())
|
450 |
-
|
451 |
-
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
|
452 |
training_state = json.load(f)
|
453 |
step = training_state["step"]
|
454 |
-
|
455 |
-
logger.info("checkpoint restored")
|
456 |
return state.replace(step=step, params=params, opt_state=opt_state), step
|
457 |
|
458 |
|
459 |
-
def rotate_checkpoints(ckpt_dir: str, save_total_limit: int):
|
460 |
-
"Removes older checkpoints so that `save_total_limit` checkpoints are kept"
|
461 |
-
# TODO: what to remove is decided using step number only, we might want to improve that
|
462 |
-
ckpts = [str(x) for x in Path(ckpt_dir).glob("ckpt-*")]
|
463 |
-
# sort checkpoints by step
|
464 |
-
ckpts_sorted = sorted(ckpts, key=lambda x: int(x.split('-')[-1]))
|
465 |
-
ckpts_to_delete = ckpts_sorted[:-save_total_limit]
|
466 |
-
for ckpt in ckpts_to_delete:
|
467 |
-
logger.info(f"Deleting older checkpoint [{ckpt}] due to save_total_limit ({save_total_limit})")
|
468 |
-
shutil.rmtree(ckpt)
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
if __name__ == "__main__":
|
473 |
# See all possible arguments in src/transformers/training_args.py
|
474 |
# or by passing the --help flag to this script.
|
@@ -509,6 +498,16 @@ if __name__ == "__main__":
|
|
509 |
# Set seed before initializing model.
|
510 |
set_seed(training_args.seed)
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
513 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
514 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
@@ -523,82 +522,38 @@ if __name__ == "__main__":
|
|
523 |
datasets["validation"] = load_dataset(
|
524 |
data_args.dataset_name,
|
525 |
data_args.dataset_config_name,
|
526 |
-
split=f"train[:{data_args.
|
527 |
cache_dir=model_args.cache_dir,
|
528 |
)
|
529 |
datasets["train"] = load_dataset(
|
530 |
data_args.dataset_name,
|
531 |
data_args.dataset_config_name,
|
532 |
-
split=f"train[{data_args.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
533 |
cache_dir=model_args.cache_dir,
|
534 |
)
|
535 |
else:
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
global data_files
|
546 |
-
data_files += glob.glob(f"{path}/{filespec}")
|
547 |
-
data_files = list(set(data_files))
|
548 |
-
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
549 |
-
|
550 |
-
# add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
|
551 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*73*.gz")
|
552 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
|
553 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*12*.gz")
|
554 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*29*.gz")
|
555 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*74*.gz")
|
556 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*26*.gz")
|
557 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
|
558 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
|
559 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
|
560 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
|
561 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
|
562 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
|
563 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
|
564 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
|
565 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
|
566 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
|
567 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
|
568 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
|
569 |
-
add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
|
570 |
-
add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
|
571 |
-
random.Random(SEED).shuffle(data_files)
|
572 |
-
|
573 |
-
total = len(data_files)
|
574 |
-
print(total)
|
575 |
-
perc = 0.05
|
576 |
-
val_size = int(perc * total)
|
577 |
-
train_size = total - val_size
|
578 |
-
train = data_files[:train_size]
|
579 |
-
val = data_files[train_size:]
|
580 |
-
print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
|
581 |
-
|
582 |
-
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
583 |
-
|
584 |
-
return train, val
|
585 |
-
|
586 |
-
# train, val = train_val_files()
|
587 |
-
|
588 |
-
load_grouped = True
|
589 |
-
|
590 |
-
if not load_grouped:
|
591 |
-
datasets = load_dataset('json', data_files={'train': train, 'validation': val})
|
592 |
-
|
593 |
-
# data_files = {}
|
594 |
-
# if data_args.train_file is not None:
|
595 |
-
# data_files["train"] = data_args.train_file
|
596 |
-
# if data_args.validation_file is not None:
|
597 |
-
# data_files["validation"] = data_args.validation_file
|
598 |
-
# extension = data_args.train_file.split(".")[-1]
|
599 |
-
# if extension == "txt":
|
600 |
-
# extension = "text"
|
601 |
-
# datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
602 |
|
603 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
604 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
@@ -606,12 +561,18 @@ if __name__ == "__main__":
|
|
606 |
# Load pretrained model and tokenizer
|
607 |
|
608 |
if model_args.tokenizer_name:
|
609 |
-
tokenizer =
|
610 |
-
model_args.tokenizer_name,
|
|
|
|
|
|
|
611 |
)
|
612 |
elif model_args.model_name_or_path:
|
613 |
-
tokenizer =
|
614 |
-
model_args.model_name_or_path,
|
|
|
|
|
|
|
615 |
)
|
616 |
else:
|
617 |
raise ValueError(
|
@@ -631,8 +592,30 @@ if __name__ == "__main__":
|
|
631 |
config = CONFIG_MAPPING[model_args.model_type]()
|
632 |
logger.warning("You are instantiating a new config instance from scratch.")
|
633 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
635 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
636 |
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
637 |
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
638 |
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
|
@@ -643,64 +626,36 @@ if __name__ == "__main__":
|
|
643 |
)
|
644 |
logger.info(f"Max seq length: {max_seq_length}, expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
|
645 |
|
646 |
-
#
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
|
662 |
-
def tokenize_function(examples):
|
663 |
-
return tokenizer(examples[text_column_name], return_attention_mask=False)
|
664 |
-
|
665 |
-
logger.info(f"Start tokenization, remove_column_names = {column_names}")
|
666 |
-
tokenized_datasets = datasets.map(
|
667 |
-
tokenize_function,
|
668 |
-
batched=True,
|
669 |
-
num_proc=data_args.preprocessing_num_workers,
|
670 |
-
remove_columns=column_names,
|
671 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
672 |
-
)
|
673 |
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
return result
|
689 |
-
|
690 |
-
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
691 |
-
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
692 |
-
# might be slower to preprocess.
|
693 |
-
#
|
694 |
-
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
695 |
-
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
696 |
-
logger.info(f"Start group_texts")
|
697 |
-
tokenized_datasets = tokenized_datasets.map(
|
698 |
-
group_texts,
|
699 |
-
batched=True,
|
700 |
-
batch_size=200,
|
701 |
-
num_proc=data_args.preprocessing_num_workers,
|
702 |
-
load_from_cache_file=not data_args.overwrite_cache,
|
703 |
-
)
|
704 |
|
705 |
# Enable tensorboard only on the master node
|
706 |
has_tensorboard = is_tensorboard_available()
|
@@ -729,15 +684,9 @@ if __name__ == "__main__":
|
|
729 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
730 |
)
|
731 |
else:
|
|
|
732 |
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
733 |
|
734 |
-
|
735 |
-
# def to_bf16(t):
|
736 |
-
# return jax.tree_map(lambda x: x.astype(jnp.bfloat16) if x.dtype == jnp.float32 else x, t)
|
737 |
-
#
|
738 |
-
#
|
739 |
-
# model.params = to_bf16(model.params)
|
740 |
-
|
741 |
# Data collator
|
742 |
# This one will take care of randomly masking the tokens.
|
743 |
data_collator = FlaxDataCollatorForT5MLM(
|
@@ -752,16 +701,13 @@ if __name__ == "__main__":
|
|
752 |
|
753 |
# Store some constant
|
754 |
num_epochs = int(training_args.num_train_epochs)
|
755 |
-
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
756 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
757 |
|
758 |
-
num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs
|
759 |
-
|
760 |
steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
|
761 |
-
|
762 |
|
763 |
# Create learning rate schedule
|
764 |
-
|
765 |
if training_args.warmup_steps:
|
766 |
warmup_steps = training_args.warmup_steps
|
767 |
elif training_args.warmup_ratio:
|
@@ -770,7 +716,6 @@ if __name__ == "__main__":
|
|
770 |
logging.info(f"Warmup steps set to {100*training_args.warmup_ratio}% = {warmup_steps} of total train steps {num_train_steps}")
|
771 |
else:
|
772 |
raise Exception("Need either --warmup_steps or --warmup_ratio")
|
773 |
-
|
774 |
warmup_fn = optax.linear_schedule(
|
775 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
|
776 |
)
|
@@ -823,8 +768,6 @@ if __name__ == "__main__":
|
|
823 |
else:
|
824 |
resume_step = 0
|
825 |
|
826 |
-
logger.info("")
|
827 |
-
|
828 |
# Define gradient update step fn
|
829 |
def train_step(state, batch, dropout_rng):
|
830 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
@@ -845,7 +788,8 @@ if __name__ == "__main__":
|
|
845 |
new_state = state.apply_gradients(grads=grad)
|
846 |
|
847 |
metrics = jax.lax.pmean(
|
848 |
-
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)},
|
|
|
849 |
)
|
850 |
|
851 |
return new_state, metrics, new_dropout_rng
|
@@ -875,17 +819,20 @@ if __name__ == "__main__":
|
|
875 |
|
876 |
logger.info("Replicate the train state on each device")
|
877 |
|
|
|
|
|
|
|
|
|
878 |
# Replicate the train state on each device
|
879 |
state = jax_utils.replicate(state)
|
880 |
|
881 |
logger.info("***** Running training *****")
|
882 |
-
|
883 |
-
logger.info(f" Num examples = {len(datasets['train'])}")
|
884 |
logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
|
885 |
logger.info(f" Num Epochs = {num_epochs}")
|
886 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
887 |
logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
|
888 |
-
logger.info(f" Total optimization steps = {
|
889 |
|
890 |
train_time = 0
|
891 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
@@ -899,16 +846,26 @@ if __name__ == "__main__":
|
|
899 |
|
900 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
901 |
num_train_samples = len(tokenized_datasets["train"])
|
902 |
-
train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
903 |
-
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
904 |
|
905 |
# Gather the indexes for creating the batch and do a training step
|
906 |
-
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
|
|
|
|
907 |
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
908 |
# skip to the step from which we are resuming
|
909 |
if cur_step < resume_step:
|
910 |
continue
|
911 |
|
|
|
912 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
913 |
try:
|
914 |
model_inputs = data_collator(samples)
|
@@ -922,7 +879,6 @@ if __name__ == "__main__":
|
|
922 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
923 |
train_metrics.append(train_metric)
|
924 |
|
925 |
-
|
926 |
if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
|
927 |
# Save metrics
|
928 |
train_metric = jax_utils.unreplicate(train_metric)
|
@@ -931,7 +887,7 @@ if __name__ == "__main__":
|
|
931 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
932 |
|
933 |
epochs.write(
|
934 |
-
f"Step... ({cur_step}
|
935 |
)
|
936 |
|
937 |
train_metrics = []
|
@@ -961,39 +917,50 @@ if __name__ == "__main__":
|
|
961 |
|
962 |
# Save metrics
|
963 |
if has_tensorboard and jax.process_index() == 0:
|
964 |
-
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
965 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
966 |
|
967 |
if cur_step % training_args.save_steps * grad_accum_steps == 0 and cur_step > 0:
|
968 |
-
logger.info(f"We should save the model here after {cur_step} steps")
|
969 |
# save checkpoint after each epoch and push checkpoint to the hub
|
970 |
if jax.process_index() == 0:
|
971 |
-
save_checkpoint(model, training_args.output_dir, state)
|
972 |
-
if training_args.save_total_limit is not None:
|
973 |
-
rotate_checkpoints(training_args.output_dir, training_args.save_total_limit)
|
974 |
# params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
975 |
-
#
|
976 |
-
#
|
977 |
-
#
|
978 |
-
#
|
979 |
-
|
980 |
-
|
981 |
-
|
982 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
983 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
984 |
|
985 |
-
# Save model at end
|
986 |
if jax.process_index() == 0:
|
987 |
-
|
988 |
-
|
989 |
-
|
990 |
-
|
991 |
-
|
992 |
-
|
993 |
-
|
994 |
-
|
995 |
-
|
996 |
-
|
997 |
-
|
998 |
-
|
999 |
-
|
|
|
|
18 |
|
19 |
Here is the full list of checkpoints on the hub that can be pretrained by this script:
|
20 |
https://huggingface.co/models?filter=t5
|
21 |
+
|
22 |
+
Adapted from the original version to support gradient accumulation and restarting.
|
23 |
"""
|
24 |
# You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
|
25 |
import logging
|
|
|
27 |
import sys
|
28 |
import time
|
29 |
import json
|
|
|
30 |
from dataclasses import dataclass, field
|
31 |
+
from itertools import chain
|
32 |
from pathlib import Path
|
33 |
from typing import Dict, List, Optional
|
34 |
|
35 |
import numpy as np
|
36 |
+
from datasets import load_dataset
|
37 |
from tqdm import tqdm
|
38 |
|
39 |
import flax
|
|
|
41 |
import jax.numpy as jnp
|
42 |
import optax
|
43 |
from flax import jax_utils, traverse_util
|
44 |
+
from flax.serialization import to_bytes, from_bytes
|
45 |
from flax.training import train_state
|
46 |
from flax.training.common_utils import get_metrics, onehot, shard
|
47 |
+
# from huggingface_hub import Repository
|
48 |
from transformers import (
|
49 |
CONFIG_MAPPING,
|
50 |
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
51 |
+
AutoTokenizer,
|
52 |
BatchEncoding,
|
53 |
FlaxT5ForConditionalGeneration,
|
|
|
54 |
HfArgumentParser,
|
55 |
PreTrainedTokenizerBase,
|
56 |
T5Config,
|
|
|
57 |
TrainingArguments,
|
58 |
is_tensorboard_available,
|
59 |
set_seed,
|
60 |
)
|
61 |
+
# from transformers.file_utils import get_full_repo_name
|
62 |
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
|
63 |
|
64 |
logger = logging.getLogger(__name__)
|
65 |
|
|
|
66 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
67 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
68 |
|
|
|
|
|
|
|
69 |
@dataclass
|
70 |
class ModelArguments:
|
71 |
"""
|
|
|
102 |
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
103 |
},
|
104 |
)
|
105 |
+
auth_token: Optional[str] = field(
|
106 |
+
default=None,
|
107 |
+
metadata={
|
108 |
+
"help": "Auth token for private repositories on the Huggingface Hub"
|
109 |
+
}
|
110 |
+
)
|
111 |
|
112 |
|
113 |
@dataclass
|
|
|
138 |
overwrite_cache: bool = field(
|
139 |
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
140 |
)
|
141 |
+
validation_split_count: Optional[int] = field(
|
142 |
+
default=10000,
|
143 |
metadata={
|
144 |
+
"help": "The count of the train set used as validation set in case there's no validation split"
|
145 |
},
|
146 |
)
|
147 |
max_seq_length: Optional[int] = field(
|
|
|
161 |
default=3.0,
|
162 |
metadata={"help": "Mean span length of masked tokens"},
|
163 |
)
|
164 |
+
max_train_samples: Optional[int] = field(
|
165 |
+
default=None,
|
166 |
+
metadata={
|
167 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
168 |
+
"value if set."
|
169 |
+
},
|
170 |
+
)
|
171 |
+
max_eval_samples: Optional[int] = field(
|
172 |
+
default=None,
|
173 |
+
metadata={
|
174 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
175 |
+
"value if set."
|
176 |
+
},
|
177 |
+
)
|
178 |
|
179 |
def __post_init__(self):
|
180 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
181 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
182 |
+
else:
|
183 |
+
if self.train_file is not None:
|
184 |
+
extension = self.train_file.split(".")[-1]
|
185 |
+
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
186 |
+
if self.validation_file is not None:
|
187 |
+
extension = self.validation_file.split(".")[-1]
|
188 |
+
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
|
|
189 |
|
190 |
|
191 |
def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length):
|
|
|
315 |
start_indices[:, 0] = mask_indices[:, 0]
|
316 |
|
317 |
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
|
318 |
+
sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
|
319 |
sentinel_ids -= mask_indices - start_indices
|
320 |
|
321 |
return sentinel_ids
|
|
|
380 |
np.random.shuffle(mask_indices)
|
381 |
first_in_segment = np.pad(mask_indices, [[1, 0]])
|
382 |
segment_id = np.cumsum(first_in_segment)
|
383 |
+
# count length of sub segments assuming that list is sorted
|
384 |
+
_, segment_length = np.unique(segment_id, return_counts=True)
|
385 |
return segment_length
|
386 |
|
387 |
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
|
|
|
424 |
for metric_name, value in eval_metrics.items():
|
425 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
426 |
|
427 |
+
|
428 |
def mb_item(x):
|
429 |
return x.item() if hasattr(x, "item") else x
|
430 |
|
431 |
|
432 |
+
def save_checkpoint(model, save_dir, state, cur_step: int, with_opt: bool = True, push_to_hub: bool = False):
|
|
|
433 |
state = jax_utils.unreplicate(state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
if with_opt:
|
435 |
+
logger.info(f'Saving optimizer and training state in {save_dir}...')
|
436 |
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
|
437 |
f.write(to_bytes(state.opt_state))
|
438 |
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
|
439 |
json.dump({"step": state.step.item()}, f)
|
440 |
+
logger.info(f'Saving model in {save_dir} {"and pushing it to HF Hub" if push_to_hub else ""}')
|
441 |
model.save_pretrained(
|
442 |
+
save_dir,
|
443 |
params=state.params,
|
444 |
+
push_to_hub=push_to_hub,
|
445 |
commit_message=f"Saving weights and logs of step {cur_step}",
|
446 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
|
448 |
+
def restore_checkpoint(load_dir, state):
|
449 |
+
logger.info(f"Restoring checkpoint from {load_dir}")
|
450 |
+
with open(os.path.join(load_dir, "flax_model.msgpack"), "rb") as f:
|
|
|
451 |
params = from_bytes(state.params, f.read())
|
452 |
+
with open(os.path.join(load_dir, "opt_state.msgpack"), "rb") as f:
|
|
|
453 |
opt_state = from_bytes(state.opt_state, f.read())
|
454 |
+
with open(os.path.join(load_dir, "training_state.json"), "r") as f:
|
|
|
455 |
training_state = json.load(f)
|
456 |
step = training_state["step"]
|
457 |
+
logger.info(f"Checkpoint restored at step {step}")
|
|
|
458 |
return state.replace(step=step, params=params, opt_state=opt_state), step
|
459 |
|
460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
461 |
if __name__ == "__main__":
|
462 |
# See all possible arguments in src/transformers/training_args.py
|
463 |
# or by passing the --help flag to this script.
|
|
|
498 |
# Set seed before initializing model.
|
499 |
set_seed(training_args.seed)
|
500 |
|
501 |
+
# Handle the repository creation
|
502 |
+
# if training_args.push_to_hub:
|
503 |
+
# if training_args.hub_model_id is None:
|
504 |
+
# repo_name = get_full_repo_name(
|
505 |
+
# Path(training_args.output_dir).absolute().name, token=training_args.hub_token
|
506 |
+
# )
|
507 |
+
# else:
|
508 |
+
# repo_name = training_args.hub_model_id
|
509 |
+
# repo = Repository(training_args.output_dir, clone_from=repo_name)
|
510 |
+
|
511 |
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
512 |
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
513 |
# (the dataset will be downloaded automatically from the datasets Hub).
|
|
|
522 |
datasets["validation"] = load_dataset(
|
523 |
data_args.dataset_name,
|
524 |
data_args.dataset_config_name,
|
525 |
+
split=f"train[:{data_args.validation_split_count}]",
|
526 |
cache_dir=model_args.cache_dir,
|
527 |
)
|
528 |
datasets["train"] = load_dataset(
|
529 |
data_args.dataset_name,
|
530 |
data_args.dataset_config_name,
|
531 |
+
split=f"train[{data_args.validation_split_count}:]",
|
532 |
+
cache_dir=model_args.cache_dir,
|
533 |
+
)
|
534 |
+
else:
|
535 |
+
datasets["validation"] = load_dataset(
|
536 |
+
data_args.dataset_name,
|
537 |
+
data_args.dataset_config_name,
|
538 |
+
split=f"validation[:{data_args.validation_split_count}]",
|
539 |
+
cache_dir=model_args.cache_dir,
|
540 |
+
)
|
541 |
+
datasets["train"] = load_dataset(
|
542 |
+
data_args.dataset_name,
|
543 |
+
data_args.dataset_config_name,
|
544 |
+
split="train",
|
545 |
cache_dir=model_args.cache_dir,
|
546 |
)
|
547 |
else:
|
548 |
+
data_files = {}
|
549 |
+
if data_args.train_file is not None:
|
550 |
+
data_files["train"] = data_args.train_file
|
551 |
+
if data_args.validation_file is not None:
|
552 |
+
data_files["validation"] = data_args.validation_file
|
553 |
+
extension = data_args.train_file.split(".")[-1]
|
554 |
+
if extension == "txt":
|
555 |
+
extension = "text"
|
556 |
+
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
|
558 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
559 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
|
|
561 |
# Load pretrained model and tokenizer
|
562 |
|
563 |
if model_args.tokenizer_name:
|
564 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
565 |
+
model_args.tokenizer_name,
|
566 |
+
cache_dir=model_args.cache_dir,
|
567 |
+
use_fast=model_args.use_fast_tokenizer,
|
568 |
+
use_auth_token=model_args.auth_token
|
569 |
)
|
570 |
elif model_args.model_name_or_path:
|
571 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
572 |
+
model_args.model_name_or_path,
|
573 |
+
cache_dir=model_args.cache_dir,
|
574 |
+
use_fast=model_args.use_fast_tokenizer,
|
575 |
+
use_auth_token=model_args.auth_token
|
576 |
)
|
577 |
else:
|
578 |
raise ValueError(
|
|
|
592 |
config = CONFIG_MAPPING[model_args.model_type]()
|
593 |
logger.warning("You are instantiating a new config instance from scratch.")
|
594 |
|
595 |
+
# Preprocessing the datasets.
|
596 |
+
# First we tokenize all the texts.
|
597 |
+
if training_args.do_train:
|
598 |
+
column_names = datasets["train"].column_names
|
599 |
+
else:
|
600 |
+
column_names = datasets["validation"].column_names
|
601 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
602 |
+
|
603 |
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
604 |
|
605 |
+
# Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
|
606 |
+
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
|
607 |
+
def tokenize_function(examples):
|
608 |
+
return tokenizer(examples[text_column_name], return_attention_mask=False)
|
609 |
+
|
610 |
+
logger.info(f"Start tokenization, remove_column_names = {column_names}")
|
611 |
+
tokenized_datasets = datasets.map(
|
612 |
+
tokenize_function,
|
613 |
+
batched=True,
|
614 |
+
num_proc=data_args.preprocessing_num_workers,
|
615 |
+
remove_columns=column_names,
|
616 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
617 |
+
)
|
618 |
+
|
619 |
# T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
|
620 |
# To ensure that the input length is `max_seq_length`, we need to increase the maximum length
|
621 |
# according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
|
|
|
626 |
)
|
627 |
logger.info(f"Max seq length: {max_seq_length}, expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
|
628 |
|
629 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
|
630 |
+
def group_texts(examples):
|
631 |
+
# Concatenate all texts.
|
632 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
633 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
634 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
635 |
+
# customize this part to your needs.
|
636 |
+
if total_length >= expanded_inputs_length:
|
637 |
+
total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
|
638 |
+
# Split by chunks of max_len.
|
639 |
+
result = {
|
640 |
+
k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
|
641 |
+
for k, t in concatenated_examples.items()
|
642 |
+
}
|
643 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
644 |
|
645 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
|
646 |
+
# remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
|
647 |
+
# might be slower to preprocess.
|
648 |
+
#
|
649 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
650 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
651 |
+
logger.info(f"Start group_texts")
|
652 |
+
tokenized_datasets = tokenized_datasets.map(
|
653 |
+
group_texts,
|
654 |
+
batched=True,
|
655 |
+
batch_size=200,
|
656 |
+
num_proc=data_args.preprocessing_num_workers,
|
657 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
658 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
659 |
|
660 |
# Enable tensorboard only on the master node
|
661 |
has_tensorboard = is_tensorboard_available()
|
|
|
684 |
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
685 |
)
|
686 |
else:
|
687 |
+
config.vocab_size = len(tokenizer)
|
688 |
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
689 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
690 |
# Data collator
|
691 |
# This one will take care of randomly masking the tokens.
|
692 |
data_collator = FlaxDataCollatorForT5MLM(
|
|
|
701 |
|
702 |
# Store some constant
|
703 |
num_epochs = int(training_args.num_train_epochs)
|
704 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
705 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
706 |
|
|
|
|
|
707 |
steps_per_epoch = len(tokenized_datasets['train']) // train_batch_size
|
708 |
+
num_train_steps = steps_per_epoch * num_epochs
|
709 |
|
710 |
# Create learning rate schedule
|
|
|
711 |
if training_args.warmup_steps:
|
712 |
warmup_steps = training_args.warmup_steps
|
713 |
elif training_args.warmup_ratio:
|
|
|
716 |
logging.info(f"Warmup steps set to {100*training_args.warmup_ratio}% = {warmup_steps} of total train steps {num_train_steps}")
|
717 |
else:
|
718 |
raise Exception("Need either --warmup_steps or --warmup_ratio")
|
|
|
719 |
warmup_fn = optax.linear_schedule(
|
720 |
init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
|
721 |
)
|
|
|
768 |
else:
|
769 |
resume_step = 0
|
770 |
|
|
|
|
|
771 |
# Define gradient update step fn
|
772 |
def train_step(state, batch, dropout_rng):
|
773 |
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
|
|
|
788 |
new_state = state.apply_gradients(grads=grad)
|
789 |
|
790 |
metrics = jax.lax.pmean(
|
791 |
+
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step // grad_accum_steps)},
|
792 |
+
axis_name="batch"
|
793 |
)
|
794 |
|
795 |
return new_state, metrics, new_dropout_rng
|
|
|
819 |
|
820 |
logger.info("Replicate the train state on each device")
|
821 |
|
822 |
+
# import pydevd_pycharm
|
823 |
+
#
|
824 |
+
# pydevd_pycharm.settrace('localhost', port=12345, stdoutToServer=True, stderrToServer=True)
|
825 |
+
|
826 |
# Replicate the train state on each device
|
827 |
state = jax_utils.replicate(state)
|
828 |
|
829 |
logger.info("***** Running training *****")
|
830 |
+
logger.info(f" Num examples = {len(datasets['train'])}")
|
|
|
831 |
logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
|
832 |
logger.info(f" Num Epochs = {num_epochs}")
|
833 |
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
834 |
logger.info(f" Total train batch size (w. parallel, distributed and grad_accum) = {train_batch_size}")
|
835 |
+
logger.info(f" Total optimization steps = {num_train_steps}")
|
836 |
|
837 |
train_time = 0
|
838 |
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
|
846 |
|
847 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
848 |
num_train_samples = len(tokenized_datasets["train"])
|
849 |
+
# train_samples_idx = jax.random.permutation(input_rng, jnp.arange(num_train_samples))
|
850 |
+
# train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
851 |
+
|
852 |
+
## IF THE DATASET IS TOO LONG, WE ONLY PROCEED SEQUENTIALLY WITHOUT SHUFFLING
|
853 |
+
samples_to_remove = num_train_samples % (train_batch_size // grad_accum_steps)
|
854 |
+
samples_idx = np.arange(num_train_samples)
|
855 |
+
if samples_to_remove != 0:
|
856 |
+
samples_idx = samples_idx[:-samples_to_remove]
|
857 |
+
steps = num_train_samples // (train_batch_size // grad_accum_steps)
|
858 |
|
859 |
# Gather the indexes for creating the batch and do a training step
|
860 |
+
# for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
861 |
+
# samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
862 |
+
for step in tqdm(range(steps), desc="Training...", position=1):
|
863 |
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
864 |
# skip to the step from which we are resuming
|
865 |
if cur_step < resume_step:
|
866 |
continue
|
867 |
|
868 |
+
batch_idx = [x for x in range(step * train_batch_size, (step + 1) * train_batch_size)]
|
869 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
870 |
try:
|
871 |
model_inputs = data_collator(samples)
|
|
|
879 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
880 |
train_metrics.append(train_metric)
|
881 |
|
|
|
882 |
if cur_step % training_args.logging_steps * grad_accum_steps == 0 and cur_step > 0:
|
883 |
# Save metrics
|
884 |
train_metric = jax_utils.unreplicate(train_metric)
|
|
|
887 |
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
888 |
|
889 |
epochs.write(
|
890 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
891 |
)
|
892 |
|
893 |
train_metrics = []
|
|
|
917 |
|
918 |
# Save metrics
|
919 |
if has_tensorboard and jax.process_index() == 0:
|
|
|
920 |
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
921 |
|
922 |
if cur_step % training_args.save_steps * grad_accum_steps == 0 and cur_step > 0:
|
|
|
923 |
# save checkpoint after each epoch and push checkpoint to the hub
|
924 |
if jax.process_index() == 0:
|
|
|
|
|
|
|
925 |
# params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
926 |
+
# model.save_pretrained(training_args.output_dir, params=params)
|
927 |
+
# tokenizer.save_pretrained(training_args.output_dir)
|
928 |
+
# if training_args.push_to_hub:
|
929 |
+
# repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
930 |
+
save_checkpoint(model, training_args.output_dir, state, cur_step, with_opt=False, push_to_hub=True)
|
931 |
+
|
932 |
+
# Eval after training
|
933 |
+
if training_args.do_eval:
|
934 |
+
num_eval_samples = len(tokenized_datasets["validation"])
|
935 |
+
eval_samples_idx = jnp.arange(num_eval_samples)
|
936 |
+
eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
|
937 |
+
|
938 |
+
eval_metrics = []
|
939 |
+
for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)):
|
940 |
+
samples = [tokenized_datasets["validation"][int(idx)] for idx in batch_idx]
|
941 |
+
model_inputs = data_collator(samples)
|
942 |
|
943 |
+
# Model forward
|
944 |
+
model_inputs = shard(model_inputs.data)
|
945 |
+
metrics = p_eval_step(state.params, model_inputs)
|
946 |
+
eval_metrics.append(metrics)
|
947 |
+
|
948 |
+
# get eval metrics
|
949 |
+
eval_metrics = get_metrics(eval_metrics)
|
950 |
+
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
|
951 |
|
|
|
952 |
if jax.process_index() == 0:
|
953 |
+
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}
|
954 |
+
path = os.path.join(training_args.output_dir, "eval_results.json")
|
955 |
+
with open(path, "w") as f:
|
956 |
+
json.dump(eval_metrics, f, indent=4, sort_keys=True)
|
957 |
+
|
958 |
+
# Save model at end
|
959 |
+
if jax.process_index() == 0:
|
960 |
+
# params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
961 |
+
# model.save_pretrained(training_args.output_dir, params=params)
|
962 |
+
# tokenizer.save_pretrained(training_args.output_dir)
|
963 |
+
# if training_args.push_to_hub:
|
964 |
+
# repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
|
965 |
+
#
|
966 |
+
save_checkpoint(model, training_args.output_dir, state, cur_step, with_opt=False, push_to_hub=True)
|
streaming_dataset_filter_test.py
DELETED
@@ -1,93 +0,0 @@
|
|
1 |
-
from clean import clean_text
|
2 |
-
|
3 |
-
from datasets import load_dataset
|
4 |
-
|
5 |
-
dataset_v0 = load_dataset('oscar', "unshuffled_deduplicated_nl", split='train', streaming=True)
|
6 |
-
|
7 |
-
# data_dir = "/home/yeb"
|
8 |
-
data_dir = "/home/yeb/Developer/data"
|
9 |
-
data_files = []
|
10 |
-
|
11 |
-
def train_val_files():
|
12 |
-
import glob
|
13 |
-
import random
|
14 |
-
SEED = 12345
|
15 |
-
|
16 |
-
def add_jsonlines_dir(path, filespec):
|
17 |
-
global data_files
|
18 |
-
data_files += glob.glob(f"{path}/{filespec}")
|
19 |
-
data_files = list(set(data_files))
|
20 |
-
print(f"Number of files {len(data_files)} after adding {path} glob {filespec}")
|
21 |
-
|
22 |
-
# add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
|
23 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned2", "*73*.gz")
|
24 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*47*.gz")
|
25 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*12*.gz")
|
26 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*29*.gz")
|
27 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*74*.gz")
|
28 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*26*.gz")
|
29 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
|
30 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
|
31 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
|
32 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
|
33 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
|
34 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
|
35 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
|
36 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
|
37 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
|
38 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
|
39 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
|
40 |
-
# add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
|
41 |
-
# add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
|
42 |
-
# add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
|
43 |
-
random.Random(SEED).shuffle(data_files)
|
44 |
-
|
45 |
-
total = len(data_files)
|
46 |
-
print(total)
|
47 |
-
perc = 0.05
|
48 |
-
val_size = int(perc * total)
|
49 |
-
train_size = total - val_size
|
50 |
-
train = data_files[:train_size]
|
51 |
-
val = data_files[train_size:]
|
52 |
-
print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
|
53 |
-
|
54 |
-
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
55 |
-
|
56 |
-
return train, val
|
57 |
-
|
58 |
-
train, val = train_val_files()
|
59 |
-
dataset_v0 = load_dataset('json', data_files={'train': train, 'validation': val})
|
60 |
-
|
61 |
-
|
62 |
-
dataset_v0 = load_dataset('oscar', "unshuffled_deduplicated_nl")
|
63 |
-
|
64 |
-
def f(obj):
|
65 |
-
obj["text"] = clean_text(obj["text"])
|
66 |
-
return obj
|
67 |
-
|
68 |
-
|
69 |
-
dataset_v1 = dataset_v0.map(
|
70 |
-
f,
|
71 |
-
batched=False,
|
72 |
-
num_proc=10,
|
73 |
-
)
|
74 |
-
|
75 |
-
datasets = dataset_v1.filter(
|
76 |
-
lambda obj: obj['text'] is not None,
|
77 |
-
num_proc=10,
|
78 |
-
)
|
79 |
-
|
80 |
-
it = iter(dataset_v0['train'])
|
81 |
-
print(next(it))
|
82 |
-
print(next(it))
|
83 |
-
print(next(it))
|
84 |
-
|
85 |
-
it = iter(dataset_v1['train'])
|
86 |
-
print(next(it))
|
87 |
-
print(next(it))
|
88 |
-
print(next(it))
|
89 |
-
|
90 |
-
# it = iter(dataset_v2)
|
91 |
-
# print(next(it))
|
92 |
-
# print(next(it))
|
93 |
-
# print(next(it))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tf_model.h5
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3083c65d23d0521977a9739022c8e48f3ee1094d43317b150cf044f1451cfd9c
|
3 |
+
size 892068248
|
train_tokenizer.py
DELETED
@@ -1,66 +0,0 @@
|
|
1 |
-
from datasets import load_dataset
|
2 |
-
from t5_tokenizer_model import SentencePieceUnigramTokenizer
|
3 |
-
|
4 |
-
# from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
|
5 |
-
|
6 |
-
data_dir = "/home/yeb"
|
7 |
-
data_files = []
|
8 |
-
|
9 |
-
|
10 |
-
def train_val_files():
|
11 |
-
import glob
|
12 |
-
import random
|
13 |
-
SEED = 12345
|
14 |
-
|
15 |
-
def add_jsonlines_dir(path, filespec):
|
16 |
-
global data_files
|
17 |
-
data_files += glob.glob(f"{path}/{filespec}")
|
18 |
-
print(f"Number of files {len(data_files)} after adding {path}")
|
19 |
-
|
20 |
-
# add_jsonlines_dir(f"{data_dir}/oscar_nl_cleaned")
|
21 |
-
add_jsonlines_dir(f"{data_dir}/c4_cleaned2", "*47*.gz")
|
22 |
-
add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
|
23 |
-
add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
|
24 |
-
random.Random(SEED).shuffle(data_files)
|
25 |
-
|
26 |
-
print(data_files)
|
27 |
-
total = len(data_files)
|
28 |
-
print(total)
|
29 |
-
perc = 0.01
|
30 |
-
val_size = int(perc * total)
|
31 |
-
train_size = total - val_size
|
32 |
-
train = data_files[:train_size]
|
33 |
-
val = data_files[train_size:]
|
34 |
-
print(f"Got {len(train)} training files and {perc * 100} % {len(val)} validation files")
|
35 |
-
|
36 |
-
assert list(set(train) & set(val)) == [], "Train overlaps with test"
|
37 |
-
|
38 |
-
return train, val
|
39 |
-
|
40 |
-
|
41 |
-
train, val = train_val_files()
|
42 |
-
|
43 |
-
dataset = load_dataset('json', data_files={'train': train, 'validation': val}, split='train')
|
44 |
-
|
45 |
-
vocab_size = 32000
|
46 |
-
input_sentence_size = None
|
47 |
-
tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")
|
48 |
-
|
49 |
-
|
50 |
-
# Build an iterator over this dataset
|
51 |
-
def batch_iterator(input_sentence_size=None):
|
52 |
-
if input_sentence_size is None:
|
53 |
-
input_sentence_size = len(dataset)
|
54 |
-
batch_length = 100
|
55 |
-
for i in range(0, input_sentence_size, batch_length):
|
56 |
-
yield dataset[i: i + batch_length]["text"]
|
57 |
-
|
58 |
-
# Train tokenizer
|
59 |
-
tokenizer.train_from_iterator(
|
60 |
-
iterator=batch_iterator(input_sentence_size=input_sentence_size),
|
61 |
-
vocab_size=vocab_size,
|
62 |
-
show_progress=True,
|
63 |
-
)
|
64 |
-
|
65 |
-
# Save files to disk
|
66 |
-
tokenizer.save("./tokenizer.json")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_state.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"step": 62500}
|
|
|
|