Abinaya Mahendiran commited on
Commit
3d74ff6
·
1 Parent(s): 2cc2a38

Updated baseline

Browse files
Files changed (42) hide show
  1. .gitattributes +3 -2
  2. gpt-2-tamil/config.json +36 -0
  3. gpt-2-tamil/events.out.tfevents.1626064970.t1v-n-ebe36c53-w-0.400773.3.v2 +3 -0
  4. gpt-2-tamil/events.out.tfevents.1626108088.t1v-n-ebe36c53-w-0.483452.3.v2 +3 -0
  5. gpt-2-tamil/events.out.tfevents.1626108395.t1v-n-ebe36c53-w-0.486342.3.v2 +3 -0
  6. gpt-2-tamil/flax_model.msgpack +3 -0
  7. gpt-2-tamil/tokenizer.json +0 -0
  8. scripts/run.log +0 -0
  9. scripts/train_gpt2-oscar-tamil.sh +11 -3
  10. scripts/wandb/debug-internal.log +1 -0
  11. scripts/wandb/debug.log +1 -0
  12. scripts/wandb/latest-run +1 -0
  13. scripts/wandb/run-20210712_044248-12kjsz9i/files/config.yaml +301 -0
  14. scripts/wandb/run-20210712_044248-12kjsz9i/files/events.out.tfevents.1626064970.t1v-n-ebe36c53-w-0.400773.3.v2 +1 -0
  15. scripts/wandb/run-20210712_044248-12kjsz9i/files/output.log +3 -0
  16. scripts/wandb/run-20210712_044248-12kjsz9i/files/requirements.txt +123 -0
  17. scripts/wandb/run-20210712_044248-12kjsz9i/files/wandb-metadata.json +45 -0
  18. scripts/wandb/run-20210712_044248-12kjsz9i/files/wandb-summary.json +1 -0
  19. scripts/wandb/run-20210712_044248-12kjsz9i/logs/debug-internal.log +3 -0
  20. scripts/wandb/run-20210712_044248-12kjsz9i/logs/debug.log +3 -0
  21. scripts/wandb/run-20210712_044248-12kjsz9i/run-12kjsz9i.wandb +3 -0
  22. scripts/wandb/run-20210712_164126-1cgtoi5r/files/config.yaml +305 -0
  23. scripts/wandb/run-20210712_164126-1cgtoi5r/files/events.out.tfevents.1626108088.t1v-n-ebe36c53-w-0.483452.3.v2 +1 -0
  24. scripts/wandb/run-20210712_164126-1cgtoi5r/files/output.log +3 -0
  25. scripts/wandb/run-20210712_164126-1cgtoi5r/files/requirements.txt +123 -0
  26. scripts/wandb/run-20210712_164126-1cgtoi5r/files/wandb-metadata.json +49 -0
  27. scripts/wandb/run-20210712_164126-1cgtoi5r/files/wandb-summary.json +1 -0
  28. scripts/wandb/run-20210712_164126-1cgtoi5r/logs/debug-internal.log +3 -0
  29. scripts/wandb/run-20210712_164126-1cgtoi5r/logs/debug.log +3 -0
  30. scripts/wandb/run-20210712_164126-1cgtoi5r/run-1cgtoi5r.wandb +3 -0
  31. scripts/wandb/run-20210712_164633-1ddv4131/files/config.yaml +305 -0
  32. scripts/wandb/run-20210712_164633-1ddv4131/files/events.out.tfevents.1626108395.t1v-n-ebe36c53-w-0.486342.3.v2 +1 -0
  33. scripts/wandb/run-20210712_164633-1ddv4131/files/output.log +3 -0
  34. scripts/wandb/run-20210712_164633-1ddv4131/files/requirements.txt +123 -0
  35. scripts/wandb/run-20210712_164633-1ddv4131/files/wandb-metadata.json +49 -0
  36. scripts/wandb/run-20210712_164633-1ddv4131/files/wandb-summary.json +1 -0
  37. scripts/wandb/run-20210712_164633-1ddv4131/logs/debug-internal.log +3 -0
  38. scripts/wandb/run-20210712_164633-1ddv4131/logs/debug.log +3 -0
  39. scripts/wandb/run-20210712_164633-1ddv4131/run-1ddv4131.wandb +3 -0
  40. src/create_config.py +1 -1
  41. src/run_clm_flax.py +147 -232
  42. src/train_tokenizer.py +1 -1
.gitattributes CHANGED
@@ -12,6 +12,7 @@
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
- *.pt filter=lfs diff=lfs merge=lfs -text
 
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
+ *.log filter=lfs diff=lfs merge=lfs -text
16
+ *.wandb filter=lfs diff=lfs merge=lfs -text
17
  *.pth filter=lfs diff=lfs merge=lfs -text
18
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
gpt-2-tamil/config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_function": "gelu_new",
3
+ "architectures": [
4
+ "GPT2LMHeadModel"
5
+ ],
6
+ "attn_pdrop": 0.0,
7
+ "bos_token_id": 50256,
8
+ "embd_pdrop": 0.0,
9
+ "eos_token_id": 50256,
10
+ "gradient_checkpointing": false,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_epsilon": 1e-05,
13
+ "model_type": "gpt2",
14
+ "n_ctx": 1024,
15
+ "n_embd": 768,
16
+ "n_head": 12,
17
+ "n_inner": null,
18
+ "n_layer": 12,
19
+ "n_positions": 1024,
20
+ "resid_pdrop": 0.0,
21
+ "scale_attn_weights": true,
22
+ "summary_activation": null,
23
+ "summary_first_dropout": 0.1,
24
+ "summary_proj_to_labels": true,
25
+ "summary_type": "cls_index",
26
+ "summary_use_proj": true,
27
+ "task_specific_params": {
28
+ "text-generation": {
29
+ "do_sample": true,
30
+ "max_length": 50
31
+ }
32
+ },
33
+ "transformers_version": "4.9.0.dev0",
34
+ "use_cache": true,
35
+ "vocab_size": 50257
36
+ }
gpt-2-tamil/events.out.tfevents.1626064970.t1v-n-ebe36c53-w-0.400773.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cc79262fd103f58e2c2bb461dc3db699613de0d444116b20f5644759ebfbe6e
3
+ size 40
gpt-2-tamil/events.out.tfevents.1626108088.t1v-n-ebe36c53-w-0.483452.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d27b640fe5e66ecd1bf4a3667a35e4243bc4afc19dd7a2247a6e3d0a56211f6
3
+ size 40
gpt-2-tamil/events.out.tfevents.1626108395.t1v-n-ebe36c53-w-0.486342.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f98c1e1d0d88519bc875d97549b8ceb6d03f7c5d0aca79c15a10749f91c28362
3
+ size 19735799
gpt-2-tamil/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f15aa88a1b0381444c39e9e70f17a82751f7c317d7be7e22cc9707527f9a8c27
3
+ size 497764120
gpt-2-tamil/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
scripts/run.log ADDED
File without changes
scripts/train_gpt2-oscar-tamil.sh CHANGED
@@ -1,5 +1,5 @@
1
  #!/usr/bin/env bash
2
- ./run_clm_flax.py \
3
  --output_dir="${MODEL_DIR}" \
4
  --model_type="gpt2" \
5
  --config_name="${MODEL_DIR}" \
@@ -10,8 +10,16 @@
10
  --block_size="512" \
11
  --per_device_train_batch_size="64" \
12
  --per_device_eval_batch_size="64" \
13
- --learning_rate="5e-3" --warmup_steps="1000" \
 
14
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
15
  --overwrite_output_dir \
16
- --num_train_epochs="20" \
 
 
 
 
 
 
17
  #--push_to_hub
 
 
1
  #!/usr/bin/env bash
2
+ python ../src/run_clm_flax.py \
3
  --output_dir="${MODEL_DIR}" \
4
  --model_type="gpt2" \
5
  --config_name="${MODEL_DIR}" \
 
10
  --block_size="512" \
11
  --per_device_train_batch_size="64" \
12
  --per_device_eval_batch_size="64" \
13
+ --learning_rate="3e-5" \
14
+ --warmup_steps="1000" \
15
  --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
16
  --overwrite_output_dir \
17
+ --num_train_epochs="25" \
18
+ --report_to wandb \
19
+ --run_name trial \
20
+ --logging_steps="500" \
21
+ --save_steps="2500" \
22
+ --eval_steps="2500" \
23
+ --preprocessing_num_workers="90" \
24
  #--push_to_hub
