kimbochen commited on
Commit
3538413
·
1 Parent(s): 143aad4

Training in progress, step 200

Browse files
.ipynb_checkpoints/fine-tune-whisper-streaming-checkpoint.ipynb CHANGED
@@ -108,7 +108,7 @@
108
  },
109
  {
110
  "cell_type": "code",
111
- "execution_count": 1,
112
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
113
  "metadata": {},
114
  "outputs": [],
@@ -142,7 +142,7 @@
142
  },
143
  {
144
  "cell_type": "code",
145
- "execution_count": 3,
146
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
147
  "metadata": {},
148
  "outputs": [],
@@ -151,7 +151,7 @@
151
  "\n",
152
  "raw_datasets = IterableDatasetDict()\n",
153
  "\n",
154
- "raw_datasets[\"train\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"train\", use_auth_token=True) # set split=\"train+validation\" for low-resource\n",
155
  "raw_datasets[\"test\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"test\", use_auth_token=True)"
156
  ]
157
  },
@@ -185,109 +185,10 @@
185
  },
186
  {
187
  "cell_type": "code",
188
- "execution_count": 4,
189
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
190
  "metadata": {},
191
- "outputs": [
192
- {
193
- "data": {
194
- "application/vnd.jupyter.widget-view+json": {
195
- "model_id": "ab8ef1fb2f284e2abd43a1b1bde55882",
196
- "version_major": 2,
197
- "version_minor": 0
198
- },
199
- "text/plain": [
200
- "Downloading: 0%| | 0.00/185k [00:00<?, ?B/s]"
201
- ]
202
- },
203
- "metadata": {},
204
- "output_type": "display_data"
205
- },
206
- {
207
- "data": {
208
- "application/vnd.jupyter.widget-view+json": {
209
- "model_id": "e0c2142f48224f1582e6457dbb8e5276",
210
- "version_major": 2,
211
- "version_minor": 0
212
- },
213
- "text/plain": [
214
- "Downloading: 0%| | 0.00/829 [00:00<?, ?B/s]"
215
- ]
216
- },
217
- "metadata": {},
218
- "output_type": "display_data"
219
- },
220
- {
221
- "data": {
222
- "application/vnd.jupyter.widget-view+json": {
223
- "model_id": "55aa8ea93e924389b339aefec864805d",
224
- "version_major": 2,
225
- "version_minor": 0
226
- },
227
- "text/plain": [
228
- "Downloading: 0%| | 0.00/1.04M [00:00<?, ?B/s]"
229
- ]
230
- },
231
- "metadata": {},
232
- "output_type": "display_data"
233
- },
234
- {
235
- "data": {
236
- "application/vnd.jupyter.widget-view+json": {
237
- "model_id": "5cc4483a4d234f73914d26f285588949",
238
- "version_major": 2,
239
- "version_minor": 0
240
- },
241
- "text/plain": [
242
- "Downloading: 0%| | 0.00/494k [00:00<?, ?B/s]"
243
- ]
244
- },
245
- "metadata": {},
246
- "output_type": "display_data"
247
- },
248
- {
249
- "data": {
250
- "application/vnd.jupyter.widget-view+json": {
251
- "model_id": "806dfeffeb1a4d6ba3a042cadee13450",
252
- "version_major": 2,
253
- "version_minor": 0
254
- },
255
- "text/plain": [
256
- "Downloading: 0%| | 0.00/52.7k [00:00<?, ?B/s]"
257
- ]
258
- },
259
- "metadata": {},
260
- "output_type": "display_data"
261
- },
262
- {
263
- "data": {
264
- "application/vnd.jupyter.widget-view+json": {
265
- "model_id": "b93cdf2091424615927adaefb032132f",
266
- "version_major": 2,
267
- "version_minor": 0
268
- },
269
- "text/plain": [
270
- "Downloading: 0%| | 0.00/2.11k [00:00<?, ?B/s]"
271
- ]
272
- },
273
- "metadata": {},
274
- "output_type": "display_data"
275
- },
276
- {
277
- "data": {
278
- "application/vnd.jupyter.widget-view+json": {
279
- "model_id": "cdb5621656934de2a60214f67530212c",
280
- "version_major": 2,
281
- "version_minor": 0
282
- },
283
- "text/plain": [
284
- "Downloading: 0%| | 0.00/2.06k [00:00<?, ?B/s]"
285
- ]
286
- },
287
- "metadata": {},
288
- "output_type": "display_data"
289
- }
290
- ],
291
  "source": [
292
  "from transformers import WhisperProcessor\n",
293
  "\n",
@@ -312,7 +213,7 @@
312
  },
313
  {
314
  "cell_type": "code",
315
- "execution_count": 5,
316
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
317
  "metadata": {},
318
  "outputs": [
@@ -332,7 +233,7 @@
332
  " 'segment': Value(dtype='string', id=None)}"
333
  ]
334
  },
335
- "execution_count": 5,
336
  "metadata": {},
337
  "output_type": "execute_result"
338
  }
@@ -358,7 +259,7 @@
358
  },
359
  {
360
  "cell_type": "code",
361
- "execution_count": 6,
362
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
363
  "metadata": {},
364
  "outputs": [],
@@ -378,7 +279,7 @@
378
  },
379
  {
380
  "cell_type": "code",
381
- "execution_count": 7,
382
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
383
  "metadata": {},
384
  "outputs": [],
@@ -405,7 +306,7 @@
405
  },
406
  {
407
  "cell_type": "code",
408
- "execution_count": 8,
409
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
410
  "metadata": {},
411
  "outputs": [],
@@ -441,7 +342,7 @@
441
  },
442
  {
443
  "cell_type": "code",
444
- "execution_count": 9,
445
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
446
  "metadata": {},
447
  "outputs": [],
@@ -459,7 +360,7 @@
459
  },
460
  {
461
  "cell_type": "code",
462
- "execution_count": 10,
463
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
464
  "metadata": {},
465
  "outputs": [],
@@ -480,7 +381,7 @@
480
  },
481
  {
482
  "cell_type": "code",
483
- "execution_count": 11,
484
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
485
  "metadata": {},
486
  "outputs": [],
@@ -501,7 +402,7 @@
501
  },
502
  {
503
  "cell_type": "code",
504
- "execution_count": 12,
505
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
506
  "metadata": {},
507
  "outputs": [],
@@ -571,7 +472,7 @@
571
  },
572
  {
573
  "cell_type": "code",
574
- "execution_count": 13,
575
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5",
576
  "metadata": {},
577
  "outputs": [],
@@ -619,7 +520,7 @@
619
  },
620
  {
621
  "cell_type": "code",
622
- "execution_count": 14,
623
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42",
624
  "metadata": {},
625
  "outputs": [],
@@ -646,14 +547,14 @@
646
  },
