"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import IPython.display as ipd\n",
"import numpy as np\n",
"import random\n",
"\n",
"rand_int = random.randint(0, len(all)-1)\n",
"\n",
"print(all[rand_int][\"sentence\"])\n",
"ipd.Audio(data=all[rand_int][\"audio\"][\"array\"], autoplay=True, rate=16000)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "5607b522",
"metadata": {
"id": "eJY7I0XAwe9p"
},
"outputs": [],
"source": [
"def prepare_dataset(batch):\n",
" audio = batch[\"audio\"]\n",
"\n",
" # batched output is \"un-batched\"\n",
" batch[\"input_values\"] = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n",
" batch[\"input_length\"] = len(batch[\"input_values\"])\n",
" \n",
" with processor.as_target_processor():\n",
" batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "00e34422",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 81,
"referenced_widgets": [
"c47ea368dd08403aa09b2bafdbb4b580",
"e77cf973d5824ae7b89bafd814805c2a",
"071b7647e1fe49609a48e4281a9efd0f",
"c97c00fcf2e64f18b637337f9244d748",
"9ca82fa27d1043e9ac9f10301e0b33bc",
"cc6c7e9931c140db8ba7a977c4461ce5",
"d207784bda7e4dd8858170f470ae2833",
"0800fef7de6e45d380873f974882d67e",
"926440595aa44c698588e02b86eb8c4c",
"ea2806c776384f1a90e36b72c2c17a44",
"6b72385c07134782995fcd76e675da7c",
"3653b92c9f2a408eac253e1d5153daf4",
"73ffd9b8166c4ec78ff2b62d17690327",
"6b133a1e11e44f68846ff931446559cf",
"7c98818547c84af7ba9284bc20101691",
"41b501a16b2a4f709197af5cdd5227cb",
"3b4fbe2916894e48b8f93ca63e203aca",
"c002386685c0413d8181b054d3f9d49f",
"cfb70829b5e1461abcb01872b74a194c",
"ed943db2b5274022a606ce4103d54425",
"cfb242eb549c4e66afcedefb575b4e38",
"a0313055d29f4a60837e59ac4d8a3870"
]
},
"id": "-np9xYK-wl8q",
"outputId": "573f6f67-e5b2-4977-a564-3919e7903592"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "45634e805c76453a94ffac3f287df33f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"0ex [00:00, ?ex/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading cached processed dataset at /workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/cs/8.0.0/b8bc4d453193c06a43269b46cd87f075c70f152ac963b7f28f7a2760c45ec3e8/cache-fcc378c48562cf8c.arrow\n"
]
}
],
"source": [
"all = all.map(prepare_dataset, remove_columns=all.column_names)\n",
"common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "4207ffd0",
"metadata": {
"id": "tborvC9hx88e"
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"from dataclasses import dataclass, field\n",
"from typing import Any, Dict, List, Optional, Union\n",
"\n",
"@dataclass\n",
"class DataCollatorCTCWithPadding:\n",
" \"\"\"\n",
" Data collator that will dynamically pad the inputs received.\n",
" Args:\n",
" processor (:class:`~transformers.Wav2Vec2Processor`)\n",
" The processor used for proccessing the data.\n",
" padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):\n",
" Select a strategy to pad the returned sequences (according to the model's padding side and padding index)\n",
" among:\n",
" * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single\n",
" sequence if provided).\n",
" * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the\n",
" maximum acceptable input length for the model if that argument is not provided.\n",
" * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of\n",
" different lengths).\n",
" \"\"\"\n",
"\n",
" processor: Wav2Vec2Processor\n",
" padding: Union[bool, str] = True\n",
"\n",
" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
" # split inputs and labels since they have to be of different lenghts and need\n",
" # different padding methods\n",
" input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n",
" label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
"\n",
" batch = self.processor.pad(\n",
" input_features,\n",
" padding=self.padding,\n",
" return_tensors=\"pt\",\n",
" )\n",
" with self.processor.as_target_processor():\n",
" labels_batch = self.processor.pad(\n",
" label_features,\n",
" padding=self.padding,\n",
" return_tensors=\"pt\",\n",
" )\n",
"\n",
" # replace padding with -100 to ignore loss correctly\n",
" labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
"\n",
" batch[\"labels\"] = labels\n",
"\n",
" return batch"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "2b346d91",
"metadata": {
"id": "lbQf5GuZyQ4_"
},
"outputs": [],
"source": [
"data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "3afc8d2a",
"metadata": {
"id": "9Xsux2gmyXso"
},
"outputs": [],
"source": [
"from datasets import load_metric\n",
"\n",
"wer_metric = load_metric(\"wer\")\n",
"cer_metric = load_metric(\"cer\")\n",
"metrics = [wer_metric, cer_metric]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "9119abc6",
"metadata": {
"id": "1XZ-kjweyTy_"
},
"outputs": [],
"source": [
"def compute_metrics(pred):\n",
" pred_logits = pred.predictions\n",
" pred_ids = np.argmax(pred_logits, axis=-1)\n",
"\n",
" pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n",
"\n",
" pred_str = processor.batch_decode(pred_ids)\n",
" # we do not want to group tokens when computing the metrics\n",
" label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n",
"\n",
" wer = wer_metric.compute(predictions=pred_str, references=label_str)\n",
" cer = cer_metric.compute(predictions=pred_str, references=label_str)\n",
"\n",
" return {\"wer\": wer, \"cer\": cer}"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "172587ca",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e7cqAWIayn6w",
"outputId": "7a7ef020-bc8f-41e2-846c-645be598312e"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at facebook/wav2vec2-xls-r-300m were not used when initializing Wav2Vec2ForCTC: ['project_hid.bias', 'project_q.weight', 'quantizer.weight_proj.bias', 'project_hid.weight', 'quantizer.weight_proj.weight', 'project_q.bias', 'quantizer.codevectors']\n",
"- This IS expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing Wav2Vec2ForCTC from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-xls-r-300m and are newly initialized: ['lm_head.weight', 'lm_head.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import Wav2Vec2ForCTC\n",
"\n",
"model = Wav2Vec2ForCTC.from_pretrained(\n",
" #\"comodoro/wav2vec2-xls-r-300m-cs-cv8\", \n",
" \"facebook/wav2vec2-xls-r-300m\", \n",
" attention_dropout=0.1,\n",
" hidden_dropout=0.2,\n",
" feat_proj_dropout=0.0,\n",
" mask_time_prob=0.1,\n",
" layerdrop=0.1,\n",
" ctc_loss_reduction=\"mean\", \n",
" pad_token_id=processor.tokenizer.pad_token_id,\n",
" vocab_size=len(processor.tokenizer),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "94625b79",
"metadata": {
"id": "oGI8zObtZ3V0"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/workspace/.local/lib/python3.8/site-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:1700: FutureWarning: The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5.Please use the equivalent `freeze_feature_encoder` method instead.\n",
" warnings.warn(\n"
]
}
],
"source": [
"model.freeze_feature_extractor()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "9c173ad4",
"metadata": {
"id": "KbeKSV7uzGPP"
},
"outputs": [],
"source": [
"from transformers import TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=repo_name,\n",
" group_by_length=True,\n",
" per_device_train_batch_size=32,\n",
" gradient_accumulation_steps=1,\n",
" eval_accumulation_steps=1,\n",
" evaluation_strategy=\"steps\",\n",
" num_train_epochs=5,\n",
" gradient_checkpointing=True,\n",
" fp16=True,\n",
" save_steps=800,\n",
" eval_steps=800,\n",
" logging_steps=250,\n",
" learning_rate=1e-4,\n",
" warmup_steps=800,\n",
" save_total_limit=2,\n",
" report_to=\"tensorboard\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "38cc611b",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rY7vBmFCPFgC",
"outputId": "a180bf3f-f798-4947-ff58-207d7aaab695"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using amp half precision backend\n"
]
}
],
"source": [
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" data_collator=data_collator,\n",
" args=training_args,\n",
" compute_metrics=compute_metrics,\n",
" train_dataset=all,\n",
" eval_dataset=common_voice_test,\n",
" tokenizer=processor.feature_extractor,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "ab7b22fa",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 312
},
"id": "9fRr9TG5pGBl",
"outputId": "8bdf1d11-bca1-46af-db67-518f85586f7a"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The following columns in the training set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"/workspace/.local/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n",
"***** Running training *****\n",
" Num examples = 159605\n",
" Num Epochs = 5\n",
" Instantaneous batch size per device = 32\n",
" Total train batch size (w. parallel, distributed & accumulation) = 32\n",
" Gradient Accumulation steps = 1\n",
" Total optimization steps = 24940\n"
]
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
"
\n",
" [24940/24940 14:54:01, Epoch 5/5]\n",
"
\n",
" \n",
" \n",
" \n",
" Step | \n",
" Training Loss | \n",
" Validation Loss | \n",
" Wer | \n",
" Cer | \n",
"
\n",
" \n",
" \n",
" \n",
" 800 | \n",
" 3.420300 | \n",
" 3.314820 | \n",
" 1.000000 | \n",
" 1.000000 | \n",
"
\n",
" \n",
" 1600 | \n",
" 2.815100 | \n",
" 0.850840 | \n",
" 0.893788 | \n",
" 0.234479 | \n",
"
\n",
" \n",
" 2400 | \n",
" 0.941100 | \n",
" 0.333538 | \n",
" 0.372315 | \n",
" 0.084735 | \n",
"
\n",
" \n",
" 3200 | \n",
" 0.740800 | \n",
" 0.257277 | \n",
" 0.283963 | \n",
" 0.064233 | \n",
"
\n",
" \n",
" 4000 | \n",
" 0.651600 | \n",
" 0.236474 | \n",
" 0.258103 | \n",
" 0.059464 | \n",
"
\n",
" \n",
" 4800 | \n",
" 0.624200 | \n",
" 0.203933 | \n",
" 0.243332 | \n",
" 0.054062 | \n",
"
\n",
" \n",
" 5600 | \n",
" 0.575400 | \n",
" 0.183210 | \n",
" 0.215611 | \n",
" 0.048234 | \n",
"
\n",
" \n",
" 6400 | \n",
" 0.562600 | \n",
" 0.182699 | \n",
" 0.209116 | \n",
" 0.046281 | \n",
"
\n",
" \n",
" 7200 | \n",
" 0.534200 | \n",
" 0.174398 | \n",
" 0.203315 | \n",
" 0.046776 | \n",
"
\n",
" \n",
" 8000 | \n",
" 0.496500 | \n",
" 0.170528 | \n",
" 0.196285 | \n",
" 0.044429 | \n",
"
\n",
" \n",
" 8800 | \n",
" 0.504700 | \n",
" 0.160374 | \n",
" 0.188880 | \n",
" 0.042167 | \n",
"
\n",
" \n",
" 9600 | \n",
" 0.481400 | \n",
" 0.160427 | \n",
" 0.182742 | \n",
" 0.041052 | \n",
"
\n",
" \n",
" 10400 | \n",
" 0.447100 | \n",
" 0.156585 | \n",
" 0.182207 | \n",
" 0.040592 | \n",
"
\n",
" \n",
" 11200 | \n",
" 0.450900 | \n",
" 0.161888 | \n",
" 0.185296 | \n",
" 0.043243 | \n",
"
\n",
" \n",
" 12000 | \n",
" 0.441500 | \n",
" 0.151254 | \n",
" 0.176386 | \n",
" 0.039725 | \n",
"
\n",
" \n",
" 12800 | \n",
" 0.431300 | \n",
" 0.151478 | \n",
" 0.173930 | \n",
" 0.039213 | \n",
"
\n",
" \n",
" 13600 | \n",
" 0.416300 | \n",
" 0.144519 | \n",
" 0.169515 | \n",
" 0.037672 | \n",
"
\n",
" \n",
" 14400 | \n",
" 0.414200 | \n",
" 0.147759 | \n",
" 0.169871 | \n",
" 0.038473 | \n",
"
\n",
" \n",
" 15200 | \n",
" 0.418400 | \n",
" 0.143047 | \n",
" 0.166921 | \n",
" 0.037583 | \n",
"
\n",
" \n",
" 16000 | \n",
" 0.388600 | \n",
" 0.143273 | \n",
" 0.164426 | \n",
" 0.037388 | \n",
"
\n",
" \n",
" 16800 | \n",
" 0.379500 | \n",
" 0.142606 | \n",
" 0.164822 | \n",
" 0.037258 | \n",
"
\n",
" \n",
" 17600 | \n",
" 0.385900 | \n",
" 0.135660 | \n",
" 0.160446 | \n",
" 0.036143 | \n",
"
\n",
" \n",
" 18400 | \n",
" 0.376200 | \n",
" 0.134396 | \n",
" 0.155832 | \n",
" 0.034930 | \n",
"
\n",
" \n",
" 19200 | \n",
" 0.384000 | \n",
" 0.137933 | \n",
" 0.157595 | \n",
" 0.035875 | \n",
"
\n",
" \n",
" 20000 | \n",
" 0.376200 | \n",
" 0.134363 | \n",
" 0.153892 | \n",
" 0.034552 | \n",
"
\n",
" \n",
" 20800 | \n",
" 0.355900 | \n",
" 0.133945 | \n",
" 0.152526 | \n",
" 0.035080 | \n",
"
\n",
" \n",
" 21600 | \n",
" 0.368300 | \n",
" 0.131489 | \n",
" 0.151753 | \n",
" 0.034226 | \n",
"
\n",
" \n",
" 22400 | \n",
" 0.357200 | \n",
" 0.130721 | \n",
" 0.150664 | \n",
" 0.034154 | \n",
"
\n",
" \n",
" 23200 | \n",
" 0.349400 | \n",
" 0.129449 | \n",
" 0.149100 | \n",
" 0.033476 | \n",
"
\n",
" \n",
" 24000 | \n",
" 0.347600 | \n",
" 0.128731 | \n",
" 0.149120 | \n",
" 0.033597 | \n",
"
\n",
" \n",
" 24800 | \n",
" 0.347500 | \n",
" 0.127098 | \n",
" 0.147457 | \n",
" 0.032896 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-800\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-800/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-800/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-800/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-30400] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-1600\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-1600/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-1600/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-1600/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-31200] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-2400\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-2400/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-2400/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-2400/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-800] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-3200\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-3200/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-3200/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-3200/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-1600] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-4000\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-4000/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-4000/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-4000/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-2400] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-4800\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-4800/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-4800/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-4800/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-3200] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-5600\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-5600/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-5600/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-5600/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-4000] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-6400\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-6400/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-6400/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-6400/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-4800] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-7200\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-7200/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-7200/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-7200/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-5600] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-8000\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-8000/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-8000/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-8000/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-6400] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-8800\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-8800/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-8800/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-8800/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-7200] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-9600\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-9600/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-9600/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-9600/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-8000] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-10400\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-10400/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-10400/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-10400/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-8800] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-11200\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-11200/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-11200/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-11200/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-9600] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-12000\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-12000/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-12000/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-12000/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-10400] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-12800\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-12800/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-12800/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-12800/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-11200] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-13600\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-13600/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-13600/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-13600/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-12000] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-14400\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-14400/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-14400/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-14400/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-12800] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-15200\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-15200/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-15200/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-15200/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-13600] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-16000\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-16000/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-16000/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-16000/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-14400] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-16800\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-16800/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-16800/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-16800/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-15200] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-17600\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-17600/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-17600/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-17600/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-16000] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-18400\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-18400/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-18400/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-18400/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-16800] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-19200\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-19200/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-19200/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-19200/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-17600] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-20000\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-20000/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-20000/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-20000/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-18400] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-20800\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-20800/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-20800/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-20800/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-19200] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-21600\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-21600/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-21600/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-21600/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-20000] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-22400\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-22400/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-22400/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-22400/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-20800] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-23200\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-23200/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-23200/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-23200/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-21600] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-24000\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-24000/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-24000/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-24000/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-22400] due to args.save_total_limit\n",
"The following columns in the evaluation set don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length.\n",
"***** Running Evaluation *****\n",
" Num examples = 7267\n",
" Batch size = 8\n",
"Saving model checkpoint to wav2vec2-xls-r-300m-cs-250/checkpoint-24800\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-24800/config.json\n",
"Model weights saved in wav2vec2-xls-r-300m-cs-250/checkpoint-24800/pytorch_model.bin\n",
"Configuration saved in wav2vec2-xls-r-300m-cs-250/checkpoint-24800/preprocessor_config.json\n",
"Deleting older checkpoint [wav2vec2-xls-r-300m-cs-250/checkpoint-23200] due to args.save_total_limit\n",
"\n",
"\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\n",
"\n"
]
},
{
"data": {
"text/plain": [
"TrainOutput(global_step=24940, training_loss=0.7262697845434511, metrics={'train_runtime': 53649.1292, 'train_samples_per_second': 14.875, 'train_steps_per_second': 0.465, 'total_flos': 1.1982083586402645e+20, 'train_loss': 0.7262697845434511, 'epoch': 5.0})"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "6610bd0a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Dropping the following result as it does not have all the necessary fields:\n",
"{}\n"
]
}
],
"source": [
"trainer.create_model_card()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "ec5a5334",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Configuration saved in ./config.json\n",
"Model weights saved in ./pytorch_model.bin\n"
]
}
],
"source": [
"model.save_pretrained('.')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d3d92c1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}