25
+ 2>&1 | tee run.log
scripts/wandb/debug-internal.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20210712_164633-1ddv4131/logs/debug-internal.log
scripts/wandb/debug.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20210712_164633-1ddv4131/logs/debug.log
scripts/wandb/latest-run ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20210712_164633-1ddv4131
scripts/wandb/run-20210712_044248-12kjsz9i/files/config.yaml ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ framework: huggingface
14
+ huggingface_version: 4.9.0.dev0
15
+ is_jupyter_run: false
16
+ is_kaggle_kernel: false
17
+ python_version: 3.8.10
18
+ t:
19
+ 1:
20
+ - 1
21
+ - 3
22
+ - 11
23
+ 4: 3.8.10
24
+ 5: 0.10.33
25
+ 6: 4.9.0.dev0
26
+ 8:
27
+ - 5
28
+ adafactor:
29
+ desc: null
30
+ value: false
31
+ adam_beta1:
32
+ desc: null
33
+ value: 0.9
34
+ adam_beta2:
35
+ desc: null
36
+ value: 0.98
37
+ adam_epsilon:
38
+ desc: null
39
+ value: 1.0e-08
40
+ block_size:
41
+ desc: null
42
+ value: 512
43
+ cache_dir:
44
+ desc: null
45
+ value: null
46
+ config_name:
47
+ desc: null
48
+ value: ../gpt-2-tamil/
49
+ dataloader_drop_last:
50
+ desc: null
51
+ value: false
52
+ dataloader_num_workers:
53
+ desc: null
54
+ value: 0
55
+ dataloader_pin_memory:
56
+ desc: null
57
+ value: true
58
+ dataset_config_name:
59
+ desc: null
60
+ value: unshuffled_deduplicated_ta
61
+ dataset_name:
62
+ desc: null
63
+ value: oscar
64
+ ddp_find_unused_parameters:
65
+ desc: null
66
+ value: null
67
+ debug:
68
+ desc: null
69
+ value: []
70
+ deepspeed:
71
+ desc: null
72
+ value: null
73
+ disable_tqdm:
74
+ desc: null
75
+ value: false
76
+ do_eval:
77
+ desc: null
78
+ value: true
79
+ do_predict:
80
+ desc: null
81
+ value: false
82
+ do_train:
83
+ desc: null
84
+ value: true
85
+ dtype:
86
+ desc: null
87
+ value: float32
88
+ eval_accumulation_steps:
89
+ desc: null
90
+ value: null
91
+ eval_steps:
92
+ desc: null
93
+ value: 500
94
+ evaluation_strategy:
95
+ desc: null
96
+ value: IntervalStrategy.NO
97
+ fp16:
98
+ desc: null
99
+ value: false
100
+ fp16_backend:
101
+ desc: null
102
+ value: auto
103
+ fp16_full_eval:
104
+ desc: null
105
+ value: false
106
+ fp16_opt_level:
107
+ desc: null
108
+ value: O1
109
+ gradient_accumulation_steps:
110
+ desc: null
111
+ value: 1
112
+ greater_is_better:
113
+ desc: null
114
+ value: null
115
+ group_by_length:
116
+ desc: null
117
+ value: false
118
+ ignore_data_skip:
119
+ desc: null
120
+ value: false
121
+ label_names:
122
+ desc: null
123
+ value: null
124
+ label_smoothing_factor:
125
+ desc: null
126
+ value: 0.0
127
+ learning_rate:
128
+ desc: null
129
+ value: 3.0e-05
130
+ length_column_name:
131
+ desc: null
132
+ value: length
133
+ load_best_model_at_end:
134
+ desc: null
135
+ value: false
136
+ local_rank:
137
+ desc: null
138
+ value: -1
139
+ log_level:
140
+ desc: null
141
+ value: -1
142
+ log_level_replica:
143
+ desc: null
144
+ value: -1
145
+ log_on_each_node:
146
+ desc: null
147
+ value: true
148
+ logging_dir:
149
+ desc: null
150
+ value: ../tmp/../gpt-2-tamil/runs/Jul11_17-18-14_t1v-n-ebe36c53-w-0
151
+ logging_first_step:
152
+ desc: null
153
+ value: false
154
+ logging_steps:
155
+ desc: null
156
+ value: 500
157
+ logging_strategy:
158
+ desc: null
159
+ value: IntervalStrategy.STEPS
160
+ lr_scheduler_type:
161
+ desc: null
162
+ value: SchedulerType.LINEAR
163
+ max_eval_samples:
164
+ desc: null
165
+ value: null
166
+ max_grad_norm:
167
+ desc: null
168
+ value: 1.0
169
+ max_steps:
170
+ desc: null
171
+ value: -1
172
+ max_train_samples:
173
+ desc: null
174
+ value: null
175
+ metric_for_best_model:
176
+ desc: null
177
+ value: null
178
+ model_name_or_path:
179
+ desc: null
180
+ value: null
181
+ model_type:
182
+ desc: null
183
+ value: gpt2
184
+ mp_parameters:
185
+ desc: null
186
+ value: ''
187
+ no_cuda:
188
+ desc: null
189
+ value: false
190
+ num_train_epochs:
191
+ desc: null
192
+ value: 1.0
193
+ output_dir:
194
+ desc: null
195
+ value: ../tmp/../gpt-2-tamil/
196
+ overwrite_cache:
197
+ desc: null
198
+ value: false
199
+ overwrite_output_dir:
200
+ desc: null
201
+ value: true
202
+ past_index:
203
+ desc: null
204
+ value: -1
205
+ per_device_eval_batch_size:
206
+ desc: null
207
+ value: 64
208
+ per_device_train_batch_size:
209
+ desc: null
210
+ value: 64
211
+ per_gpu_eval_batch_size:
212
+ desc: null
213
+ value: null
214
+ per_gpu_train_batch_size:
215
+ desc: null
216
+ value: null
217
+ prediction_loss_only:
218
+ desc: null
219
+ value: false
220
+ preprocessing_num_workers:
221
+ desc: null
222
+ value: null
223
+ push_to_hub:
224
+ desc: null
225
+ value: false
226
+ push_to_hub_model_id:
227
+ desc: null
228
+ value: gpt-2-tamil
229
+ push_to_hub_organization:
230
+ desc: null
231
+ value: null
232
+ push_to_hub_token:
233
+ desc: null
234
+ value: null
235
+ remove_unused_columns:
236
+ desc: null
237
+ value: true
238
+ report_to:
239
+ desc: null
240
+ value:
241
+ - wandb
242
+ resume_from_checkpoint:
243
+ desc: null
244
+ value: null
245
+ run_name:
246
+ desc: null
247
+ value: trial
248
+ save_on_each_node:
249
+ desc: null
250
+ value: false
251
+ save_steps:
252
+ desc: null
253
+ value: 500
254
+ save_strategy:
255
+ desc: null
256
+ value: IntervalStrategy.STEPS
257
+ save_total_limit:
258
+ desc: null
259
+ value: null
260
+ seed:
261
+ desc: null
262
+ value: 42
263
+ sharded_ddp:
264
+ desc: null
265
+ value: []
266
+ skip_memory_metrics:
267
+ desc: null
268
+ value: true
269
+ tokenizer_name:
270
+ desc: null
271
+ value: ../gpt-2-tamil/
272
+ tpu_metrics_debug:
273
+ desc: null
274
+ value: false
275
+ tpu_num_cores:
276
+ desc: null
277
+ value: null
278
+ train_file:
279
+ desc: null
280
+ value: null
281
+ use_fast_tokenizer:
282
+ desc: null
283
+ value: true
284
+ use_legacy_prediction_loop:
285
+ desc: null
286
+ value: false
287
+ validation_file:
288
+ desc: null
289
+ value: null
290
+ validation_split_percentage:
291
+ desc: null
292
+ value: 5
293
+ warmup_ratio:
294
+ desc: null
295
+ value: 0.0
296
+ warmup_steps:
297
+ desc: null
298
+ value: 1000
299
+ weight_decay:
300
+ desc: null
301
+ value: 0.01
scripts/wandb/run-20210712_044248-12kjsz9i/files/events.out.tfevents.1626064970.t1v-n-ebe36c53-w-0.400773.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/tweety_abi/GPT2-Tamil/gpt-2-tamil/events.out.tfevents.1626064970.t1v-n-ebe36c53-w-0.400773.3.v2
scripts/wandb/run-20210712_044248-12kjsz9i/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83c77cb8fdf96d6479ff9c389029839beb48a924dec227d80f708b2d1f1dd66f
3
+ size 107953
scripts/wandb/run-20210712_044248-12kjsz9i/files/requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ appdirs==1.4.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.2.0
7
+ backcall==0.2.0
8
+ black==21.6b0
9
+ cachetools==4.2.2
10
+ certifi==2021.5.30
11
+ cfgv==3.3.0
12
+ chardet==4.0.0
13
+ chex==0.0.7
14
+ click==8.0.1
15
+ configparser==5.0.2
16
+ cycler==0.10.0
17
+ datasets==1.8.1.dev0
18
+ decorator==5.0.9
19
+ dill==0.3.4
20
+ distlib==0.3.2
21
+ dm-tree==0.1.6
22
+ docker-pycreds==0.4.0
23
+ filelock==3.0.12
24
+ flake8==3.9.2
25
+ flatbuffers==1.12
26
+ flax==0.3.4
27
+ fsspec==2021.6.1
28
+ gast==0.4.0
29
+ gitdb==4.0.7
30
+ gitpython==3.1.18
31
+ google-auth-oauthlib==0.4.4
32
+ google-auth==1.32.1
33
+ google-pasta==0.2.0
34
+ grpcio==1.34.1
35
+ h5py==3.1.0
36
+ huggingface-hub==0.0.12
37
+ identify==2.2.10
38
+ idna==2.10
39
+ ipython-genutils==0.2.0
40
+ ipython==7.25.0
41
+ isort==5.9.1
42
+ jax==0.2.16
43
+ jaxlib==0.1.68
44
+ jedi==0.18.0
45
+ joblib==1.0.1
46
+ keras-nightly==2.5.0.dev2021032900
47
+ keras-preprocessing==1.1.2
48
+ kiwisolver==1.3.1
49
+ libtpu-nightly==0.1.dev20210615
50
+ markdown==3.3.4
51
+ matplotlib-inline==0.1.2
52
+ matplotlib==3.4.2
53
+ mccabe==0.6.1
54
+ msgpack==1.0.2
55
+ multidict==5.1.0
56
+ multiprocess==0.70.12.2
57
+ mypy-extensions==0.4.3
58
+ nodeenv==1.6.0
59
+ numpy==1.19.5
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.0.8
63
+ packaging==20.9
64
+ pandas==1.2.5
65
+ parso==0.8.2
66
+ pathspec==0.8.1
67
+ pathtools==0.1.2
68
+ pexpect==4.8.0
69
+ pickleshare==0.7.5
70
+ pillow==8.3.0
71
+ pip==20.0.2
72
+ pkg-resources==0.0.0
73
+ pre-commit==2.13.0
74
+ promise==2.3
75
+ prompt-toolkit==3.0.19
76
+ protobuf==3.17.3
77
+ psutil==5.8.0
78
+ ptyprocess==0.7.0
79
+ pyarrow==4.0.1
80
+ pyasn1-modules==0.2.8
81
+ pyasn1==0.4.8
82
+ pycodestyle==2.7.0
83
+ pyflakes==2.3.1
84
+ pygments==2.9.0
85
+ pyparsing==2.4.7
86
+ python-dateutil==2.8.1
87
+ pytz==2021.1
88
+ pyyaml==5.4.1
89
+ regex==2021.7.1
90
+ requests-oauthlib==1.3.0
91
+ requests==2.25.1
92
+ rsa==4.7.2
93
+ sacremoses==0.0.45
94
+ scipy==1.7.0
95
+ sentry-sdk==1.3.0
96
+ setuptools==44.0.0
97
+ shortuuid==1.0.1
98
+ six==1.15.0
99
+ smmap==4.0.0
100
+ subprocess32==3.5.4
101
+ tensorboard-data-server==0.6.1
102
+ tensorboard-plugin-wit==1.8.0
103
+ tensorboard==2.5.0
104
+ tensorflow-estimator==2.5.0
105
+ tensorflow==2.5.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.10.3
108
+ toml==0.10.2
109
+ toolz==0.11.1
110
+ torch==1.9.0
111
+ tqdm==4.61.1
112
+ traitlets==5.0.5
113
+ transformers==4.9.0.dev0
114
+ typing-extensions==3.7.4.3
115
+ urllib3==1.26.6
116
+ virtualenv==20.4.7
117
+ wandb==0.10.33
118
+ wcwidth==0.2.5
119
+ werkzeug==2.0.1
120
+ wheel==0.36.2
121
+ wrapt==1.12.1
122
+ xxhash==2.0.2
123
+ yarl==1.6.3
scripts/wandb/run-20210712_044248-12kjsz9i/files/wandb-metadata.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-12T04:42:50.208592",
5
+ "startedAt": "2021-07-12T04:42:48.264668",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=../tmp/../gpt-2-tamil/",
11
+ "--model_type=gpt2",
12
+ "--config_name=../gpt-2-tamil/",
13
+ "--tokenizer_name=../gpt-2-tamil/",
14
+ "--dataset_name=oscar",
15
+ "--dataset_config_name=unshuffled_deduplicated_ta",
16
+ "--do_train",
17
+ "--do_eval",
18
+ "--block_size=512",
19
+ "--per_device_train_batch_size=64",
20
+ "--per_device_eval_batch_size=64",
21
+ "--learning_rate=3e-5",
22
+ "--warmup_steps=1000",
23
+ "--adam_beta1=0.9",
24
+ "--adam_beta2=0.98",
25
+ "--weight_decay=0.01",
26
+ "--overwrite_output_dir",
27
+ "--num_train_epochs=1",
28
+ "--report_to",
29
+ "wandb",
30
+ "--run_name",
31
+ "trial"
32
+ ],
33
+ "state": "running",
34
+ "program": "../src/run_clm_flax.py",
35
+ "codePath": "src/run_clm_flax.py",
36
+ "git": {
37
+ "remote": "https://github.com/AbinayaM02/GPT2-Tamil.git",
38
+ "commit": "a828229d00c071e9ced919095290b80e4781210e"
39
+ },
40
+ "email": "abinaya.m02@mphasis.com",
41
+ "root": "/home/tweety_abi/GPT2-Tamil",
42
+ "host": "t1v-n-ebe36c53-w-0",
43
+ "username": "tweety_abi",
44
+ "executable": "/home/tweety_abi/gpt2_env/bin/python"
45
+ }
scripts/wandb/run-20210712_044248-12kjsz9i/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
scripts/wandb/run-20210712_044248-12kjsz9i/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fcfc08cf0afaa70518eae77306ccb85256b1c75804bb875c43e69ead82eecee
3
+ size 351322
scripts/wandb/run-20210712_044248-12kjsz9i/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:629872cc9a122d0b701da9fb5b2d152a25a06d9c316ff22a307c2d5297a5f684
3
+ size 5672
scripts/wandb/run-20210712_044248-12kjsz9i/run-12kjsz9i.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77335af984a7fc18e2e903ce57acfa4f5ee5a60613ffb55a60c3e1996007f5e9
3
+ size 327917
scripts/wandb/run-20210712_164126-1cgtoi5r/files/config.yaml ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ framework: huggingface
14
+ huggingface_version: 4.9.0.dev0
15
+ is_jupyter_run: false
16
+ is_kaggle_kernel: false
17
+ python_version: 3.8.10
18
+ t:
19
+ 1:
20
+ - 1
21
+ - 3
22
+ - 11
23
+ 2:
24
+ - 1
25
+ - 3
26
+ - 11
27
+ 4: 3.8.10
28
+ 5: 0.10.33
29
+ 6: 4.9.0.dev0
30
+ 8:
31
+ - 5
32
+ adafactor:
33
+ desc: null
34
+ value: false
35
+ adam_beta1:
36
+ desc: null
37
+ value: 0.9
38
+ adam_beta2:
39
+ desc: null
40
+ value: 0.98
41
+ adam_epsilon:
42
+ desc: null
43
+ value: 1.0e-08
44
+ block_size:
45
+ desc: null
46
+ value: 512
47
+ cache_dir:
48
+ desc: null
49
+ value: null
50
+ config_name:
51
+ desc: null
52
+ value: ../gpt-2-tamil/
53
+ dataloader_drop_last:
54
+ desc: null
55
+ value: false
56
+ dataloader_num_workers:
57
+ desc: null
58
+ value: 0
59
+ dataloader_pin_memory:
60
+ desc: null
61
+ value: true
62
+ dataset_config_name:
63
+ desc: null
64
+ value: unshuffled_deduplicated_ta
65
+ dataset_name:
66
+ desc: null
67
+ value: oscar
68
+ ddp_find_unused_parameters:
69
+ desc: null
70
+ value: null
71
+ debug:
72
+ desc: null
73
+ value: []
74
+ deepspeed:
75
+ desc: null
76
+ value: null
77
+ disable_tqdm:
78
+ desc: null
79
+ value: false
80
+ do_eval:
81
+ desc: null
82
+ value: true
83
+ do_predict:
84
+ desc: null
85
+ value: false
86
+ do_train:
87
+ desc: null
88
+ value: true
89
+ dtype:
90
+ desc: null
91
+ value: float32
92
+ eval_accumulation_steps:
93
+ desc: null
94
+ value: null
95
+ eval_steps:
96
+ desc: null
97
+ value: 2500
98
+ evaluation_strategy:
99
+ desc: null
100
+ value: IntervalStrategy.NO
101
+ fp16:
102
+ desc: null
103
+ value: false
104
+ fp16_backend:
105
+ desc: null
106
+ value: auto
107
+ fp16_full_eval:
108
+ desc: null
109
+ value: false
110
+ fp16_opt_level:
111
+ desc: null
112
+ value: O1
113
+ gradient_accumulation_steps:
114
+ desc: null
115
+ value: 1
116
+ greater_is_better:
117
+ desc: null
118
+ value: null
119
+ group_by_length:
120
+ desc: null
121
+ value: false
122
+ ignore_data_skip:
123
+ desc: null
124
+ value: false
125
+ label_names:
126
+ desc: null
127
+ value: null
128
+ label_smoothing_factor:
129
+ desc: null
130
+ value: 0.0
131
+ learning_rate:
132
+ desc: null
133
+ value: 3.0e-05
134
+ length_column_name:
135
+ desc: null
136
+ value: length
137
+ load_best_model_at_end:
138
+ desc: null
139
+ value: false
140
+ local_rank:
141
+ desc: null
142
+ value: -1
143
+ log_level:
144
+ desc: null
145
+ value: -1
146
+ log_level_replica:
147
+ desc: null
148
+ value: -1
149
+ log_on_each_node:
150
+ desc: null
151
+ value: true
152
+ logging_dir:
153
+ desc: null
154
+ value: ../gpt-2-tamil/runs/Jul12_16-26-59_t1v-n-ebe36c53-w-0
155
+ logging_first_step:
156
+ desc: null
157
+ value: false
158
+ logging_steps:
159
+ desc: null
160
+ value: 500
161
+ logging_strategy:
162
+ desc: null
163
+ value: IntervalStrategy.STEPS
164
+ lr_scheduler_type:
165
+ desc: null
166
+ value: SchedulerType.LINEAR
167
+ max_eval_samples:
168
+ desc: null
169
+ value: null
170
+ max_grad_norm:
171
+ desc: null
172
+ value: 1.0
173
+ max_steps:
174
+ desc: null
175
+ value: -1
176
+ max_train_samples:
177
+ desc: null
178
+ value: null
179
+ metric_for_best_model:
180
+ desc: null
181
+ value: null
182
+ model_name_or_path:
183
+ desc: null
184
+ value: null
185
+ model_type:
186
+ desc: null
187
+ value: gpt2
188
+ mp_parameters:
189
+ desc: null
190
+ value: ''
191
+ no_cuda:
192
+ desc: null
193
+ value: false
194
+ num_train_epochs:
195
+ desc: null
196
+ value: 1.0
197
+ output_dir:
198
+ desc: null
199
+ value: ../gpt-2-tamil/
200
+ overwrite_cache:
201
+ desc: null
202
+ value: false
203
+ overwrite_output_dir:
204
+ desc: null
205
+ value: true
206
+ past_index:
207
+ desc: null
208
+ value: -1
209
+ per_device_eval_batch_size:
210
+ desc: null
211
+ value: 64
212
+ per_device_train_batch_size:
213
+ desc: null
214
+ value: 64
215
+ per_gpu_eval_batch_size:
216
+ desc: null
217
+ value: null
218
+ per_gpu_train_batch_size:
219
+ desc: null
220
+ value: null
221
+ prediction_loss_only:
222
+ desc: null
223
+ value: false
224
+ preprocessing_num_workers:
225
+ desc: null
226
+ value: 90
227
+ push_to_hub:
228
+ desc: null
229
+ value: false
230
+ push_to_hub_model_id:
231
+ desc: null
232
+ value: gpt-2-tamil
233
+ push_to_hub_organization:
234
+ desc: null
235
+ value: null
236
+ push_to_hub_token:
237
+ desc: null
238
+ value: null
239
+ remove_unused_columns:
240
+ desc: null
241
+ value: true
242
+ report_to:
243
+ desc: null
244
+ value:
245
+ - wandb
246
+ resume_from_checkpoint:
247
+ desc: null
248
+ value: null
249
+ run_name:
250
+ desc: null
251
+ value: trial
252
+ save_on_each_node:
253
+ desc: null
254
+ value: false
255
+ save_steps:
256
+ desc: null
257
+ value: 2500
258
+ save_strategy:
259
+ desc: null
260
+ value: IntervalStrategy.STEPS
261
+ save_total_limit:
262
+ desc: null
263
+ value: null
264
+ seed:
265
+ desc: null
266
+ value: 42
267
+ sharded_ddp:
268
+ desc: null
269
+ value: []
270
+ skip_memory_metrics:
271
+ desc: null
272
+ value: true
273
+ tokenizer_name:
274
+ desc: null
275
+ value: ../gpt-2-tamil/
276
+ tpu_metrics_debug:
277
+ desc: null
278
+ value: false
279
+ tpu_num_cores:
280
+ desc: null
281
+ value: null
282
+ train_file:
283
+ desc: null
284
+ value: null
285
+ use_fast_tokenizer:
286
+ desc: null
287
+ value: true
288
+ use_legacy_prediction_loop:
289
+ desc: null
290
+ value: false
291
+ validation_file:
292
+ desc: null
293
+ value: null
294
+ validation_split_percentage:
295
+ desc: null
296
+ value: 5
297
+ warmup_ratio:
298
+ desc: null
299
+ value: 0.0
300
+ warmup_steps:
301
+ desc: null
302
+ value: 1000
303
+ weight_decay:
304
+ desc: null
305
+ value: 0.01
scripts/wandb/run-20210712_164126-1cgtoi5r/files/events.out.tfevents.1626108088.t1v-n-ebe36c53-w-0.483452.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/tweety_abi/GPT2-Tamil/gpt-2-tamil/events.out.tfevents.1626108088.t1v-n-ebe36c53-w-0.483452.3.v2
scripts/wandb/run-20210712_164126-1cgtoi5r/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0919562e4d7c8bbdebb2074976c553f66727900bc770c5dd9d8041e3a008931
3
+ size 2408
scripts/wandb/run-20210712_164126-1cgtoi5r/files/requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ appdirs==1.4.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.2.0
7
+ backcall==0.2.0
8
+ black==21.6b0
9
+ cachetools==4.2.2
10
+ certifi==2021.5.30
11
+ cfgv==3.3.0
12
+ chardet==4.0.0
13
+ chex==0.0.7
14
+ click==8.0.1
15
+ configparser==5.0.2
16
+ cycler==0.10.0
17
+ datasets==1.8.1.dev0
18
+ decorator==5.0.9
19
+ dill==0.3.4
20
+ distlib==0.3.2
21
+ dm-tree==0.1.6
22
+ docker-pycreds==0.4.0
23
+ filelock==3.0.12
24
+ flake8==3.9.2
25
+ flatbuffers==1.12
26
+ flax==0.3.4
27
+ fsspec==2021.6.1
28
+ gast==0.4.0
29
+ gitdb==4.0.7
30
+ gitpython==3.1.18
31
+ google-auth-oauthlib==0.4.4
32
+ google-auth==1.32.1
33
+ google-pasta==0.2.0
34
+ grpcio==1.34.1
35
+ h5py==3.1.0
36
+ huggingface-hub==0.0.12
37
+ identify==2.2.10
38
+ idna==2.10
39
+ ipython-genutils==0.2.0
40
+ ipython==7.25.0
41
+ isort==5.9.1
42
+ jax==0.2.16
43
+ jaxlib==0.1.68
44
+ jedi==0.18.0
45
+ joblib==1.0.1
46
+ keras-nightly==2.5.0.dev2021032900
47
+ keras-preprocessing==1.1.2
48
+ kiwisolver==1.3.1
49
+ libtpu-nightly==0.1.dev20210615
50
+ markdown==3.3.4
51
+ matplotlib-inline==0.1.2
52
+ matplotlib==3.4.2
53
+ mccabe==0.6.1
54
+ msgpack==1.0.2
55
+ multidict==5.1.0
56
+ multiprocess==0.70.12.2
57
+ mypy-extensions==0.4.3
58
+ nodeenv==1.6.0
59
+ numpy==1.19.5
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.0.8
63
+ packaging==20.9
64
+ pandas==1.2.5
65
+ parso==0.8.2
66
+ pathspec==0.8.1
67
+ pathtools==0.1.2
68
+ pexpect==4.8.0
69
+ pickleshare==0.7.5
70
+ pillow==8.3.0
71
+ pip==20.0.2
72
+ pkg-resources==0.0.0
73
+ pre-commit==2.13.0
74
+ promise==2.3
75
+ prompt-toolkit==3.0.19
76
+ protobuf==3.17.3
77
+ psutil==5.8.0
78
+ ptyprocess==0.7.0
79
+ pyarrow==4.0.1
80
+ pyasn1-modules==0.2.8
81
+ pyasn1==0.4.8
82
+ pycodestyle==2.7.0
83
+ pyflakes==2.3.1
84
+ pygments==2.9.0
85
+ pyparsing==2.4.7
86
+ python-dateutil==2.8.1
87
+ pytz==2021.1
88
+ pyyaml==5.4.1
89
+ regex==2021.7.1
90
+ requests-oauthlib==1.3.0
91
+ requests==2.25.1
92
+ rsa==4.7.2
93
+ sacremoses==0.0.45
94
+ scipy==1.7.0
95
+ sentry-sdk==1.3.0
96
+ setuptools==44.0.0
97
+ shortuuid==1.0.1
98
+ six==1.15.0
99
+ smmap==4.0.0
100
+ subprocess32==3.5.4
101
+ tensorboard-data-server==0.6.1
102
+ tensorboard-plugin-wit==1.8.0
103
+ tensorboard==2.5.0
104
+ tensorflow-estimator==2.5.0
105
+ tensorflow==2.5.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.10.3
108
+ toml==0.10.2
109
+ toolz==0.11.1
110
+ torch==1.9.0
111
+ tqdm==4.61.1
112
+ traitlets==5.0.5
113
+ transformers==4.9.0.dev0
114
+ typing-extensions==3.7.4.3
115
+ urllib3==1.26.6
116
+ virtualenv==20.4.7
117
+ wandb==0.10.33
118
+ wcwidth==0.2.5
119
+ werkzeug==2.0.1
120
+ wheel==0.36.2
121
+ wrapt==1.12.1
122
+ xxhash==2.0.2
123
+ yarl==1.6.3
scripts/wandb/run-20210712_164126-1cgtoi5r/files/wandb-metadata.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-12T16:41:28.249908",
5
+ "startedAt": "2021-07-12T16:41:26.246514",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=../gpt-2-tamil/",
11
+ "--model_type=gpt2",
12
+ "--config_name=../gpt-2-tamil/",
13
+ "--tokenizer_name=../gpt-2-tamil/",
14
+ "--dataset_name=oscar",
15
+ "--dataset_config_name=unshuffled_deduplicated_ta",
16
+ "--do_train",
17
+ "--do_eval",
18
+ "--block_size=512",
19
+ "--per_device_train_batch_size=64",
20
+ "--per_device_eval_batch_size=64",
21
+ "--learning_rate=3e-5",
22
+ "--warmup_steps=1000",
23
+ "--adam_beta1=0.9",
24
+ "--adam_beta2=0.98",
25
+ "--weight_decay=0.01",
26
+ "--overwrite_output_dir",
27
+ "--num_train_epochs=1",
28
+ "--report_to",
29
+ "wandb",
30
+ "--run_name",
31
+ "trial",
32
+ "--logging_steps=500",
33
+ "--save_steps=2500",
34
+ "--eval_steps=2500",
35
+ "--preprocessing_num_workers=90"
36
+ ],
37
+ "state": "running",
38
+ "program": "../src/run_clm_flax.py",
39
+ "codePath": "src/run_clm_flax.py",
40
+ "git": {
41
+ "remote": "https://github.com/AbinayaM02/GPT2-Tamil.git",
42
+ "commit": "a828229d00c071e9ced919095290b80e4781210e"
43
+ },
44
+ "email": "abinaya.m02@mphasis.com",
45
+ "root": "/home/tweety_abi/GPT2-Tamil",
46
+ "host": "t1v-n-ebe36c53-w-0",
47
+ "username": "tweety_abi",
48
+ "executable": "/home/tweety_abi/gpt2_env/bin/python"
49
+ }
scripts/wandb/run-20210712_164126-1cgtoi5r/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
scripts/wandb/run-20210712_164126-1cgtoi5r/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5042f20446f607f9de1da5328a848218a444a160203dbbe620844d138a5f041
3
+ size 28874
scripts/wandb/run-20210712_164126-1cgtoi5r/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b642e86c06ccde2fecec5b035f7154289468f3126e152b9170e2e8d35b655328
3
+ size 7649
scripts/wandb/run-20210712_164126-1cgtoi5r/run-1cgtoi5r.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1302a66f7718be000865df304f0b21d1471bf7bb1cc55e4aba7e50d6f051725d
3
+ size 15809
scripts/wandb/run-20210712_164633-1ddv4131/files/config.yaml ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ framework: huggingface
14
+ huggingface_version: 4.9.0.dev0
15
+ is_jupyter_run: false
16
+ is_kaggle_kernel: false
17
+ python_version: 3.8.10
18
+ t:
19
+ 1:
20
+ - 1
21
+ - 3
22
+ - 11
23
+ 2:
24
+ - 1
25
+ - 3
26
+ - 11
27
+ 4: 3.8.10
28
+ 5: 0.10.33
29
+ 6: 4.9.0.dev0
30
+ 8:
31
+ - 5
32
+ adafactor:
33
+ desc: null
34
+ value: false
35
+ adam_beta1:
36
+ desc: null
37
+ value: 0.9
38
+ adam_beta2:
39
+ desc: null
40
+ value: 0.98
41
+ adam_epsilon:
42
+ desc: null
43
+ value: 1.0e-08
44
+ block_size:
45
+ desc: null
46
+ value: 512
47
+ cache_dir:
48
+ desc: null
49
+ value: null
50
+ config_name:
51
+ desc: null
52
+ value: ../gpt-2-tamil/
53
+ dataloader_drop_last:
54
+ desc: null
55
+ value: false
56
+ dataloader_num_workers:
57
+ desc: null
58
+ value: 0
59
+ dataloader_pin_memory:
60
+ desc: null
61
+ value: true
62
+ dataset_config_name:
63
+ desc: null
64
+ value: unshuffled_deduplicated_ta
65
+ dataset_name:
66
+ desc: null
67
+ value: oscar
68
+ ddp_find_unused_parameters:
69
+ desc: null
70
+ value: null
71
+ debug:
72
+ desc: null
73
+ value: []
74
+ deepspeed:
75
+ desc: null
76
+ value: null
77
+ disable_tqdm:
78
+ desc: null
79
+ value: false
80
+ do_eval:
81
+ desc: null
82
+ value: true
83
+ do_predict:
84
+ desc: null
85
+ value: false
86
+ do_train:
87
+ desc: null
88
+ value: true
89
+ dtype:
90
+ desc: null
91
+ value: float32
92
+ eval_accumulation_steps:
93
+ desc: null
94
+ value: null
95
+ eval_steps:
96
+ desc: null
97
+ value: 2500
98
+ evaluation_strategy:
99
+ desc: null
100
+ value: IntervalStrategy.NO
101
+ fp16:
102
+ desc: null
103
+ value: false
104
+ fp16_backend:
105
+ desc: null
106
+ value: auto
107
+ fp16_full_eval:
108
+ desc: null
109
+ value: false
110
+ fp16_opt_level:
111
+ desc: null
112
+ value: O1
113
+ gradient_accumulation_steps:
114
+ desc: null
115
+ value: 1
116
+ greater_is_better:
117
+ desc: null
118
+ value: null
119
+ group_by_length:
120
+ desc: null
121
+ value: false
122
+ ignore_data_skip:
123
+ desc: null
124
+ value: false
125
+ label_names:
126
+ desc: null
127
+ value: null
128
+ label_smoothing_factor:
129
+ desc: null
130
+ value: 0.0
131
+ learning_rate:
132
+ desc: null
133
+ value: 3.0e-05
134
+ length_column_name:
135
+ desc: null
136
+ value: length
137
+ load_best_model_at_end:
138
+ desc: null
139
+ value: false
140
+ local_rank:
141
+ desc: null
142
+ value: -1
143
+ log_level:
144
+ desc: null
145
+ value: -1
146
+ log_level_replica:
147
+ desc: null
148
+ value: -1
149
+ log_on_each_node:
150
+ desc: null
151
+ value: true
152
+ logging_dir:
153
+ desc: null
154
+ value: ../gpt-2-tamil/runs/Jul12_16-45-48_t1v-n-ebe36c53-w-0
155
+ logging_first_step:
156
+ desc: null
157
+ value: false
158
+ logging_steps:
159
+ desc: null
160
+ value: 500
161
+ logging_strategy:
162
+ desc: null
163
+ value: IntervalStrategy.STEPS
164
+ lr_scheduler_type:
165
+ desc: null
166
+ value: SchedulerType.LINEAR
167
+ max_eval_samples:
168
+ desc: null
169
+ value: null
170
+ max_grad_norm:
171
+ desc: null
172
+ value: 1.0
173
+ max_steps:
174
+ desc: null
175
+ value: -1
176
+ max_train_samples:
177
+ desc: null
178
+ value: null
179
+ metric_for_best_model:
180
+ desc: null
181
+ value: null
182
+ model_name_or_path:
183
+ desc: null
184
+ value: null
185
+ model_type:
186
+ desc: null
187
+ value: gpt2
188
+ mp_parameters:
189
+ desc: null
190
+ value: ''
191
+ no_cuda:
192
+ desc: null
193
+ value: false
194
+ num_train_epochs:
195
+ desc: null
196
+ value: 25.0
197
+ output_dir:
198
+ desc: null
199
+ value: ../gpt-2-tamil/
200
+ overwrite_cache:
201
+ desc: null
202
+ value: false
203
+ overwrite_output_dir:
204
+ desc: null
205
+ value: true
206
+ past_index:
207
+ desc: null
208
+ value: -1
209
+ per_device_eval_batch_size:
210
+ desc: null
211
+ value: 64
212
+ per_device_train_batch_size:
213
+ desc: null
214
+ value: 64
215
+ per_gpu_eval_batch_size:
216
+ desc: null
217
+ value: null
218
+ per_gpu_train_batch_size:
219
+ desc: null
220
+ value: null
221
+ prediction_loss_only:
222
+ desc: null
223
+ value: false
224
+ preprocessing_num_workers:
225
+ desc: null
226
+ value: 90
227
+ push_to_hub:
228
+ desc: null
229
+ value: false
230
+ push_to_hub_model_id:
231
+ desc: null
232
+ value: gpt-2-tamil
233
+ push_to_hub_organization:
234
+ desc: null
235
+ value: null
236
+ push_to_hub_token:
237
+ desc: null
238
+ value: null
239
+ remove_unused_columns:
240
+ desc: null
241
+ value: true
242
+ report_to:
243
+ desc: null
244
+ value:
245
+ - wandb
246
+ resume_from_checkpoint:
247
+ desc: null
248
+ value: null
249
+ run_name:
250
+ desc: null
251
+ value: trial
252
+ save_on_each_node:
253
+ desc: null
254
+ value: false
255
+ save_steps:
256
+ desc: null
257
+ value: 2500
258
+ save_strategy:
259
+ desc: null
260
+ value: IntervalStrategy.STEPS
261
+ save_total_limit:
262
+ desc: null
263
+ value: null
264
+ seed:
265
+ desc: null
266
+ value: 42
267
+ sharded_ddp:
268
+ desc: null
269
+ value: []
270
+ skip_memory_metrics:
271
+ desc: null
272
+ value: true
273
+ tokenizer_name:
274
+ desc: null
275
+ value: ../gpt-2-tamil/
276
+ tpu_metrics_debug:
277
+ desc: null
278
+ value: false
279
+ tpu_num_cores:
280
+ desc: null
281
+ value: null
282
+ train_file:
283
+ desc: null
284
+ value: null
285
+ use_fast_tokenizer:
286
+ desc: null
287
+ value: true
288
+ use_legacy_prediction_loop:
289
+ desc: null
290
+ value: false
291
+ validation_file:
292
+ desc: null
293
+ value: null
294
+ validation_split_percentage:
295
+ desc: null
296
+ value: 5
297
+ warmup_ratio:
298
+ desc: null
299
+ value: 0.0
300
+ warmup_steps:
301
+ desc: null
302
+ value: 1000
303
+ weight_decay:
304
+ desc: null
305
+ value: 0.01
scripts/wandb/run-20210712_164633-1ddv4131/files/events.out.tfevents.1626108395.t1v-n-ebe36c53-w-0.486342.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /home/tweety_abi/GPT2-Tamil/gpt-2-tamil/events.out.tfevents.1626108395.t1v-n-ebe36c53-w-0.486342.3.v2
scripts/wandb/run-20210712_164633-1ddv4131/files/output.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d9761e7442f5b6b99224ee68ee38f3b7e486ead79f4e390bf7e258dc16de973
3
+ size 4407657
scripts/wandb/run-20210712_164633-1ddv4131/files/requirements.txt ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ appdirs==1.4.4
4
+ astunparse==1.6.3
5
+ async-timeout==3.0.1
6
+ attrs==21.2.0
7
+ backcall==0.2.0
8
+ black==21.6b0
9
+ cachetools==4.2.2
10
+ certifi==2021.5.30
11
+ cfgv==3.3.0
12
+ chardet==4.0.0
13
+ chex==0.0.7
14
+ click==8.0.1
15
+ configparser==5.0.2
16
+ cycler==0.10.0
17
+ datasets==1.8.1.dev0
18
+ decorator==5.0.9
19
+ dill==0.3.4
20
+ distlib==0.3.2
21
+ dm-tree==0.1.6
22
+ docker-pycreds==0.4.0
23
+ filelock==3.0.12
24
+ flake8==3.9.2
25
+ flatbuffers==1.12
26
+ flax==0.3.4
27
+ fsspec==2021.6.1
28
+ gast==0.4.0
29
+ gitdb==4.0.7
30
+ gitpython==3.1.18
31
+ google-auth-oauthlib==0.4.4
32
+ google-auth==1.32.1
33
+ google-pasta==0.2.0
34
+ grpcio==1.34.1
35
+ h5py==3.1.0
36
+ huggingface-hub==0.0.12
37
+ identify==2.2.10
38
+ idna==2.10
39
+ ipython-genutils==0.2.0
40
+ ipython==7.25.0
41
+ isort==5.9.1
42
+ jax==0.2.16
43
+ jaxlib==0.1.68
44
+ jedi==0.18.0
45
+ joblib==1.0.1
46
+ keras-nightly==2.5.0.dev2021032900
47
+ keras-preprocessing==1.1.2
48
+ kiwisolver==1.3.1
49
+ libtpu-nightly==0.1.dev20210615
50
+ markdown==3.3.4
51
+ matplotlib-inline==0.1.2
52
+ matplotlib==3.4.2
53
+ mccabe==0.6.1
54
+ msgpack==1.0.2
55
+ multidict==5.1.0
56
+ multiprocess==0.70.12.2
57
+ mypy-extensions==0.4.3
58
+ nodeenv==1.6.0
59
+ numpy==1.19.5
60
+ oauthlib==3.1.1
61
+ opt-einsum==3.3.0
62
+ optax==0.0.8
63
+ packaging==20.9
64
+ pandas==1.2.5
65
+ parso==0.8.2
66
+ pathspec==0.8.1
67
+ pathtools==0.1.2
68
+ pexpect==4.8.0
69
+ pickleshare==0.7.5
70
+ pillow==8.3.0
71
+ pip==20.0.2
72
+ pkg-resources==0.0.0
73
+ pre-commit==2.13.0
74
+ promise==2.3
75
+ prompt-toolkit==3.0.19
76
+ protobuf==3.17.3
77
+ psutil==5.8.0
78
+ ptyprocess==0.7.0
79
+ pyarrow==4.0.1
80
+ pyasn1-modules==0.2.8
81
+ pyasn1==0.4.8
82
+ pycodestyle==2.7.0
83
+ pyflakes==2.3.1
84
+ pygments==2.9.0
85
+ pyparsing==2.4.7
86
+ python-dateutil==2.8.1
87
+ pytz==2021.1
88
+ pyyaml==5.4.1
89
+ regex==2021.7.1
90
+ requests-oauthlib==1.3.0
91
+ requests==2.25.1
92
+ rsa==4.7.2
93
+ sacremoses==0.0.45
94
+ scipy==1.7.0
95
+ sentry-sdk==1.3.0
96
+ setuptools==44.0.0
97
+ shortuuid==1.0.1
98
+ six==1.15.0
99
+ smmap==4.0.0
100
+ subprocess32==3.5.4
101
+ tensorboard-data-server==0.6.1
102
+ tensorboard-plugin-wit==1.8.0
103
+ tensorboard==2.5.0
104
+ tensorflow-estimator==2.5.0
105
+ tensorflow==2.5.0
106
+ termcolor==1.1.0
107
+ tokenizers==0.10.3
108
+ toml==0.10.2
109
+ toolz==0.11.1
110
+ torch==1.9.0
111
+ tqdm==4.61.1
112
+ traitlets==5.0.5
113
+ transformers==4.9.0.dev0
114
+ typing-extensions==3.7.4.3
115
+ urllib3==1.26.6
116
+ virtualenv==20.4.7
117
+ wandb==0.10.33
118
+ wcwidth==0.2.5
119
+ werkzeug==2.0.1
120
+ wheel==0.36.2
121
+ wrapt==1.12.1
122
+ xxhash==2.0.2
123
+ yarl==1.6.3
scripts/wandb/run-20210712_164633-1ddv4131/files/wandb-metadata.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-12T16:46:35.350252",
5
+ "startedAt": "2021-07-12T16:46:33.416306",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--output_dir=../gpt-2-tamil/",
11
+ "--model_type=gpt2",
12
+ "--config_name=../gpt-2-tamil/",
13
+ "--tokenizer_name=../gpt-2-tamil/",
14
+ "--dataset_name=oscar",
15
+ "--dataset_config_name=unshuffled_deduplicated_ta",
16
+ "--do_train",
17
+ "--do_eval",
18
+ "--block_size=512",
19
+ "--per_device_train_batch_size=64",
20
+ "--per_device_eval_batch_size=64",
21
+ "--learning_rate=3e-5",
22
+ "--warmup_steps=1000",
23
+ "--adam_beta1=0.9",
24
+ "--adam_beta2=0.98",
25
+ "--weight_decay=0.01",
26
+ "--overwrite_output_dir",
27
+ "--num_train_epochs=25",
28
+ "--report_to",
29
+ "wandb",
30
+ "--run_name",
31
+ "trial",
32
+ "--logging_steps=500",
33
+ "--save_steps=2500",
34
+ "--eval_steps=2500",
35
+ "--preprocessing_num_workers=90"
36
+ ],
37
+ "state": "running",
38
+ "program": "../src/run_clm_flax.py",
39
+ "codePath": "src/run_clm_flax.py",
40
+ "git": {
41
+ "remote": "https://github.com/AbinayaM02/GPT2-Tamil.git",
42
+ "commit": "5d59c6a635e952a0f51ef33ed713960a04e9dcb6"
43
+ },
44
+ "email": "abinaya.m02@mphasis.com",
45
+ "root": "/home/tweety_abi/GPT2-Tamil",
46
+ "host": "t1v-n-ebe36c53-w-0",
47
+ "username": "tweety_abi",
48
+ "executable": "/home/tweety_abi/gpt2_env/bin/python"
49
+ }
scripts/wandb/run-20210712_164633-1ddv4131/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"global_step": 132500, "_timestamp": 1626248099.379086, "train_time": 743654.875, "train_learning_rate": 1.1402963906448349e-08, "_step": 264206, "train_loss": 1.1299134492874146, "eval_loss": 1.1545542478561401, "eval_perplexity": 3.1726088523864746}
scripts/wandb/run-20210712_164633-1ddv4131/logs/debug-internal.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:748fffc8fe7bbd39d404a1bae61d5711a3e098491142dc28b16d5d75e32dc937
3
+ size 97283434
scripts/wandb/run-20210712_164633-1ddv4131/logs/debug.log ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36e11410ff19af1db092231a1450397dffb80ef21248540b6d372dcf5606559c
3
+ size 8797
scripts/wandb/run-20210712_164633-1ddv4131/run-1ddv4131.wandb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8211487b4d0a0489ae4728120abad1be7ee4190520afc47fdae166087ae6068
3
+ size 60817322
src/create_config.py CHANGED
@@ -1,6 +1,6 @@
1
  from transformers import GPT2Config