647
  {
648
  "cell_type": "code",
649
- "execution_count": 15,
650
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
651
  "metadata": {},
652
  "outputs": [
653
  {
654
  "data": {
655
  "application/vnd.jupyter.widget-view+json": {
656
- "model_id": "737faa61d325424ba4b395c4aeb9a58f",
657
  "version_major": 2,
658
  "version_minor": 0
659
  },
@@ -690,7 +591,7 @@
690
  },
691
  {
692
  "cell_type": "code",
693
- "execution_count": 16,
694
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb",
695
  "metadata": {},
696
  "outputs": [],
@@ -740,14 +641,14 @@
740
  },
741
  {
742
  "cell_type": "code",
743
- "execution_count": 17,
744
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
745
  "metadata": {},
746
  "outputs": [
747
  {
748
  "data": {
749
  "application/vnd.jupyter.widget-view+json": {
750
- "model_id": "a2e3bfa2e47241f193b2e9be28e184b4",
751
  "version_major": 2,
752
  "version_minor": 0
753
  },
@@ -761,7 +662,7 @@
761
  {
762
  "data": {
763
  "application/vnd.jupyter.widget-view+json": {
764
- "model_id": "22bccf0b5b46459ebfbdb17f4641b800",
765
  "version_major": 2,
766
  "version_minor": 0
767
  },
@@ -789,7 +690,7 @@
789
  },
790
  {
791
  "cell_type": "code",
792
- "execution_count": 18,
793
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f",
794
  "metadata": {},
795
  "outputs": [],
@@ -817,7 +718,7 @@
817
  },
818
  {
819
  "cell_type": "code",
820
- "execution_count": 19,
821
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
822
  "metadata": {},
823
  "outputs": [],
@@ -829,16 +730,16 @@
829
  " per_device_train_batch_size=64,\n",
830
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
831
  " learning_rate=1e-5,\n",
832
- " warmup_steps=500,\n",
833
- " max_steps=5000,\n",
834
  " gradient_checkpointing=True,\n",
835
  " fp16=True,\n",
836
  " evaluation_strategy=\"steps\",\n",
837
  " per_device_eval_batch_size=8,\n",
838
  " predict_with_generate=True,\n",
839
  " generation_max_length=225,\n",
840
- " save_steps=1000,\n",
841
- " eval_steps=1000,\n",
842
  " logging_steps=25,\n",
843
  " report_to=[\"tensorboard\"],\n",
844
  " load_best_model_at_end=True,\n",
@@ -867,7 +768,7 @@
867
  },
868
  {
869
  "cell_type": "code",
870
- "execution_count": 20,
871
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
872
  "metadata": {},
873
  "outputs": [],
@@ -896,7 +797,7 @@
896
  },
897
  {
898
  "cell_type": "code",
899
- "execution_count": 21,
900
  "id": "d546d7fe-0543-479a-b708-2ebabec19493",
901
  "metadata": {},
902
  "outputs": [
@@ -935,7 +836,7 @@
935
  },
936
  {
937
  "cell_type": "code",
938
- "execution_count": 22,
939
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672",
940
  "metadata": {},
941
  "outputs": [
@@ -992,14 +893,15 @@
992
  "/home/ubuntu/.venv/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
993
  " warnings.warn(\n",
994
  "***** Running training *****\n",
995
- " Num examples = 320000\n",
996
  " Num Epochs = 9223372036854775807\n",
997
  " Instantaneous batch size per device = 64\n",
998
  " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
999
  " Gradient Accumulation steps = 1\n",
1000
- " Total optimization steps = 5000\n",
1001
  " Number of trainable parameters = 241734912\n",
1002
- "Reading metadata...: 6505it [00:00, 21991.51it/s]\n",
 
1003
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
1004
  ]
1005
  },
@@ -1009,8 +911,8 @@
1009
  "\n",
1010
  " <div>\n",
1011
  " \n",
1012
- " <progress value='424' max='5000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1013
- " [ 424/5000 1:07:10 < 12:08:22, 0.10 it/s, Epoch 4.00/9223372036854775807]\n",
1014
  " </div>\n",
1015
  " <table border=\"1\" class=\"dataframe\">\n",
1016
  " <thead>\n",
@@ -1035,10 +937,13 @@
1035
  "name": "stderr",
1036
  "output_type": "stream",
1037
  "text": [
1038
- "Reading metadata...: 6505it [00:00, 24574.42it/s]\n",
1039
- "Reading metadata...: 6505it [00:00, 24420.15it/s]\n",
1040
- "Reading metadata...: 6505it [00:00, 36254.85it/s]\n",
1041
- "Reading metadata...: 6505it [00:00, 30794.80it/s]\n"
 
 
 
1042
  ]
1043
  }
1044
  ],
@@ -1068,7 +973,7 @@
1068
  },
1069
  {
1070
  "cell_type": "code",
1071
- "execution_count": null,
1072
  "id": "6dd0e310-9b07-4133-ac14-2ed2d7524e22",
1073
  "metadata": {},
1074
  "outputs": [],
@@ -1094,20 +999,282 @@
1094
  },
1095
  {
1096
  "cell_type": "code",
1097
- "execution_count": null,
1098
  "id": "95737cda-c5dd-4887-a4d0-dfcb0d61d977",
1099
  "metadata": {},
1100
- "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1101
  "source": [
1102
  "trainer.push_to_hub(**kwargs)"
1103
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1104
  }
1105
  ],
1106
  "metadata": {
1107
  "kernelspec": {
1108
- "display_name": "hf",
1109
  "language": "python",
1110
- "name": "hf"
1111
  },
1112
  "language_info": {
1113
  "codemirror_mode": {
 
108
  },
109
  {
110
  "cell_type": "code",
111
+ "execution_count": 5,
112
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
113
  "metadata": {},
114
  "outputs": [],
 
142
  },
143
  {
144
  "cell_type": "code",
145
+ "execution_count": 6,
146
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
147
  "metadata": {},
148
  "outputs": [],
 
151
  "\n",
152
  "raw_datasets = IterableDatasetDict()\n",
153
  "\n",
154
+ "raw_datasets[\"train\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"train+validation\", use_auth_token=True) # set split=\"train+validation\" for low-resource\n",
155
  "raw_datasets[\"test\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"test\", use_auth_token=True)"
156
  ]
157
  },
 
185
  },
186
  {
187
  "cell_type": "code",
188
+ "execution_count": 7,
189
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
190
  "metadata": {},
191
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  "source": [
193
  "from transformers import WhisperProcessor\n",
194
  "\n",
 
213
  },
214
  {
215
  "cell_type": "code",
216
+ "execution_count": 8,
217
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
218
  "metadata": {},
219
  "outputs": [
 
233
  " 'segment': Value(dtype='string', id=None)}"
234
  ]
235
  },
236
+ "execution_count": 8,
237
  "metadata": {},
238
  "output_type": "execute_result"
239
  }
 
259
  },
260
  {
261
  "cell_type": "code",
262
+ "execution_count": 9,
263
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
264
  "metadata": {},
265
  "outputs": [],
 
279
  },
280
  {
281
  "cell_type": "code",
282
+ "execution_count": 10,
283
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
284
  "metadata": {},
285
  "outputs": [],
 
306
  },
307
  {
308
  "cell_type": "code",
309
+ "execution_count": 11,
310
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
311
  "metadata": {},
312
  "outputs": [],
 
342
  },
343
  {
344
  "cell_type": "code",
345
+ "execution_count": 12,
346
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
347
  "metadata": {},
348
  "outputs": [],
 
360
  },
361
  {
362
  "cell_type": "code",
363
+ "execution_count": 13,
364
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
365
  "metadata": {},
366
  "outputs": [],
 
381
  },
382
  {
383
  "cell_type": "code",
384
+ "execution_count": 14,
385
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
386
  "metadata": {},
387
  "outputs": [],
 
402
  },
403
  {
404
  "cell_type": "code",
405
+ "execution_count": 15,
406
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
407
  "metadata": {},
408
  "outputs": [],
 
472
  },
473
  {
474
  "cell_type": "code",
475
+ "execution_count": 16,
476
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5",
477
  "metadata": {},
478
  "outputs": [],
 
520
  },
521
  {
522
  "cell_type": "code",
523
+ "execution_count": 17,
524
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42",
525
  "metadata": {},
526
  "outputs": [],
 
547
  },
548
  {
549
  "cell_type": "code",
550
+ "execution_count": 18,
551
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
552
  "metadata": {},
553
  "outputs": [
554
  {
555
  "data": {
556
  "application/vnd.jupyter.widget-view+json": {
557
+ "model_id": "bffdd7b1fed44295954d9eed41a9cfd5",
558
  "version_major": 2,
559
  "version_minor": 0
560
  },
 
591
  },
592
  {
593
  "cell_type": "code",
594
+ "execution_count": 19,
595
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb",
596
  "metadata": {},
597
  "outputs": [],
 
641
  },
642
  {
643
  "cell_type": "code",
644
+ "execution_count": 20,
645
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
646
  "metadata": {},
647
  "outputs": [
648
  {
649
  "data": {
650
  "application/vnd.jupyter.widget-view+json": {
651
+ "model_id": "48fee2fd3b2a4a67b3a35666fda4dfe9",
652
  "version_major": 2,
653
  "version_minor": 0
654
  },
 
662
  {
663
  "data": {
664
  "application/vnd.jupyter.widget-view+json": {
665
+ "model_id": "51cdba284e8f44318868fbd013970280",
666
  "version_major": 2,
667
  "version_minor": 0
668
  },
 
690
  },
691
  {
692
  "cell_type": "code",
693
+ "execution_count": 21,
694
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f",
695
  "metadata": {},
696
  "outputs": [],
 
718
  },
719
  {
720
  "cell_type": "code",
721
+ "execution_count": 22,
722
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
723
  "metadata": {},
724
  "outputs": [],
 
730
  " per_device_train_batch_size=64,\n",
731
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
732
  " learning_rate=1e-5,\n",
733
+ " warmup_steps=200,\n",
734
+ " max_steps=1000,\n",
735
  " gradient_checkpointing=True,\n",
736
  " fp16=True,\n",
737
  " evaluation_strategy=\"steps\",\n",
738
  " per_device_eval_batch_size=8,\n",
739
  " predict_with_generate=True,\n",
740
  " generation_max_length=225,\n",
741
+ " save_steps=200,\n",
742
+ " eval_steps=200,\n",
743
  " logging_steps=25,\n",
744
  " report_to=[\"tensorboard\"],\n",
745
  " load_best_model_at_end=True,\n",
 
768
  },
769
  {
770
  "cell_type": "code",
771
+ "execution_count": 23,
772
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
773
  "metadata": {},
774
  "outputs": [],
 
797
  },
798
  {
799
  "cell_type": "code",
800
+ "execution_count": 24,
801
  "id": "d546d7fe-0543-479a-b708-2ebabec19493",
802
  "metadata": {},
803
  "outputs": [
 
836
  },
837
  {
838
  "cell_type": "code",
839
+ "execution_count": 25,
840
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672",
841
  "metadata": {},
842
  "outputs": [
 
893
  "/home/ubuntu/.venv/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
894
  " warnings.warn(\n",
895
  "***** Running training *****\n",
896
+ " Num examples = 64000\n",
897
  " Num Epochs = 9223372036854775807\n",
898
  " Instantaneous batch size per device = 64\n",
899
  " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
900
  " Gradient Accumulation steps = 1\n",
901
+ " Total optimization steps = 1000\n",
902
  " Number of trainable parameters = 241734912\n",
903
+ "Reading metadata...: 6505it [00:00, 31331.40it/s]\n",
904
+ "Reading metadata...: 4485it [00:00, 41376.86it/s]\n",
905
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
906
  ]
907
  },
 
911
  "\n",
912
  " <div>\n",
913
  " \n",
914
+ " <progress value='201' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
915
+ " [ 201/1000 22:31 < 1:30:27, 0.15 it/s, Epoch 1.06/9223372036854775807]\n",
916
  " </div>\n",
917
  " <table border=\"1\" class=\"dataframe\">\n",
918
  " <thead>\n",
 
937
  "name": "stderr",
938
  "output_type": "stream",
939
  "text": [
940
+ "Reading metadata...: 6505it [00:00, 64162.65it/s]\n",
941
+ "Reading metadata...: 4485it [00:00, 27834.06it/s]\n",
942
+ "***** Running Evaluation *****\n",
943
+ " Num examples: Unknown\n",
944
+ " Batch size = 8\n",
945
+ "Reading metadata...: 4604it [00:00, 27155.92it/s]\n",
946
+ "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
947
  ]
948
  }
949
  ],
 
973
  },
974
  {
975
  "cell_type": "code",
976
+ "execution_count": 24,
977
  "id": "6dd0e310-9b07-4133-ac14-2ed2d7524e22",
978
  "metadata": {},
979
  "outputs": [],
 
999
  },
1000
  {
1001
  "cell_type": "code",
1002
+ "execution_count": 31,
1003
  "id": "95737cda-c5dd-4887-a4d0-dfcb0d61d977",
1004
  "metadata": {},
1005
+ "outputs": [
1006
+ {
1007
+ "name": "stderr",
1008
+ "output_type": "stream",
1009
+ "text": [
1010
+ "Saving model checkpoint to ./\n",
1011
+ "Configuration saved in ./config.json\n",
1012
+ "Model weights saved in ./pytorch_model.bin\n",
1013
+ "Feature extractor saved in ./preprocessor_config.json\n",
1014
+ "tokenizer config file saved in ./tokenizer_config.json\n",
1015
+ "Special tokens file saved in ./special_tokens_map.json\n",
1016
+ "added tokens file saved in ./added_tokens.json\n"
1017
+ ]
1018
+ },
1019
+ {
1020
+ "data": {
1021
+ "application/vnd.jupyter.widget-view+json": {
1022
+ "model_id": "695c170663c94560a567be198b7181ff",
1023
+ "version_major": 2,
1024
+ "version_minor": 0
1025
+ },
1026
+ "text/plain": [
1027
+ "Upload file runs/Dec10_16-23-25_129-213-27-84/1670689420.7830398/events.out.tfevents.1670689420.129-213-27-84.…"
1028
+ ]
1029
+ },
1030
+ "metadata": {},
1031
+ "output_type": "display_data"
1032
+ },
1033
+ {
1034
+ "data": {
1035
+ "application/vnd.jupyter.widget-view+json": {
1036
+ "model_id": "2318836d6dd3405fabafca4370232e34",
1037
+ "version_major": 2,
1038
+ "version_minor": 0
1039
+ },
1040
+ "text/plain": [
1041
+ "Upload file training_args.bin: 100%|##########| 3.50k/3.50k [00:00<?, ?B/s]"
1042
+ ]
1043
+ },
1044
+ "metadata": {},
1045
+ "output_type": "display_data"
1046
+ },
1047
+ {
1048
+ "data": {
1049
+ "application/vnd.jupyter.widget-view+json": {
1050
+ "model_id": "9b673eb134984bdda227d23929b66479",
1051
+ "version_major": 2,
1052
+ "version_minor": 0
1053
+ },
1054
+ "text/plain": [
1055
+ "Upload file runs/Dec10_16-23-25_129-213-27-84/events.out.tfevents.1670689420.129-213-27-84.69598.2: 100%|#####…"
1056
+ ]
1057
+ },
1058
+ "metadata": {},
1059
+ "output_type": "display_data"
1060
+ },
1061
+ {
1062
+ "name": "stderr",
1063
+ "output_type": "stream",
1064
+ "text": [
1065
+ "remote: Scanning LFS files for validity, may be slow... \n",
1066
+ "remote: LFS file scan complete. \n",
1067
+ "To https://huggingface.co/kimbochen/whisper-small-jp\n",
1068
+ " 3a44fa5..05da956 main -> main\n",
1069
+ "\n",
1070
+ "To https://huggingface.co/kimbochen/whisper-small-jp\n",
1071
+ " 05da956..30906c5 main -> main\n",
1072
+ "\n"
1073
+ ]
1074
+ },
1075
+ {
1076
+ "data": {
1077
+ "text/plain": [
1078
+ "'https://huggingface.co/kimbochen/whisper-small-jp/commit/05da956fdc97e7c01112f45c20e56c8f6a127502'"
1079
+ ]
1080
+ },
1081
+ "execution_count": 31,
1082
+ "metadata": {},
1083
+ "output_type": "execute_result"
1084
+ }
1085
+ ],
1086
  "source": [
1087
  "trainer.push_to_hub(**kwargs)"
1088
  ]
1089
+ },
1090
+ {
1091
+ "cell_type": "code",
1092
+ "execution_count": 28,
1093
+ "id": "4df1603c-ef35-40f1-ae57-3214441073c8",
1094
+ "metadata": {},
1095
+ "outputs": [
1096
+ {
1097
+ "name": "stderr",
1098
+ "output_type": "stream",
1099
+ "text": [
1100
+ "PyTorch: setting up devices\n"
1101
+ ]
1102
+ }
1103
+ ],
1104
+ "source": [
1105
+ "training_args = Seq2SeqTrainingArguments(\n",
1106
+ " output_dir=\"./\",\n",
1107
+ " per_device_train_batch_size=64,\n",
1108
+ " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
1109
+ " learning_rate=1e-5,\n",
1110
+ " max_steps=1000,\n",
1111
+ " num_train_epochs=-1,\n",
1112
+ " gradient_checkpointing=True,\n",
1113
+ " fp16=True,\n",
1114
+ " evaluation_strategy=\"steps\",\n",
1115
+ " per_device_eval_batch_size=8,\n",
1116
+ " predict_with_generate=True,\n",
1117
+ " generation_max_length=225,\n",
1118
+ " save_steps=1000,\n",
1119
+ " eval_steps=1000,\n",
1120
+ " logging_steps=25,\n",
1121
+ " report_to=[\"tensorboard\"],\n",
1122
+ " load_best_model_at_end=True,\n",
1123
+ " metric_for_best_model=\"wer\",\n",
1124
+ " greater_is_better=False,\n",
1125
+ " push_to_hub=True,\n",
1126
+ ")"
1127
+ ]
1128
+ },
1129
+ {
1130
+ "cell_type": "code",
1131
+ "execution_count": 29,
1132
+ "id": "afc2b554-7171-48c7-95aa-b7e61b70ab20",
1133
+ "metadata": {},
1134
+ "outputs": [
1135
+ {
1136
+ "name": "stderr",
1137
+ "output_type": "stream",
1138
+ "text": [
1139
+ "/home/ubuntu/whisper-small-jp/./ is already a clone of https://huggingface.co/kimbochen/whisper-small-jp. Make sure you pull the latest changes with `repo.git_pull()`.\n",
1140
+ "max_steps is given, it will override any value given in num_train_epochs\n",
1141
+ "Using cuda_amp half precision backend\n"
1142
+ ]
1143
+ }
1144
+ ],
1145
+ "source": [
1146
+ "trainer = Seq2SeqTrainer(\n",
1147
+ " args=training_args,\n",
1148
+ " model=model,\n",
1149
+ " train_dataset=vectorized_datasets[\"train\"],\n",
1150
+ " eval_dataset=vectorized_datasets[\"test\"],\n",
1151
+ " data_collator=data_collator,\n",
1152
+ " compute_metrics=compute_metrics,\n",
1153
+ " tokenizer=processor,\n",
1154
+ " callbacks=[ShuffleCallback()],\n",
1155
+ ")"
1156
+ ]
1157
+ },
1158
+ {
1159
+ "cell_type": "code",
1160
+ "execution_count": 30,
1161
+ "id": "b029a1d8-24de-46e7-b067-0f900b1db342",
1162
+ "metadata": {},
1163
+ "outputs": [
1164
+ {
1165
+ "name": "stderr",
1166
+ "output_type": "stream",
1167
+ "text": [
1168
+ "Loading model from checkpoint-4000.\n",
1169
+ "/home/ubuntu/.venv/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
1170
+ " warnings.warn(\n",
1171
+ "***** Running training *****\n",
1172
+ " Num examples = 64000\n",
1173
+ " Num Epochs = 9223372036854775807\n",
1174
+ " Instantaneous batch size per device = 64\n",
1175
+ " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
1176
+ " Gradient Accumulation steps = 1\n",
1177
+ " Total optimization steps = 1000\n",
1178
+ " Number of trainable parameters = 241734912\n",
1179
+ " Continuing training from checkpoint, will skip to saved global_step\n",
1180
+ " Continuing training from epoch 4\n",
1181
+ " Continuing training from global step 4000\n",
1182
+ " Will skip the first 4 epochs then the first 0 batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` flag to your launch command, but you will resume the training on data already seen by your model.\n"
1183
+ ]
1184
+ },
1185
+ {
1186
+ "data": {
1187
+ "application/vnd.jupyter.widget-view+json": {
1188
+ "model_id": "01337298313740d98d3cc75b6d5e3ff7",
1189
+ "version_major": 2,
1190
+ "version_minor": 0
1191
+ },
1192
+ "text/plain": [
1193
+ "0it [00:00, ?it/s]"
1194
+ ]
1195
+ },
1196
+ "metadata": {},
1197
+ "output_type": "display_data"
1198
+ },
1199
+ {
1200
+ "name": "stderr",
1201
+ "output_type": "stream",
1202
+ "text": [
1203
+ "\n",
1204
+ "Reading metadata...: 0it [00:00, ?it/s]\u001b[A\n",
1205
+ "Reading metadata...: 6505it [00:00, 34246.80it/s]\n",
1206
+ "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n",
1207
+ "\n",
1208
+ "Reading metadata...: 6505it [00:00, 84823.64it/s]\n",
1209
+ "\n",
1210
+ "Reading metadata...: 6505it [00:00, 88617.62it/s]\n",
1211
+ "\n",
1212
+ "Reading metadata...: 6505it [00:00, 90289.78it/s]\n",
1213
+ "\n",
1214
+ "Reading metadata...: 6505it [00:00, 91816.92it/s]\n"
1215
+ ]
1216
+ },
1217
+ {
1218
+ "data": {
1219
+ "text/html": [
1220
+ "\n",
1221
+ " <div>\n",
1222
+ " \n",
1223
+ " <progress value='4001' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1224
+ " [1000/1000 00:00, Epoch 4/9223372036854775807]\n",
1225
+ " </div>\n",
1226
+ " <table border=\"1\" class=\"dataframe\">\n",
1227
+ " <thead>\n",
1228
+ " <tr style=\"text-align: left;\">\n",
1229
+ " <th>Step</th>\n",
1230
+ " <th>Training Loss</th>\n",
1231
+ " <th>Validation Loss</th>\n",
1232
+ " </tr>\n",
1233
+ " </thead>\n",
1234
+ " <tbody>\n",
1235
+ " </tbody>\n",
1236
+ "</table><p>"
1237
+ ],
1238
+ "text/plain": [
1239
+ "<IPython.core.display.HTML object>"
1240
+ ]
1241
+ },
1242
+ "metadata": {},
1243
+ "output_type": "display_data"
1244
+ },
1245
+ {
1246
+ "name": "stderr",
1247
+ "output_type": "stream",
1248
+ "text": [
1249
+ "\n",
1250
+ "\n",
1251
+ "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
1252
+ "\n",
1253
+ "\n",
1254
+ "Loading best model from ./checkpoint-4000 (score: 88.31039863810469).\n"
1255
+ ]
1256
+ },
1257
+ {
1258
+ "data": {
1259
+ "text/plain": [
1260
+ "TrainOutput(global_step=4001, training_loss=8.343380785802548e-08, metrics={'train_runtime': 169.0541, 'train_samples_per_second': 378.577, 'train_steps_per_second': 5.915, 'total_flos': 7.363747084345344e+19, 'train_loss': 8.343380785802548e-08, 'epoch': 4.0})"
1261
+ ]
1262
+ },
1263
+ "execution_count": 30,
1264
+ "metadata": {},
1265
+ "output_type": "execute_result"
1266
+ }
1267
+ ],
1268
+ "source": [
1269
+ "trainer.train(\"checkpoint-4000\")"
1270
+ ]
1271
  }
1272
  ],
1273
  "metadata": {
1274
  "kernelspec": {
1275
+ "display_name": "wspsr",
1276
  "language": "python",
1277
+ "name": "wspsr"
1278
  },
1279
  "language_info": {
1280
  "codemirror_mode": {
fine-tune-whisper-streaming.ipynb CHANGED
@@ -108,7 +108,7 @@
108
  },
109
  {
110
  "cell_type": "code",
111
- "execution_count": 1,
112
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
113
  "metadata": {},
114
  "outputs": [],
@@ -142,7 +142,7 @@
142
  },
143
  {
144
  "cell_type": "code",
145
- "execution_count": 3,
146
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
147
  "metadata": {},
148
  "outputs": [],
@@ -151,7 +151,7 @@
151
  "\n",
152
  "raw_datasets = IterableDatasetDict()\n",
153
  "\n",
154
- "raw_datasets[\"train\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"train\", use_auth_token=True) # set split=\"train+validation\" for low-resource\n",
155
  "raw_datasets[\"test\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"test\", use_auth_token=True)"
156
  ]
157
  },
@@ -185,109 +185,10 @@
185
  },
186
  {
187
  "cell_type": "code",
188
- "execution_count": 4,
189
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
190
  "metadata": {},
191
- "outputs": [
192
- {
193
- "data": {
194
- "application/vnd.jupyter.widget-view+json": {
195
- "model_id": "ab8ef1fb2f284e2abd43a1b1bde55882",
196
- "version_major": 2,
197
- "version_minor": 0
198
- },
199
- "text/plain": [
200
- "Downloading: 0%| | 0.00/185k [00:00<?, ?B/s]"
201
- ]
202
- },
203
- "metadata": {},
204
- "output_type": "display_data"
205
- },
206
- {
207
- "data": {
208
- "application/vnd.jupyter.widget-view+json": {
209
- "model_id": "e0c2142f48224f1582e6457dbb8e5276",
210
- "version_major": 2,
211
- "version_minor": 0
212
- },
213
- "text/plain": [
214
- "Downloading: 0%| | 0.00/829 [00:00<?, ?B/s]"
215
- ]
216
- },
217
- "metadata": {},
218
- "output_type": "display_data"
219
- },
220
- {
221
- "data": {
222
- "application/vnd.jupyter.widget-view+json": {
223
- "model_id": "55aa8ea93e924389b339aefec864805d",
224
- "version_major": 2,
225
- "version_minor": 0
226
- },
227
- "text/plain": [
228
- "Downloading: 0%| | 0.00/1.04M [00:00<?, ?B/s]"
229
- ]
230
- },
231
- "metadata": {},
232
- "output_type": "display_data"
233
- },
234
- {
235
- "data": {
236
- "application/vnd.jupyter.widget-view+json": {
237
- "model_id": "5cc4483a4d234f73914d26f285588949",
238
- "version_major": 2,
239
- "version_minor": 0
240
- },
241
- "text/plain": [
242
- "Downloading: 0%| | 0.00/494k [00:00<?, ?B/s]"
243
- ]
244
- },
245
- "metadata": {},
246
- "output_type": "display_data"
247
- },
248
- {
249
- "data": {
250
- "application/vnd.jupyter.widget-view+json": {
251
- "model_id": "806dfeffeb1a4d6ba3a042cadee13450",
252
- "version_major": 2,
253
- "version_minor": 0
254
- },
255
- "text/plain": [
256
- "Downloading: 0%| | 0.00/52.7k [00:00<?, ?B/s]"
257
- ]
258
- },
259
- "metadata": {},
260
- "output_type": "display_data"
261
- },
262
- {
263
- "data": {
264
- "application/vnd.jupyter.widget-view+json": {
265
- "model_id": "b93cdf2091424615927adaefb032132f",
266
- "version_major": 2,
267
- "version_minor": 0
268
- },
269
- "text/plain": [
270
- "Downloading: 0%| | 0.00/2.11k [00:00<?, ?B/s]"
271
- ]
272
- },
273
- "metadata": {},
274
- "output_type": "display_data"
275
- },
276
- {
277
- "data": {
278
- "application/vnd.jupyter.widget-view+json": {
279
- "model_id": "cdb5621656934de2a60214f67530212c",
280
- "version_major": 2,
281
- "version_minor": 0
282
- },
283
- "text/plain": [
284
- "Downloading: 0%| | 0.00/2.06k [00:00<?, ?B/s]"
285
- ]
286
- },
287
- "metadata": {},
288
- "output_type": "display_data"
289
- }
290
- ],
291
  "source": [
292
  "from transformers import WhisperProcessor\n",
293
  "\n",
@@ -312,7 +213,7 @@
312
  },
313
  {
314
  "cell_type": "code",
315
- "execution_count": 5,
316
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
317
  "metadata": {},
318
  "outputs": [
@@ -332,7 +233,7 @@
332
  " 'segment': Value(dtype='string', id=None)}"
333
  ]
334
  },
335
- "execution_count": 5,
336
  "metadata": {},
337
  "output_type": "execute_result"
338
  }
@@ -358,7 +259,7 @@
358
  },
359
  {
360
  "cell_type": "code",
361
- "execution_count": 6,
362
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
363
  "metadata": {},
364
  "outputs": [],
@@ -378,7 +279,7 @@
378
  },
379
  {
380
  "cell_type": "code",
381
- "execution_count": 7,
382
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
383
  "metadata": {},
384
  "outputs": [],
@@ -405,7 +306,7 @@
405
  },
406
  {
407
  "cell_type": "code",
408
- "execution_count": 8,
409
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
410
  "metadata": {},
411
  "outputs": [],
@@ -441,7 +342,7 @@
441
  },
442
  {
443
  "cell_type": "code",
444
- "execution_count": 9,
445
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
446
  "metadata": {},
447
  "outputs": [],
@@ -459,7 +360,7 @@
459
  },
460
  {
461
  "cell_type": "code",
462
- "execution_count": 10,
463
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
464
  "metadata": {},
465
  "outputs": [],
@@ -480,7 +381,7 @@
480
  },
481
  {
482
  "cell_type": "code",
483
- "execution_count": 11,
484
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
485
  "metadata": {},
486
  "outputs": [],
@@ -501,7 +402,7 @@
501
  },
502
  {
503
  "cell_type": "code",
504
- "execution_count": 12,
505
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
506
  "metadata": {},
507
  "outputs": [],
@@ -571,7 +472,7 @@
571
  },
572
  {
573
  "cell_type": "code",
574
- "execution_count": 13,
575
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5",
576
  "metadata": {},
577
  "outputs": [],
@@ -619,7 +520,7 @@
619
  },
620
  {
621
  "cell_type": "code",
622
- "execution_count": 14,
623
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42",
624
  "metadata": {},
625
  "outputs": [],
@@ -646,14 +547,14 @@
646
  },
647
  {
648
  "cell_type": "code",
649
- "execution_count": 15,
650
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
651
  "metadata": {},
652
  "outputs": [
653
  {
654
  "data": {
655
  "application/vnd.jupyter.widget-view+json": {
656
- "model_id": "737faa61d325424ba4b395c4aeb9a58f",
657
  "version_major": 2,
658
  "version_minor": 0
659
  },
@@ -690,7 +591,7 @@
690
  },
691
  {
692
  "cell_type": "code",
693
- "execution_count": 16,
694
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb",
695
  "metadata": {},
696
  "outputs": [],
@@ -740,14 +641,14 @@
740
  },
741
  {
742
  "cell_type": "code",
743
- "execution_count": 17,
744
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
745
  "metadata": {},
746
  "outputs": [
747
  {
748
  "data": {
749
  "application/vnd.jupyter.widget-view+json": {
750
- "model_id": "a2e3bfa2e47241f193b2e9be28e184b4",
751
  "version_major": 2,
752
  "version_minor": 0
753
  },
@@ -761,7 +662,7 @@
761
  {
762
  "data": {
763
  "application/vnd.jupyter.widget-view+json": {
764
- "model_id": "22bccf0b5b46459ebfbdb17f4641b800",
765
  "version_major": 2,
766
  "version_minor": 0
767
  },
@@ -789,7 +690,7 @@
789
  },
790
  {
791
  "cell_type": "code",
792
- "execution_count": 18,
793
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f",
794
  "metadata": {},
795
  "outputs": [],
@@ -817,7 +718,7 @@
817
  },
818
  {
819
  "cell_type": "code",
820
- "execution_count": 19,
821
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
822
  "metadata": {},
823
  "outputs": [],
@@ -829,16 +730,16 @@
829
  " per_device_train_batch_size=64,\n",
830
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
831
  " learning_rate=1e-5,\n",
832
- " warmup_steps=500,\n",
833
- " max_steps=5000,\n",
834
  " gradient_checkpointing=True,\n",
835
  " fp16=True,\n",
836
  " evaluation_strategy=\"steps\",\n",
837
  " per_device_eval_batch_size=8,\n",
838
  " predict_with_generate=True,\n",
839
  " generation_max_length=225,\n",
840
- " save_steps=1000,\n",
841
- " eval_steps=1000,\n",
842
  " logging_steps=25,\n",
843
  " report_to=[\"tensorboard\"],\n",
844
  " load_best_model_at_end=True,\n",
@@ -867,7 +768,7 @@
867
  },
868
  {
869
  "cell_type": "code",
870
- "execution_count": 20,
871
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
872
  "metadata": {},
873
  "outputs": [],
@@ -896,7 +797,7 @@
896
  },
897
  {
898
  "cell_type": "code",
899
- "execution_count": 21,
900
  "id": "d546d7fe-0543-479a-b708-2ebabec19493",
901
  "metadata": {},
902
  "outputs": [
@@ -935,7 +836,7 @@
935
  },
936
  {
937
  "cell_type": "code",
938
- "execution_count": 22,
939
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672",
940
  "metadata": {},
941
  "outputs": [
@@ -992,14 +893,15 @@
992
  "/home/ubuntu/.venv/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
993
  " warnings.warn(\n",
994
  "***** Running training *****\n",
995
- " Num examples = 320000\n",
996
  " Num Epochs = 9223372036854775807\n",
997
  " Instantaneous batch size per device = 64\n",
998
  " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
999
  " Gradient Accumulation steps = 1\n",
1000
- " Total optimization steps = 5000\n",
1001
  " Number of trainable parameters = 241734912\n",
1002
- "Reading metadata...: 6505it [00:00, 21991.51it/s]\n",
 
1003
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
1004
  ]
1005
  },
@@ -1009,8 +911,8 @@
1009
  "\n",
1010
  " <div>\n",
1011
  " \n",
1012
- " <progress value='2231' max='5000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1013
- " [2231/5000 6:49:21 < 8:28:31, 0.09 it/s, Epoch 21.02/9223372036854775807]\n",
1014
  " </div>\n",
1015
  " <table border=\"1\" class=\"dataframe\">\n",
1016
  " <thead>\n",
@@ -1018,22 +920,9 @@
1018
  " <th>Step</th>\n",
1019
  " <th>Training Loss</th>\n",
1020
  " <th>Validation Loss</th>\n",
1021
- " <th>Wer</th>\n",
1022
  " </tr>\n",
1023
  " </thead>\n",
1024
  " <tbody>\n",
1025
- " <tr>\n",
1026
- " <td>1000</td>\n",
1027
- " <td>0.006600</td>\n",
1028
- " <td>0.468024</td>\n",
1029
- " <td>90.537665</td>\n",
1030
- " </tr>\n",
1031
- " <tr>\n",
1032
- " <td>2000</td>\n",
1033
- " <td>0.003000</td>\n",
1034
- " <td>0.512834</td>\n",
1035
- " <td>89.360193</td>\n",
1036
- " </tr>\n",
1037
  " </tbody>\n",
1038
  "</table><p>"
1039
  ],
@@ -1048,11 +937,13 @@
1048
  "name": "stderr",
1049
  "output_type": "stream",
1050
  "text": [
1051
- "Reading metadata...: 6505it [00:00, 24574.42it/s]\n",
1052
- "Reading metadata...: 6505it [00:00, 24420.15it/s]\n",
1053
- "Reading metadata...: 6505it [00:00, 36254.85it/s]\n",
1054
- "Reading metadata...: 6505it [00:00, 30794.80it/s]\n",
1055
- "Reading metadata...: 6505it [00:00, 27712.44it/s]\n"
 
 
1056
  ]
1057
  }
1058
  ],
@@ -1381,9 +1272,9 @@
1381
  ],
1382
  "metadata": {
1383
  "kernelspec": {
1384
- "display_name": "hf",
1385
  "language": "python",
1386
- "name": "hf"
1387
  },
1388
  "language_info": {
1389
  "codemirror_mode": {
 
108
  },
109
  {
110
  "cell_type": "code",
111
+ "execution_count": 5,
112
  "id": "065a8cf7-e54f-4ac3-900e-609c80714fca",
113
  "metadata": {},
114
  "outputs": [],
 
142
  },
143
  {
144
  "cell_type": "code",
145
+ "execution_count": 6,
146
  "id": "a2787582-554f-44ce-9f38-4180a5ed6b44",
147
  "metadata": {},
148
  "outputs": [],
 
151
  "\n",
152
  "raw_datasets = IterableDatasetDict()\n",
153
  "\n",
154
+ "raw_datasets[\"train\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"train+validation\", use_auth_token=True) # set split=\"train+validation\" for low-resource\n",
155
  "raw_datasets[\"test\"] = load_streaming_dataset(\"mozilla-foundation/common_voice_11_0\", \"ja\", split=\"test\", use_auth_token=True)"
156
  ]
157
  },
 
185
  },
186
  {
187
  "cell_type": "code",
188
+ "execution_count": 7,
189
  "id": "77d9f0c5-8607-4642-a8ac-c3ab2e223ea6",
190
  "metadata": {},
191
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  "source": [
193
  "from transformers import WhisperProcessor\n",
194
  "\n",
 
213
  },
214
  {
215
  "cell_type": "code",
216
+ "execution_count": 8,
217
  "id": "ab5a13b4-9bd4-4aa0-aef2-b3de9b762988",
218
  "metadata": {},
219
  "outputs": [
 
233
  " 'segment': Value(dtype='string', id=None)}"
234
  ]
235
  },
236
+ "execution_count": 8,
237
  "metadata": {},
238
  "output_type": "execute_result"
239
  }
 
259
  },
260
  {
261
  "cell_type": "code",
262
+ "execution_count": 9,
263
  "id": "3ab6a724-3d1e-478b-a9e9-d2f85feb6c39",
264
  "metadata": {},
265
  "outputs": [],
 
279
  },
280
  {
281
  "cell_type": "code",
282
+ "execution_count": 10,
283
  "id": "d041650e-1c48-4439-87b3-5b6f4a514107",
284
  "metadata": {},
285
  "outputs": [],
 
306
  },
307
  {
308
  "cell_type": "code",
309
+ "execution_count": 11,
310
  "id": "c085911c-a10a-41ef-8874-306e0503e9bb",
311
  "metadata": {},
312
  "outputs": [],
 
342
  },
343
  {
344
  "cell_type": "code",
345
+ "execution_count": 12,
346
  "id": "a37a7cdb-9013-427f-8de9-6a8d0e9dc684",
347
  "metadata": {},
348
  "outputs": [],
 
360
  },
361
  {
362
  "cell_type": "code",
363
+ "execution_count": 13,
364
  "id": "1b145699-acfc-4b1d-93a2-a2ad3d62674c",
365
  "metadata": {},
366
  "outputs": [],
 
381
  },
382
  {
383
  "cell_type": "code",
384
+ "execution_count": 14,
385
  "id": "01cb25ef-4bb0-4325-9461-f59198acadf6",
386
  "metadata": {},
387
  "outputs": [],
 
402
  },
403
  {
404
  "cell_type": "code",
405
+ "execution_count": 15,
406
  "id": "333f7f6e-6053-4d3b-8924-c733c79b82ac",
407
  "metadata": {},
408
  "outputs": [],
 
472
  },
473
  {
474
  "cell_type": "code",
475
+ "execution_count": 16,
476
  "id": "8326221e-ec13-4731-bb4e-51e5fc1486c5",
477
  "metadata": {},
478
  "outputs": [],
 
520
  },
521
  {
522
  "cell_type": "code",
523
+ "execution_count": 17,
524
  "id": "fc834702-c0d3-4a96-b101-7b87be32bf42",
525
  "metadata": {},
526
  "outputs": [],
 
547
  },
548
  {
549
  "cell_type": "code",
550
+ "execution_count": 18,
551
  "id": "b22b4011-f31f-4b57-b684-c52332f92890",
552
  "metadata": {},
553
  "outputs": [
554
  {
555
  "data": {
556
  "application/vnd.jupyter.widget-view+json": {
557
+ "model_id": "bffdd7b1fed44295954d9eed41a9cfd5",
558
  "version_major": 2,
559
  "version_minor": 0
560
  },
 
591
  },
592
  {
593
  "cell_type": "code",
594
+ "execution_count": 19,
595
  "id": "a11d1bfc-9e28-460f-a287-72d8f7bc1acb",
596
  "metadata": {},
597
  "outputs": [],
 
641
  },
642
  {
643
  "cell_type": "code",
644
+ "execution_count": 20,
645
  "id": "5a10cc4b-07ec-4ebd-ac1d-7c601023594f",
646
  "metadata": {},
647
  "outputs": [
648
  {
649
  "data": {
650
  "application/vnd.jupyter.widget-view+json": {
651
+ "model_id": "48fee2fd3b2a4a67b3a35666fda4dfe9",
652
  "version_major": 2,
653
  "version_minor": 0
654
  },
 
662
  {
663
  "data": {
664
  "application/vnd.jupyter.widget-view+json": {
665
+ "model_id": "51cdba284e8f44318868fbd013970280",
666
  "version_major": 2,
667
  "version_minor": 0
668
  },
 
690
  },
691
  {
692
  "cell_type": "code",
693
+ "execution_count": 21,
694
  "id": "62038ba3-88ed-4fce-84db-338f50dcd04f",
695
  "metadata": {},
696
  "outputs": [],
 
718
  },
719
  {
720
  "cell_type": "code",
721
+ "execution_count": 22,
722
  "id": "0ae3e9af-97b7-4aa0-ae85-20b23b5bcb3a",
723
  "metadata": {},
724
  "outputs": [],
 
730
  " per_device_train_batch_size=64,\n",
731
  " gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size\n",
732
  " learning_rate=1e-5,\n",
733
+ " warmup_steps=200,\n",
734
+ " max_steps=1000,\n",
735
  " gradient_checkpointing=True,\n",
736
  " fp16=True,\n",
737
  " evaluation_strategy=\"steps\",\n",
738
  " per_device_eval_batch_size=8,\n",
739
  " predict_with_generate=True,\n",
740
  " generation_max_length=225,\n",
741
+ " save_steps=200,\n",
742
+ " eval_steps=200,\n",
743
  " logging_steps=25,\n",
744
  " report_to=[\"tensorboard\"],\n",
745
  " load_best_model_at_end=True,\n",
 
768
  },
769
  {
770
  "cell_type": "code",
771
+ "execution_count": 23,
772
  "id": "3ac16b62-b3c0-4c68-8f3d-9ecf471534b2",
773
  "metadata": {},
774
  "outputs": [],
 
797
  },
798
  {
799
  "cell_type": "code",
800
+ "execution_count": 24,
801
  "id": "d546d7fe-0543-479a-b708-2ebabec19493",
802
  "metadata": {},
803
  "outputs": [
 
836
  },
837
  {
838
  "cell_type": "code",
839
+ "execution_count": 25,
840
  "id": "a1ccb9ed-cbc8-4419-91c0-651e9424b672",
841
  "metadata": {},
842
  "outputs": [
 
893
  "/home/ubuntu/.venv/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
894
  " warnings.warn(\n",
895
  "***** Running training *****\n",
896
+ " Num examples = 64000\n",
897
  " Num Epochs = 9223372036854775807\n",
898
  " Instantaneous batch size per device = 64\n",
899
  " Total train batch size (w. parallel, distributed & accumulation) = 64\n",
900
  " Gradient Accumulation steps = 1\n",
901
+ " Total optimization steps = 1000\n",
902
  " Number of trainable parameters = 241734912\n",
903
+ "Reading metadata...: 6505it [00:00, 31331.40it/s]\n",
904
+ "Reading metadata...: 4485it [00:00, 41376.86it/s]\n",
905
  "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
906
  ]
907
  },
 
911
  "\n",
912
  " <div>\n",
913
  " \n",
914
+ " <progress value='201' max='1000' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
915
+ " [ 201/1000 22:31 < 1:30:27, 0.15 it/s, Epoch 1.06/9223372036854775807]\n",
916
  " </div>\n",
917
  " <table border=\"1\" class=\"dataframe\">\n",
918
  " <thead>\n",
 
920
  " <th>Step</th>\n",
921
  " <th>Training Loss</th>\n",
922
  " <th>Validation Loss</th>\n",
 
923
  " </tr>\n",
924
  " </thead>\n",
925
  " <tbody>\n",
 
 
 
 
 
 
 
 
 
 
 
 
926
  " </tbody>\n",
927
  "</table><p>"
928
  ],
 
937
  "name": "stderr",
938
  "output_type": "stream",
939
  "text": [
940
+ "Reading metadata...: 6505it [00:00, 64162.65it/s]\n",
941
+ "Reading metadata...: 4485it [00:00, 27834.06it/s]\n",
942
+ "***** Running Evaluation *****\n",
943
+ " Num examples: Unknown\n",
944
+ " Batch size = 8\n",
945
+ "Reading metadata...: 4604it [00:00, 27155.92it/s]\n",
946
+ "The following columns in the evaluation set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: input_length. If input_length are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n"
947
  ]
948
  }
949
  ],
 
1272
  ],
1273
  "metadata": {
1274
  "kernelspec": {
1275
+ "display_name": "wspsr",
1276
  "language": "python",
1277
+ "name": "wspsr"
1278
  },
1279
  "language_info": {
1280
  "codemirror_mode": {
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b3246529f086b22124c7901ea81e50c3e83cfe22009b2ee44ddc94f5bea88d86
3
  size 967102601
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56c4b0bb4897d70e1953cf26927fc51e19cecc3225658657daa32a0c0d1e1cb0
3
  size 967102601
runs/Dec12_04-37-47_150-136-44-233/1670819878.783822/events.out.tfevents.1670819878.150-136-44-233.69039.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d13318210207986e1d4965c6206a303c2bcd72da40d33ba3b859c8e3111cf764
3
+ size 5864
runs/Dec12_04-37-47_150-136-44-233/events.out.tfevents.1670819878.150-136-44-233.69039.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da946657c9377166580c41662af45f086a478b101a9862b40eb5174e55e6f75a
3
+ size 5844
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b15a0138008e0c490133c10ef48941adc9502d5f778b86dcc1d39f32d25062dc
3
  size 3579
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:728d6cd7b154a86029fc38c737217977eb35dd910ed073d6628129742d876d7e
3
  size 3579