yhavinga commited on
Commit
55bd44c
1 Parent(s): 4c4fca4

Saving weights and logs of step 16000

Browse files
Determine_batch_size.ipynb ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "id": "kqOqTZuKeJoa",
11
+ "outputId": "9f63819c-9bc1-4c15-e9cd-9c1121edd2a6"
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "#!pip install \"jax[tpu]>=0.2.16\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
16
+ ]
17
+ },
18
+ {
19
+ "cell_type": "code",
20
+ "execution_count": null,
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": []
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "metadata": {
29
+ "colab": {
30
+ "base_uri": "https://localhost:8080/"
31
+ },
32
+ "id": "SQ-lhEVFeY4d",
33
+ "outputId": "7346c6b8-1848-4755-c114-94d6de50b50d"
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "#!git clone https://github.com/huggingface/transformers.git"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": 3,
43
+ "metadata": {
44
+ "colab": {
45
+ "base_uri": "https://localhost:8080/"
46
+ },
47
+ "id": "9qSTMLvFfBVs",
48
+ "outputId": "40659f61-86d4-4ae5-9262-501557737705"
49
+ },
50
+ "outputs": [],
51
+ "source": [
52
+ "#!pip install ./transformers"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": 4,
58
+ "metadata": {
59
+ "id": "7Og7zRTrfm08"
60
+ },
61
+ "outputs": [],
62
+ "source": [
63
+ "#!pip install jaxlib>=0.2.9"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 5,
69
+ "metadata": {
70
+ "id": "nmQv7VMaf1L8"
71
+ },
72
+ "outputs": [],
73
+ "source": [
74
+ "#!pip install flax>=0.3.4"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 6,
80
+ "metadata": {
81
+ "id": "MT6jpop-f4dc"
82
+ },
83
+ "outputs": [],
84
+ "source": [
85
+ "#!pip install optax>=0.0.9"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": 7,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "# %%capture\n",
95
+ "# !pip install jupyterlab_widgets\n",
96
+ "# !pip install ipywidgets"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 8,
102
+ "metadata": {
103
+ "id": "-F5NIqDmfDLb"
104
+ },
105
+ "outputs": [
106
+ {
107
+ "name": "stderr",
108
+ "output_type": "stream",
109
+ "text": [
110
+ "2021-07-08 10:11:37.310929: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
111
+ ]
112
+ }
113
+ ],
114
+ "source": [
115
+ "from transformers import (\n",
116
+ " CONFIG_MAPPING,\n",
117
+ " FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n",
118
+ " BatchEncoding,\n",
119
+ " FlaxT5ForConditionalGeneration,\n",
120
+ " T5ForConditionalGeneration,\n",
121
+ " HfArgumentParser,\n",
122
+ " PreTrainedTokenizerBase,\n",
123
+ " T5Config,\n",
124
+ " T5TokenizerFast,\n",
125
+ " TrainingArguments,\n",
126
+ " is_tensorboard_available,\n",
127
+ " set_seed,\n",
128
+ ")\n"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 9,
134
+ "metadata": {
135
+ "id": "aInICxY6gREQ"
136
+ },
137
+ "outputs": [],
138
+ "source": [
139
+ "import flax\n",
140
+ "import jax\n",
141
+ "import jax.numpy as jnp\n",
142
+ "import optax\n",
143
+ "from flax import jax_utils, traverse_util\n",
144
+ "from flax.training import train_state\n",
145
+ "from flax.training.common_utils import get_metrics, onehot, shard\n",
146
+ "from transformers import (\n",
147
+ " CONFIG_MAPPING,\n",
148
+ " FLAX_MODEL_FOR_MASKED_LM_MAPPING,\n",
149
+ " BatchEncoding,\n",
150
+ " FlaxT5ForConditionalGeneration,\n",
151
+ " T5ForConditionalGeneration,\n",
152
+ " HfArgumentParser,\n",
153
+ " PreTrainedTokenizerBase,\n",
154
+ " T5Config,\n",
155
+ " T5TokenizerFast,\n",
156
+ " TrainingArguments,\n",
157
+ " is_tensorboard_available,\n",
158
+ " set_seed,\n",
159
+ ")\n",
160
+ "from transformers.models.t5.modeling_flax_t5 import shift_tokens_right\n"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 10,
166
+ "metadata": {
167
+ "id": "iEqVlHptfOCT"
168
+ },
169
+ "outputs": [],
170
+ "source": [
171
+ "tokenizer = T5TokenizerFast.from_pretrained(\"t5-small\")\n",
172
+ "config = T5Config.from_pretrained(\"t5-small\")"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 11,
178
+ "metadata": {
179
+ "colab": {
180
+ "base_uri": "https://localhost:8080/"
181
+ },
182
+ "id": "LNETw3cWfjbr",
183
+ "outputId": "95c0e750-c087-46dd-92fa-39f8ff0238f2"
184
+ },
185
+ "outputs": [
186
+ {
187
+ "name": "stderr",
188
+ "output_type": "stream",
189
+ "text": [
190
+ "INFO:absl:Starting the local TPU driver.\n",
191
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
192
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
193
+ ]
194
+ }
195
+ ],
196
+ "source": [
197
+ "model = FlaxT5ForConditionalGeneration(config)"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 12,
203
+ "metadata": {},
204
+ "outputs": [],
205
+ "source": [
206
+ "import numpy as np"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": 13,
212
+ "metadata": {
213
+ "id": "T5F3BEA2f6xE"
214
+ },
215
+ "outputs": [],
216
+ "source": [
217
+ "input_ids = np.asarray(208 * [512 * [1]], dtype=np.int32)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 14,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "def run_forward(input_ids, params):\n",
227
+ " return model(input_ids, decoder_input_ids=input_ids).logits"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 15,
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "jitted_forward = jax.jit(run_forward)"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 16,
242
+ "metadata": {},
243
+ "outputs": [],
244
+ "source": [
245
+ "logits = jitted_forward(input_ids, model.params)"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {},
252
+ "outputs": [],
253
+ "source": []
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "metadata": {},
259
+ "outputs": [],
260
+ "source": []
261
+ }
262
+ ],
263
+ "metadata": {
264
+ "accelerator": "TPU",
265
+ "colab": {
266
+ "name": "Untitled1.ipynb",
267
+ "provenance": []
268
+ },
269
+ "kernelspec": {
270
+ "display_name": "Python 3 (ipykernel)",
271
+ "language": "python",
272
+ "name": "python3"
273
+ },
274
+ "language_info": {
275
+ "codemirror_mode": {
276
+ "name": "ipython",
277
+ "version": 3
278
+ },
279
+ "file_extension": ".py",
280
+ "mimetype": "text/x-python",
281
+ "name": "python",
282
+ "nbconvert_exporter": "python",
283
+ "pygments_lexer": "ipython3",
284
+ "version": "3.8.10"
285
+ }
286
+ },
287
+ "nbformat": 4,
288
+ "nbformat_minor": 4
289
+ }
Load_preprocessed_dataset.ipynb ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 10,
6
+ "id": "cf148030-7287-4c9e-ae32-8d1e1c47be30",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from datasets import Dataset, DatasetDict"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 11,
16
+ "id": "5161b4ba-e8cf-43e1-b67e-503c29aa4271",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "datasets = DatasetDict.load_from_disk(\"./grouped_dataset\")"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": 12,
26
+ "id": "15f9d047-ac35-43d7-ab55-9f9afe96dd07",
27
+ "metadata": {},
28
+ "outputs": [
29
+ {
30
+ "data": {
31
+ "text/plain": [
32
+ "DatasetDict({\n",
33
+ " train: Dataset({\n",
34
+ " features: ['input_ids'],\n",
35
+ " num_rows: 86438919\n",
36
+ " })\n",
37
+ " validation: Dataset({\n",
38
+ " features: ['input_ids'],\n",
39
+ " num_rows: 4735324\n",
40
+ " })\n",
41
+ "})"
42
+ ]
43
+ },
44
+ "execution_count": 12,
45
+ "metadata": {},
46
+ "output_type": "execute_result"
47
+ }
48
+ ],
49
+ "source": [
50
+ "datasets"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 14,
56
+ "id": "d1d1218e-142e-441a-b20d-d300b13b172a",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "train = datasets['train']"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "id": "9eaddfb1-242f-4a25-8789-efe97b2a5712",
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": []
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 15,
74
+ "id": "8aabb26f-19ca-467a-b383-3a693be70cac",
75
+ "metadata": {},
76
+ "outputs": [
77
+ {
78
+ "name": "stdout",
79
+ "output_type": "stream",
80
+ "text": [
81
+ "86438919\n"
82
+ ]
83
+ }
84
+ ],
85
+ "source": [
86
+ "print(len(train))"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "f3176986-5b34-4ed6-a643-e342db9a2ce8",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": []
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 16,
100
+ "id": "1205bbef-ba9d-4ddc-af2e-602d56b7dd64",
101
+ "metadata": {},
102
+ "outputs": [
103
+ {
104
+ "name": "stdout",
105
+ "output_type": "stream",
106
+ "text": [
107
+ "{'input_ids': [256, 3, 20, 18452, 6690, 7757, 1286, 43, 10, 4942, 1286, 80, 12, 4782, 5442, 39, 5385, 33, 4, 5, 3, 2924, 117, 5669, 228, 21, 193, 9030, 511, 24, 11, 5, 665, 165, 4218, 7, 26, 264, 1528, 35, 105, 3, 19653, 12, 9661, 17156, 13955, 4, 132, 5, 611, 959, 961, 146, 6522, 7757, 1286, 89, 7500, 9716, 11, 5, 4868, 107, 13604, 12, 12836, 13368, 11, 611, 959, 4, 3, 69, 99, 12, 13132, 6690, 590, 5, 1803, 1867, 69, 7, 924, 10, 1762, 4, 3, 69, 538, 489, 14, 1149, 16, 3, 11384, 199, 116, 399, 4782, 291, 3, 6, 237, 13, 2629, 3, 8987, 291, 4, 69, 5, 3, 27, 72, 20, 325, 3, 2924, 133, 21, 105, 9030, 10, 1149, 242, 16, 144, 13572, 11, 9, 13401, 20, 7951, 8, 165, 4218, 4, 5, 1910]}\n"
108
+ ]
109
+ }
110
+ ],
111
+ "source": [
112
+ "it = iter(train)\n",
113
+ "\n",
114
+ "print(next(it))"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "f5d4e8de-419c-4c70-896e-fbd640bb7321",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": []
124
+ }
125
+ ],
126
+ "metadata": {
127
+ "kernelspec": {
128
+ "display_name": "Python 3 (ipykernel)",
129
+ "language": "python",
130
+ "name": "python3"
131
+ },
132
+ "language_info": {
133
+ "codemirror_mode": {
134
+ "name": "ipython",
135
+ "version": 3
136
+ },
137
+ "file_extension": ".py",
138
+ "mimetype": "text/x-python",
139
+ "name": "python",
140
+ "nbconvert_exporter": "python",
141
+ "pygments_lexer": "ipython3",
142
+ "version": "3.8.10"
143
+ }
144
+ },
145
+ "nbformat": 4,
146
+ "nbformat_minor": 5
147
+ }
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:68189daaf4d05f88ba3305865c8d542080aac997060714ef57b8b116f56010c0
3
  size 891548548
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed68bf4bf2ba245a90ae31d71a56c0f85d1ef7665f0748bd10826c688e5de825
3
  size 891548548
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4162e34e834b3cf52825caa4f4d2cbec358c35405a212052af6c977b2561680
3
  size 891650495
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a8f60fdc3ad43a82bab7ec3dcaf1138179d7508798267becb15426d86b9385f
3
  size 891650495
run_t5.sh CHANGED
@@ -6,7 +6,6 @@ mkdir -p "${MODEL_DIR}/runs"
6
 
7
  # T5 paper lr 0.01 with batch size 128
8
  # We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
9
- # Warmup steps is set to 4% of the training steps
10
 
11
  ./run_t5_mlm_flax_custom_dataset.py \
12
  --output_dir="${MODEL_DIR}" \
@@ -23,12 +22,13 @@ mkdir -p "${MODEL_DIR}/runs"
23
  --dtype="bfloat16" \
24
  --overwrite_output_dir \
25
  --num_train_epochs="1" \
26
- --logging_steps="50" \
27
  --save_steps="2000" \
28
- --eval_steps="1000000" \
 
 
29
  --push_to_hub
30
 
31
- # --resume_from_checkpoint="${MODEL_DIR}/ckpt-1500" \
32
 
33
 
34
  #git add pytorch_model.bin
 
6
 
7
  # T5 paper lr 0.01 with batch size 128
8
  # We have a batch size of 8 devices * 32 = 256, so lr = 0.01/2
 
9
 
10
  ./run_t5_mlm_flax_custom_dataset.py \
11
  --output_dir="${MODEL_DIR}" \
 
22
  --dtype="bfloat16" \
23
  --overwrite_output_dir \
24
  --num_train_epochs="1" \
25
+ --logging_steps="20" \
26
  --save_steps="2000" \
27
+ --eval_steps="10000000" \
28
+ --resume_from_checkpoint="${MODEL_DIR}/ckpt-14000" \
29
+ --warmup_steps="3413" \
30
  --push_to_hub
31
 
 
32
 
33
 
34
  #git add pytorch_model.bin
run_t5_mlm_flax_custom_dataset.py CHANGED
@@ -31,7 +31,7 @@ from pathlib import Path
31
  from typing import Dict, List, Optional
32
 
33
  import numpy as np
34
- from datasets import load_dataset
35
  from tqdm import tqdm
36
 
37
  import flax
@@ -552,15 +552,15 @@ if __name__ == "__main__":
552
  add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
553
  add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
554
  add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
555
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
556
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
557
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
558
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
559
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
560
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
561
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
562
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
563
- add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
564
  add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
565
  add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
566
  random.Random(SEED).shuffle(data_files)
@@ -580,7 +580,10 @@ if __name__ == "__main__":
580
 
581
  train, val = train_val_files()
582
 
583
- datasets = load_dataset('json', data_files={'train': train, 'validation': val})
 
 
 
584
 
585
  # data_files = {}
586
  # if data_args.train_file is not None:
@@ -623,31 +626,8 @@ if __name__ == "__main__":
623
  config = CONFIG_MAPPING[model_args.model_type]()
624
  logger.warning("You are instantiating a new config instance from scratch.")
625
 
626
- # Preprocessing the datasets.
627
- # First we tokenize all the texts.
628
- if training_args.do_train:
629
- column_names = datasets["train"].column_names
630
- else:
631
- column_names = datasets["validation"].column_names
632
- text_column_name = "text" if "text" in column_names else column_names[0]
633
-
634
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
635
 
636
- # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
637
- # Since we make sure that all sequences are of the same length, no attention_mask is needed.
638
- def tokenize_function(examples):
639
- return tokenizer(examples[text_column_name], return_attention_mask=False)
640
-
641
- logger.info(f"Start tokenization, remove_column_names = {column_names}")
642
-
643
- tokenized_datasets = datasets.map(
644
- tokenize_function,
645
- batched=True,
646
- num_proc=data_args.preprocessing_num_workers,
647
- remove_columns=column_names,
648
- load_from_cache_file=not data_args.overwrite_cache,
649
- )
650
-
651
  # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
652
  # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
653
  # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
@@ -656,40 +636,64 @@ if __name__ == "__main__":
656
  noise_density=data_args.mlm_probability,
657
  mean_noise_span_length=data_args.mean_noise_span_length,
658
  )
 
659
 
660
- logger.info(f"Expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
661
-
662
- logger.info(f"Start group_texts")
663
-
664
- # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
665
- def group_texts(examples):
666
- # Concatenate all texts.
667
- concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
668
- total_length = len(concatenated_examples[list(examples.keys())[0]])
669
- # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
670
- # customize this part to your needs.
671
- if total_length >= expanded_inputs_length:
672
- total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
673
- # Split by chunks of max_len.
674
- result = {
675
- k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
676
- for k, t in concatenated_examples.items()
677
- }
678
- return result
 
 
 
679
 
680
- # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
681
- # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
682
- # might be slower to preprocess.
683
- #
684
- # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
685
- # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
686
- tokenized_datasets = tokenized_datasets.map(
687
- group_texts,
688
- batched=True,
689
- batch_size=200,
690
- num_proc=data_args.preprocessing_num_workers,
691
- load_from_cache_file=not data_args.overwrite_cache,
692
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
693
 
694
  # Enable tensorboard only on the master node
695
  has_tensorboard = is_tensorboard_available()
@@ -751,9 +755,14 @@ if __name__ == "__main__":
751
 
752
  # Create learning rate schedule
753
 
754
- # See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at 4% of training steps
755
- warmup_steps = int(0.04 * num_train_steps)
756
- logging.info(f"Warmup steps set to 4% = {warmup_steps} of total train steps {num_train_steps}")
 
 
 
 
 
757
 
758
  warmup_fn = optax.linear_schedule(
759
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
@@ -863,7 +872,8 @@ if __name__ == "__main__":
863
  state = jax_utils.replicate(state)
864
 
865
  logger.info("***** Running training *****")
866
- logger.info(f" Num examples = {len(datasets['train'])}")
 
867
  logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
868
  logger.info(f" Num Epochs = {num_epochs}")
869
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
 
31
  from typing import Dict, List, Optional
32
 
33
  import numpy as np
34
+ from datasets import load_dataset, DatasetDict
35
  from tqdm import tqdm
36
 
37
  import flax
 
552
  add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*54*.gz")
553
  add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*68*.gz")
554
  add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*57*.gz")
555
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*46*.gz")
556
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*35*.gz")
557
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*13*.gz")
558
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*41*.gz")
559
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*52*.gz")
560
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*63*.gz")
561
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*85*.gz")
562
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*81*.gz")
563
+ # add_jsonlines_dir(f"{data_dir}/c4_cleaned", "*96*.gz")
564
  add_jsonlines_dir(f"{data_dir}/nrc_uniq_cleaned_20210223", "*.gz")
565
  add_jsonlines_dir(f"{data_dir}/nu_uniq_cleaned_20210225", "*.gz")
566
  random.Random(SEED).shuffle(data_files)
 
580
 
581
  train, val = train_val_files()
582
 
583
+ load_grouped = False
584
+
585
+ if not load_grouped:
586
+ datasets = load_dataset('json', data_files={'train': train, 'validation': val})
587
 
588
  # data_files = {}
589
  # if data_args.train_file is not None:
 
626
  config = CONFIG_MAPPING[model_args.model_type]()
627
  logger.warning("You are instantiating a new config instance from scratch.")
628
 
 
 
 
 
 
 
 
 
629
  max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
631
  # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token.
632
  # To ensure that the input length is `max_seq_length`, we need to increase the maximum length
633
  # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly.
 
636
  noise_density=data_args.mlm_probability,
637
  mean_noise_span_length=data_args.mean_noise_span_length,
638
  )
639
+ logger.info(f"Max seq length: {max_seq_length}, expanded_inputs_length: {expanded_inputs_length}, targets_length: {targets_length}")
640
 
641
+ # Preprocessing the datasets.
642
+ # First we tokenize all the texts.
643
+ if not load_grouped:
644
+ if training_args.do_train:
645
+ column_names = datasets["train"].column_names
646
+ else:
647
+ column_names = datasets["validation"].column_names
648
+ text_column_name = "text" if "text" in column_names else column_names[0]
649
+
650
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
651
+ # Since we make sure that all sequences are of the same length, no attention_mask is needed.
652
+ def tokenize_function(examples):
653
+ return tokenizer(examples[text_column_name], return_attention_mask=False)
654
+
655
+ logger.info(f"Start tokenization, remove_column_names = {column_names}")
656
+ tokenized_datasets = datasets.map(
657
+ tokenize_function,
658
+ batched=True,
659
+ num_proc=data_args.preprocessing_num_workers,
660
+ remove_columns=column_names,
661
+ load_from_cache_file=not data_args.overwrite_cache,
662
+ )
663
 
664
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length.
665
+ def group_texts(examples):
666
+ # Concatenate all texts.
667
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
668
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
669
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
670
+ # customize this part to your needs.
671
+ if total_length >= expanded_inputs_length:
672
+ total_length = (total_length // expanded_inputs_length) * expanded_inputs_length
673
+ # Split by chunks of max_len.
674
+ result = {
675
+ k: [t[i : i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length)]
676
+ for k, t in concatenated_examples.items()
677
+ }
678
+ return result
679
+
680
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
681
+ # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
682
+ # might be slower to preprocess.
683
+ #
684
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
685
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
686
+ logger.info(f"Start group_texts")
687
+ tokenized_datasets = tokenized_datasets.map(
688
+ group_texts,
689
+ batched=True,
690
+ batch_size=200,
691
+ num_proc=data_args.preprocessing_num_workers,
692
+ load_from_cache_file=not data_args.overwrite_cache,
693
+ )
694
+ else:
695
+ logger.info("Loading tokenized and grouped dataset")
696
+ tokenized_datasets = DatasetDict.load_from_disk("/home/yeb/grouped_datasets")
697
 
698
  # Enable tensorboard only on the master node
699
  has_tensorboard = is_tensorboard_available()
 
755
 
756
  # Create learning rate schedule
757
 
758
+ if training_args.warmup_steps:
759
+ warmup_steps = training_args.warmup_steps
760
+ elif training_args.warmup_ratio:
761
+ # See https://arxiv.org/pdf/2104.07705.pdf for rationale of choosing the peak at % of training steps
762
+ warmup_steps = int(training_args.warmup_ratio * num_train_steps)
763
+ logging.info(f"Warmup steps set to {100*training_args.warmup_ratio}% = {warmup_steps} of total train steps {num_train_steps}")
764
+ else:
765
+ raise Exception("Need either --warmup_steps or --warmup_ratio")
766
 
767
  warmup_fn = optax.linear_schedule(
768
  init_value=0.0, end_value=training_args.learning_rate, transition_steps=warmup_steps
 
872
  state = jax_utils.replicate(state)
873
 
874
  logger.info("***** Running training *****")
875
+ if not load_grouped:
876
+ logger.info(f" Num examples = {len(datasets['train'])}")
877
  logger.info(f" Num tokenized group examples {len(tokenized_datasets['train'])}")
878
  logger.info(f" Num Epochs = {num_epochs}")
879
  logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
runs/Jul10_12-03-45_t1v-n-0e7426e8-w-0/events.out.tfevents.1625920526.t1v-n-0e7426e8-w-0.48005.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac6e43e7a7661df39374a0364897650b948e816446aa7cfb526ff2f0f51b9e1e
3
- size 40
 
 
 
 
runs/Jul10_12-39-58_t1v-n-0e7426e8-w-0/events.out.tfevents.1625922498.t1v-n-0e7426e8-w-0.52901.3.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d9f295ae46710e76c6932be7c90ac6db5f1f58dbce55b34870ab3d43248fdaee
3
- size 2077245
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62ea16c934f55451bfb14bd666567f0c8837fead4ad2e1f6a8adbd8d11fd25a6
3
+ size 2359167
runs/{Jul10_07-37-20_t1v-n-0e7426e8-w-0/events.out.tfevents.1625902752.t1v-n-0e7426e8-w-0.18397.3.v2 → Jul11_09-15-07_t1v-n-0e7426e8-w-0/events.out.tfevents.1625995853.t1v-n-0e7426e8-w-0.145718.3.v2} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1aa4fd14ba6d0007ac2b4c7ad5f7b03ab486b3899ece3eba1fefe852923f2366
3
- size 40
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecdd317adb51d2b44773888aaa52793f97b5af475a8f35560774d02bd6ae20a2
3
+ size 300940