2
 
3
- model_dir = "./gpt2-tamil" # ${MODEL_DIR}
4
 
5
  config = GPT2Config.from_pretrained(
6
  "gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0
 
1
  from transformers import GPT2Config
2
 
3
+ model_dir = "../gpt-2-tamil" # ${MODEL_DIR}
4
 
5
  config = GPT2Config.from_pretrained(
6
  "gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0
src/run_clm_flax.py CHANGED
@@ -31,16 +31,18 @@ from pathlib import Path
31
  from typing import Callable, Optional
32
 
33
  import datasets
 
 
 
34
  import jax
35
  import jax.numpy as jnp
36
  import optax
37
  import transformers
38
- from datasets import Dataset, load_dataset
39
  from flax import jax_utils, traverse_util
40
  from flax.jax_utils import unreplicate
41
  from flax.training import train_state
42
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
43
- from tqdm import tqdm
44
  from transformers import (
45
  CONFIG_MAPPING,
46
  FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
@@ -53,25 +55,8 @@ from transformers import (
53
  )
54
  from transformers.testing_utils import CaptureLogger
55
 
56
- logger = logging.getLogger(__name__)
57
-
58
- # Cache the result
59
- has_tensorboard = is_tensorboard_available()
60
- if has_tensorboard:
61
- try:
62
- from flax.metrics.tensorboard import SummaryWriter
63
- except ImportError as ie:
64
- has_tensorboard = False
65
- print(
66
- f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
67
- )
68
-
69
- else:
70
- print(
71
- "Unable to display metrics through TensorBoard because the package is not installed: "
72
- "Please run pip install tensorboard to enable."
73
- )
74
 
 
75
 
76
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
77
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
@@ -92,34 +77,20 @@ class ModelArguments:
92
  )
93
  model_type: Optional[str] = field(
94
  default=None,
95
- metadata={
96
- "help": "If training from scratch, pass a model type from the list: "
97
- + ", ".join(MODEL_TYPES)
98
- },
99
  )
100
  config_name: Optional[str] = field(
101
- default=None,
102
- metadata={
103
- "help": "Pretrained config name or path if not the same as model_name"
104
- },
105
  )
106
  tokenizer_name: Optional[str] = field(
107
- default=None,
108
- metadata={
109
- "help": "Pretrained tokenizer name or path if not the same as model_name"
110
- },
111
  )
112
  cache_dir: Optional[str] = field(
113
- default=None,
114
- metadata={
115
- "help": "Where do you want to store the pretrained models downloaded from s3"
116
- },
117
  )
118
  use_fast_tokenizer: bool = field(
119
  default=True,
120
- metadata={
121
- "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
122
- },
123
  )
124
  dtype: Optional[str] = field(
125
  default="float32",
@@ -136,26 +107,15 @@ class DataTrainingArguments:
136
  """
137
 
138
  dataset_name: Optional[str] = field(
139
- default=None,
140
- metadata={
141
- "help": "The name of the dataset to use (via the datasets library)."
142
- },
143
  )
144
  dataset_config_name: Optional[str] = field(
145
- default=None,
146
- metadata={
147
- "help": "The configuration name of the dataset to use (via the datasets library)."
148
- },
149
- )
150
- train_file: Optional[str] = field(
151
- default=None,
152
- metadata={"help": "The input training data file (a text file)."},
153
  )
 
154
  validation_file: Optional[str] = field(
155
  default=None,
156
- metadata={
157
- "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
158
- },
159
  )
160
  max_train_samples: Optional[int] = field(
161
  default=None,
@@ -172,8 +132,7 @@ class DataTrainingArguments:
172
  },
173
  )
174
  overwrite_cache: bool = field(
175
- default=False,
176
- metadata={"help": "Overwrite the cached training and evaluation sets"},
177
  )
178
  validation_split_percentage: Optional[int] = field(
179
  default=5,
@@ -190,8 +149,7 @@ class DataTrainingArguments:
190
  },
191
  )
192
  overwrite_cache: bool = field(
193
- default=False,
194
- metadata={"help": "Overwrite the cached training and evaluation sets"},
195
  )
196
  preprocessing_num_workers: Optional[int] = field(
197
  default=None,
@@ -199,43 +157,25 @@ class DataTrainingArguments:
199
  )
200
 
201
  def __post_init__(self):
202
- if (
203
- self.dataset_name is None
204
- and self.train_file is None
205
- and self.validation_file is None
206
- ):
207
- raise ValueError(
208
- "Need either a dataset name or a training/validation file."
209
- )
210
  else:
211
  if self.train_file is not None:
212
  extension = self.train_file.split(".")[-1]
213
- assert extension in [
214
- "csv",
215
- "json",
216
- "txt",
217
- ], "`train_file` should be a csv, a json or a txt file."
218
  if self.validation_file is not None:
219
  extension = self.validation_file.split(".")[-1]
220
- assert extension in [
221
- "csv",
222
- "json",
223
- "txt",
224
- ], "`validation_file` should be a csv, a json or a txt file."
225
 
226
 
227
  class TrainState(train_state.TrainState):
228
  dropout_rng: jnp.ndarray
229
 
230
  def replicate(self):
231
- return jax_utils.replicate(self).replace(
232
- dropout_rng=shard_prng_key(self.dropout_rng)
233
- )
234
 
235
 
236
- def data_loader(
237
- rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False
238
- ):
239
  """
240
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
241
  Shuffle batches if `shuffle` is `True`.
@@ -259,7 +199,7 @@ def data_loader(
259
  yield batch
260
 
261
 
262
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
263
  summary_writer.scalar("train_time", train_time, step)
264
 
265
  train_metrics = get_metrics(train_metrics)
@@ -268,31 +208,23 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
268
  for i, val in enumerate(vals):
269
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
270
 
 
 
271
  for metric_name, value in eval_metrics.items():
272
  summary_writer.scalar(f"eval_{metric_name}", value, step)
273
 
274
 
275
  def create_learning_rate_fn(
276
- train_ds_size: int,
277
- train_batch_size: int,
278
- num_train_epochs: int,
279
- num_warmup_steps: int,
280
- learning_rate: float,
281
  ) -> Callable[[int], jnp.array]:
282
  """Returns a linear warmup, linear_decay learning rate function."""
283
  steps_per_epoch = train_ds_size // train_batch_size
284
  num_train_steps = steps_per_epoch * num_train_epochs
285
- warmup_fn = optax.linear_schedule(
286
- init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
287
- )
288
  decay_fn = optax.linear_schedule(
289
- init_value=learning_rate,
290
- end_value=0,
291
- transition_steps=num_train_steps - num_warmup_steps,
292
- )
293
- schedule_fn = optax.join_schedules(
294
- schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
295
  )
 
296
  return schedule_fn
297
 
298
 
@@ -301,15 +233,11 @@ def main():
301
  # or by passing the --help flag to this script.
302
  # We now keep distinct sets of args, for a cleaner separation of concerns.
303
 
304
- parser = HfArgumentParser(
305
- (ModelArguments, DataTrainingArguments, TrainingArguments)
306
- )
307
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
308
  # If we pass only one argument to the script and it's the path to a json file,
309
  # let's parse it to get our arguments.
310
- model_args, data_args, training_args = parser.parse_json_file(
311
- json_file=os.path.abspath(sys.argv[1])
312
- )
313
  else:
314
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
315
 
@@ -351,14 +279,10 @@ def main():
351
  #
352
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
353
  # download the dataset.
354
- logger.info("Loading dataset....")
355
  if data_args.dataset_name is not None:
356
  # Downloading and loading a dataset from the hub.
357
  dataset = load_dataset(
358
- data_args.dataset_name,
359
- data_args.dataset_config_name,
360
- cache_dir=model_args.cache_dir,
361
- keep_in_memory=False,
362
  )
363
 
364
  if "validation" not in dataset.keys():
@@ -383,10 +307,7 @@ def main():
383
  extension = data_args.train_file.split(".")[-1]
384
  if extension == "txt":
385
  extension = "text"
386
- logger.info(f"Loading dataset....{data_args.train_file}")
387
- dataset = load_dataset(
388
- extension, data_files=data_files, cache_dir=model_args.cache_dir
389
- )
390
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
391
  # https://huggingface.co/docs/datasets/loading_datasets.html.
392
 
@@ -396,28 +317,20 @@ def main():
396
  # The .from_pretrained methods guarantee that only one local process can concurrently
397
  # download model & vocab.
398
  if model_args.config_name:
399
- config = AutoConfig.from_pretrained(
400
- model_args.config_name, cache_dir=model_args.cache_dir
401
- )
402
  elif model_args.model_name_or_path:
403
- config = AutoConfig.from_pretrained(
404
- model_args.model_name_or_path, cache_dir=model_args.cache_dir
405
- )
406
  else:
407
  config = CONFIG_MAPPING[model_args.model_type]()
408
  logger.warning("You are instantiating a new config instance from scratch.")
409
 
410
  if model_args.tokenizer_name:
411
  tokenizer = AutoTokenizer.from_pretrained(
412
- model_args.tokenizer_name,
413
- cache_dir=model_args.cache_dir,
414
- use_fast=model_args.use_fast_tokenizer,
415
  )
416
  elif model_args.model_name_or_path:
417
  tokenizer = AutoTokenizer.from_pretrained(
418
- model_args.model_name_or_path,
419
- cache_dir=model_args.cache_dir,
420
- use_fast=model_args.use_fast_tokenizer,
421
  )
422
  else:
423
  raise ValueError(
@@ -427,10 +340,7 @@ def main():
427
 
428
  if model_args.model_name_or_path:
429
  model = FlaxAutoModelForCausalLM.from_pretrained(
430
- model_args.model_name_or_path,
431
- config=config,
432
- seed=training_args.seed,
433
- dtype=getattr(jnp, model_args.dtype),
434
  )
435
  else:
436
  model = FlaxAutoModelForCausalLM.from_config(
@@ -446,9 +356,7 @@ def main():
446
  text_column_name = "text" if "text" in column_names else column_names[0]
447
 
448
  # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
449
- tok_logger = transformers.utils.logging.get_logger(
450
- "transformers.tokenization_utils_base"
451
- )
452
 
453
  def tokenize_function(examples):
454
  with CaptureLogger(tok_logger) as cl:
@@ -491,7 +399,8 @@ def main():
491
  total_length = len(concatenated_examples[list(examples.keys())[0]])
492
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
493
  # customize this part to your needs.
494
- total_length = (total_length // block_size) * block_size
 
495
  # Split by chunks of max_len.
496
  result = {
497
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
@@ -529,8 +438,32 @@ def main():
529
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
530
 
531
  # Enable tensorboard only on the master node
 
532
  if has_tensorboard and jax.process_index() == 0:
533
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
  # Initialize our training
536
  rng = jax.random.PRNGKey(training_args.seed)
@@ -538,12 +471,8 @@ def main():
538
 
539
  # Store some constant
540
  num_epochs = int(training_args.num_train_epochs)
541
- train_batch_size = (
542
- int(training_args.per_device_train_batch_size) * jax.device_count()
543
- )
544
- eval_batch_size = (
545
- int(training_args.per_device_eval_batch_size) * jax.device_count()
546
- )
547
  steps_per_epoch = len(train_dataset) // train_batch_size
548
  total_train_steps = steps_per_epoch * num_epochs
549
 
@@ -566,39 +495,35 @@ def main():
566
  def decay_mask_fn(params):
567
  flat_params = traverse_util.flatten_dict(params)
568
  flat_mask = {
569
- path: (
570
- path[-1] != "bias"
571
- and path[-2:]
572
- not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]
573
- )
574
  for path in flat_params
575
  }
576
  return traverse_util.unflatten_dict(flat_mask)
577
 
578
  # create adam optimizer
579
- adamw = optax.adamw(
580
- learning_rate=linear_decay_lr_schedule_fn,
581
- b1=training_args.adam_beta1,
582
- b2=training_args.adam_beta2,
583
- eps=training_args.adam_epsilon,
584
- weight_decay=training_args.weight_decay,
585
- mask=decay_mask_fn,
586
- )
 
 
 
 
 
 
 
587
 
588
  # Setup train state
589
- state = TrainState.create(
590
- apply_fn=model.__call__,
591
- params=model.params,
592
- tx=adamw,
593
- dropout_rng=dropout_rng,
594
- )
595
 
596
  def loss_fn(logits, labels):
597
  shift_logits = logits[..., :-1, :]
598
  shift_labels = labels[..., 1:]
599
- loss = optax.softmax_cross_entropy(
600
- shift_logits, onehot(shift_labels, shift_logits.shape[-1])
601
- )
602
  return loss.mean()
603
 
604
  # Define gradient update step fn
@@ -607,9 +532,7 @@ def main():
607
 
608
  def compute_loss(params):
609
  labels = batch.pop("labels")
610
- logits = state.apply_fn(
611
- **batch, params=params, dropout_rng=dropout_rng, train=True
612
- )[0]
613
  loss = loss_fn(logits, labels)
614
  return loss
615
 
@@ -619,10 +542,7 @@ def main():
619
 
620
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
621
 
622
- metrics = {
623
- "loss": loss,
624
- "learning_rate": linear_decay_lr_schedule_fn(state.step),
625
- }
626
  metrics = jax.lax.pmean(metrics, axis_name="batch")
627
 
628
  return new_state, metrics
@@ -648,15 +568,12 @@ def main():
648
  logger.info("***** Running training *****")
649
  logger.info(f" Num examples = {len(train_dataset)}")
650
  logger.info(f" Num Epochs = {num_epochs}")
651
- logger.info(
652
- f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
653
- )
654
- logger.info(
655
- f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
656
- )
657
  logger.info(f" Total optimization steps = {total_train_steps}")
658
 
659
  train_time = 0
 
660
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
661
  for epoch in epochs:
662
  # ======================== Training ================================
@@ -664,72 +581,70 @@ def main():
664
 
665
  # Create sampling rng
666
  rng, input_rng = jax.random.split(rng)
667
- train_metrics = []
668
 
669
  # Generate an epoch by shuffling sampling indices from the train dataset
670
- train_loader = data_loader(
671
- input_rng, train_dataset, train_batch_size, shuffle=True
672
- )
673
  steps_per_epoch = len(train_dataset) // train_batch_size
674
  # train
675
- for _ in tqdm(
676
- range(steps_per_epoch), desc="Training...", position=1, leave=False
677
- ):
678
  batch = next(train_loader)
679
  state, train_metric = p_train_step(state, batch)
680
  train_metrics.append(train_metric)
681
 
682
- train_time += time.time() - train_start
683
-
684
- train_metric = unreplicate(train_metric)
685
-
686
- epochs.write(
687
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
688
- )
689
-
690
- # ======================== Evaluating ==============================
691
- eval_metrics = []
692
- eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
693
- eval_steps = len(eval_dataset) // eval_batch_size
694
- for _ in tqdm(
695
- range(eval_steps), desc="Evaluating...", position=2, leave=False
696
- ):
697
- # Model forward
698
- batch = next(eval_loader)
699
- metrics = p_eval_step(state.params, batch)
700
- eval_metrics.append(metrics)
701
-
702
- # normalize eval metrics
703
- eval_metrics = get_metrics(eval_metrics)
704
-
705
- eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
706
-
707
- try:
708
- eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
709
- except OverflowError:
710
- eval_metrics["perplexity"] = float("inf")
711
-
712
- # Print metrics and update progress bar
713
- desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
714
- epochs.write(desc)
715
- epochs.desc = desc
716
-
717
- # Save metrics
718
- if has_tensorboard and jax.process_index() == 0:
719
- cur_step = epoch * (len(train_dataset) // train_batch_size)
720
- write_metric(
721
- summary_writer, train_metrics, eval_metrics, train_time, cur_step
722
- )
723
-
724
- # save checkpoint after each epoch and push checkpoint to the hub
725
- if jax.process_index() == 0:
726
- params = jax.device_get(unreplicate(state.params))
727
- model.save_pretrained(
728
- training_args.output_dir,
729
- params=params,
730
- push_to_hub=training_args.push_to_hub,
731
- commit_message=f"Saving weights and logs of epoch {epoch+1}",
732
- )
 
 
 
733
 
734
 
735
  if __name__ == "__main__":
 
31
  from typing import Callable, Optional
32
 
33
  import datasets
34
+ from datasets import Dataset, load_dataset
35
+ from tqdm import tqdm
36
+
37
  import jax
38
  import jax.numpy as jnp
39
  import optax
40
  import transformers
41
+ import wandb
42
  from flax import jax_utils, traverse_util
43
  from flax.jax_utils import unreplicate
44
  from flax.training import train_state
45
  from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
 
46
  from transformers import (
47
  CONFIG_MAPPING,
48
  FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
 
55
  )
56
  from transformers.testing_utils import CaptureLogger
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ logger = logging.getLogger(__name__)
60
 
61
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
62
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
 
77
  )
78
  model_type: Optional[str] = field(
79
  default=None,
80
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
 
 
 
81
  )
82
  config_name: Optional[str] = field(
83
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 
 
 
84
  )
85
  tokenizer_name: Optional[str] = field(
86
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
 
 
 
87
  )
88
  cache_dir: Optional[str] = field(
89
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
 
 
 
90
  )
91
  use_fast_tokenizer: bool = field(
92
  default=True,
93
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
 
 
94
  )
95
  dtype: Optional[str] = field(
96
  default="float32",
 
107
  """
108
 
109
  dataset_name: Optional[str] = field(
110
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
 
 
 
111
  )
112
  dataset_config_name: Optional[str] = field(
113
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
 
 
 
 
 
 
 
114
  )
115
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
116
  validation_file: Optional[str] = field(
117
  default=None,
118
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
 
 
119
  )
120
  max_train_samples: Optional[int] = field(
121
  default=None,
 
132
  },
133
  )
134
  overwrite_cache: bool = field(
135
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
136
  )
137
  validation_split_percentage: Optional[int] = field(
138
  default=5,
 
149
  },
150
  )
151
  overwrite_cache: bool = field(
152
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
 
153
  )
154
  preprocessing_num_workers: Optional[int] = field(
155
  default=None,
 
157
  )
158
 
159
  def __post_init__(self):
160
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
161
+ raise ValueError("Need either a dataset name or a training/validation file.")
 
 
 
 
 
 
162
  else:
163
  if self.train_file is not None:
164
  extension = self.train_file.split(".")[-1]
165
+ assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
 
 
 
 
166
  if self.validation_file is not None:
167
  extension = self.validation_file.split(".")[-1]
168
+ assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
 
 
 
 
169
 
170
 
171
  class TrainState(train_state.TrainState):
172
  dropout_rng: jnp.ndarray
173
 
174
  def replicate(self):
175
+ return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
 
 
176
 
177
 
178
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
 
 
179
  """
180
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
181
  Shuffle batches if `shuffle` is `True`.
 
199
  yield batch
200
 
201
 
202
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
203
  summary_writer.scalar("train_time", train_time, step)
204
 
205
  train_metrics = get_metrics(train_metrics)
 
208
  for i, val in enumerate(vals):
209
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
210
 
211
+
212
+ def write_eval_metric(summary_writer, eval_metrics, step):
213
  for metric_name, value in eval_metrics.items():
214
  summary_writer.scalar(f"eval_{metric_name}", value, step)
215
 
216
 
217
  def create_learning_rate_fn(
218
+ train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
 
 
 
 
219
  ) -> Callable[[int], jnp.array]:
220
  """Returns a linear warmup, linear_decay learning rate function."""
221
  steps_per_epoch = train_ds_size // train_batch_size
222
  num_train_steps = steps_per_epoch * num_train_epochs
223
+ warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
 
 
224
  decay_fn = optax.linear_schedule(
225
+ init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
 
 
 
 
 
226
  )
227
+ schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
228
  return schedule_fn
229
 
230
 
 
233
  # or by passing the --help flag to this script.
234
  # We now keep distinct sets of args, for a cleaner separation of concerns.
235
 
236
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
 
 
237
  if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
238
  # If we pass only one argument to the script and it's the path to a json file,
239
  # let's parse it to get our arguments.
240
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
 
 
241
  else:
242
  model_args, data_args, training_args = parser.parse_args_into_dataclasses()
243
 
 
279
  #
280
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
281
  # download the dataset.
 
282
  if data_args.dataset_name is not None:
283
  # Downloading and loading a dataset from the hub.
284
  dataset = load_dataset(
285
+ data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False
 
 
 
286
  )
287
 
288
  if "validation" not in dataset.keys():
 
307
  extension = data_args.train_file.split(".")[-1]
308
  if extension == "txt":
309
  extension = "text"
310
+ dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
 
 
 
311
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
312
  # https://huggingface.co/docs/datasets/loading_datasets.html.
313
 
 
317
  # The .from_pretrained methods guarantee that only one local process can concurrently
318
  # download model & vocab.
319
  if model_args.config_name:
320
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
 
 
321
  elif model_args.model_name_or_path:
322
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
 
 
323
  else:
324
  config = CONFIG_MAPPING[model_args.model_type]()
325
  logger.warning("You are instantiating a new config instance from scratch.")
326
 
327
  if model_args.tokenizer_name:
328
  tokenizer = AutoTokenizer.from_pretrained(
329
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
330
  )
331
  elif model_args.model_name_or_path:
332
  tokenizer = AutoTokenizer.from_pretrained(
333
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
 
 
334
  )
335
  else:
336
  raise ValueError(
 
340
 
341
  if model_args.model_name_or_path:
342
  model = FlaxAutoModelForCausalLM.from_pretrained(
343
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
 
 
 
344
  )
345
  else:
346
  model = FlaxAutoModelForCausalLM.from_config(
 
356
  text_column_name = "text" if "text" in column_names else column_names[0]
357
 
358
  # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
359
+ tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
 
 
360
 
361
  def tokenize_function(examples):
362
  with CaptureLogger(tok_logger) as cl:
 
399
  total_length = len(concatenated_examples[list(examples.keys())[0]])
400
  # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
401
  # customize this part to your needs.
402
+ if total_length >= block_size:
403
+ total_length = (total_length // block_size) * block_size
404
  # Split by chunks of max_len.
405
  result = {
406
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
 
438
  eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
439
 
440
  # Enable tensorboard only on the master node
441
+ has_tensorboard = is_tensorboard_available()
442
  if has_tensorboard and jax.process_index() == 0:
443
+ wandb.init(
444
+ entity='abinayam',
445
+ project='hf-flax-gpt-2-tamil',
446
+ sync_tensorboard=True
447
+ )
448
+
449
+ wandb.config.update(training_args) # optional, log your configs
450
+ wandb.config.update(model_args) # optional, log your configs
451
+ wandb.config.update(data_args) # optional, log your configs
452
+
453
+ try:
454
+ from flax.metrics.tensorboard import SummaryWriter
455
+
456
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
457
+ except ImportError as ie:
458
+ has_tensorboard = False
459
+ logger.warning(
460
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
461
+ )
462
+ else:
463
+ logger.warning(
464
+ "Unable to display metrics through TensorBoard because the package is not installed: "
465
+ "Please run pip install tensorboard to enable."
466
+ )
467
 
468
  # Initialize our training
469
  rng = jax.random.PRNGKey(training_args.seed)
 
471
 
472
  # Store some constant
473
  num_epochs = int(training_args.num_train_epochs)
474
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
475
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
 
 
 
 
476
  steps_per_epoch = len(train_dataset) // train_batch_size
477
  total_train_steps = steps_per_epoch * num_epochs
478
 
 
495
  def decay_mask_fn(params):
496
  flat_params = traverse_util.flatten_dict(params)
497
  flat_mask = {
498
+ path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")])
 
 
 
 
499
  for path in flat_params
500
  }
501
  return traverse_util.unflatten_dict(flat_mask)
502
 
503
  # create adam optimizer
504
+ if training_args.adafactor:
505
+ # We use the default parameters here to initialize adafactor,
506
+ # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74
507
+ optimizer = optax.adafactor(
508
+ learning_rate=linear_decay_lr_schedule_fn,
509
+ )
510
+ else:
511
+ optimizer = optax.adamw(
512
+ learning_rate=linear_decay_lr_schedule_fn,
513
+ b1=training_args.adam_beta1,
514
+ b2=training_args.adam_beta2,
515
+ eps=training_args.adam_epsilon,
516
+ weight_decay=training_args.weight_decay,
517
+ mask=decay_mask_fn,
518
+ )
519
 
520
  # Setup train state
521
+ state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
 
 
 
 
 
522
 
523
  def loss_fn(logits, labels):
524
  shift_logits = logits[..., :-1, :]
525
  shift_labels = labels[..., 1:]
526
+ loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
 
 
527
  return loss.mean()
528
 
529
  # Define gradient update step fn
 
532
 
533
  def compute_loss(params):
534
  labels = batch.pop("labels")
535
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
 
 
536
  loss = loss_fn(logits, labels)
537
  return loss
538
 
 
542
 
543
  new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
544
 
545
+ metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
 
 
 
546
  metrics = jax.lax.pmean(metrics, axis_name="batch")
547
 
548
  return new_state, metrics
 
568
  logger.info("***** Running training *****")
569
  logger.info(f" Num examples = {len(train_dataset)}")
570
  logger.info(f" Num Epochs = {num_epochs}")
571
+ logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
572
+ logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
 
 
 
 
573
  logger.info(f" Total optimization steps = {total_train_steps}")
574
 
575
  train_time = 0
576
+ train_metrics = []
577
  epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
578
  for epoch in epochs:
579
  # ======================== Training ================================
 
581
 
582
  # Create sampling rng
583
  rng, input_rng = jax.random.split(rng)
 
584
 
585
  # Generate an epoch by shuffling sampling indices from the train dataset
586
+ train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
 
 
587
  steps_per_epoch = len(train_dataset) // train_batch_size
588
  # train
589
+ for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
 
 
590
  batch = next(train_loader)
591
  state, train_metric = p_train_step(state, batch)
592
  train_metrics.append(train_metric)
593
 
594
+ cur_step = epoch * (len(train_dataset) // train_batch_size) + step
595
+
596
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
597
+ # Save metrics
598
+ train_metric = unreplicate(train_metric)
599
+ train_time += time.time() - train_start
600
+ if has_tensorboard and jax.process_index() == 0:
601
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
602
+
603
+ epochs.write(
604
+ f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
605
+ )
606
+
607
+ train_metrics = []
608
+
609
+ if cur_step % training_args.eval_steps == 0 and cur_step > 0:
610
+ # ======================== Evaluating ==============================
611
+ eval_metrics = []
612
+ eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size)
613
+ eval_steps = len(eval_dataset) // eval_batch_size
614
+ for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False):
615
+ # Model forward
616
+ batch = next(eval_loader)
617
+ metrics = p_eval_step(state.params, batch)
618
+ eval_metrics.append(metrics)
619
+
620
+ # normalize eval metrics
621
+ eval_metrics = get_metrics(eval_metrics)
622
+ eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
623
+
624
+ try:
625
+ eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
626
+ except OverflowError:
627
+ eval_metrics["perplexity"] = float("inf")
628
+
629
+ # Print metrics and update progress bar
630
+ desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})"
631
+ epochs.write(desc)
632
+ epochs.desc = desc
633
+
634
+ # Save metrics
635
+ if has_tensorboard and jax.process_index() == 0:
636
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
637
+
638
+ if cur_step % training_args.save_steps == 0 and cur_step > 0:
639
+ # save checkpoint after each epoch and push checkpoint to the hub
640
+ if jax.process_index() == 0:
641
+ params = jax.device_get(unreplicate(state.params))
642
+ model.save_pretrained(
643
+ training_args.output_dir,
644
+ params=params,
645
+ push_to_hub=training_args.push_to_hub,
646
+ commit_message=f"Saving weights and logs of step {cur_step}",
647
+ )
648
 
649
 
650
  if __name__ == "__main__":
src/train_tokenizer.py CHANGED
@@ -1,7 +1,7 @@
1
  from datasets import load_dataset
2
  from tokenizers import ByteLevelBPETokenizer # Tokenizer, normalizers, trainers
3
 
4
- model_dir = "./gpt2-tamil" # ${MODEL_DIR}
5
 
6
  # load dataset
7
  dataset = load_dataset("oscar", "unshuffled_deduplicated_ta", split="train")
 
1
  from datasets import load_dataset
2
  from tokenizers import ByteLevelBPETokenizer # Tokenizer, normalizers, trainers
3
 
4
+ model_dir = "../gpt-2-tamil" # ${MODEL_DIR}
5
 
6
  # load dataset
7
  dataset = load_dataset("oscar", "unshuffled_deduplicated_ta", split="train")