{ "cells": [ { "cell_type": "code", "execution_count": 15, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:15.328900Z", "start_time": "2021-03-14T18:07:15.326838Z" } }, "outputs": [], "source": [ "from transformers import Wav2Vec2ForCTC\n", "from transformers import Wav2Vec2Processor\n", "from datasets import load_dataset, load_metric\n", "import re\n", "import torchaudio\n", "import librosa\n", "import numpy as np\n", "from datasets import load_dataset, load_metric\n", "import torch" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:15.933957Z", "start_time": "2021-03-14T18:07:15.927789Z" } }, "outputs": [], "source": [ "chars_to_ignore_regex = '[\\,\\?\\.\\!\\-\\;\\:\\\"\\“\\%\\‘\\”\\�]'\n", "\n", "def remove_special_characters(batch):\n", " batch[\"text\"] = re.sub(chars_to_ignore_regex, '', batch[\"sentence\"]).lower() + \" \"\n", " return batch\n", "\n", "def speech_file_to_array_fn(batch):\n", " speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n", " batch[\"speech\"] = speech_array[0].numpy()\n", " batch[\"sampling_rate\"] = sampling_rate\n", " batch[\"target_text\"] = batch[\"text\"]\n", " return batch\n", "\n", "def resample(batch):\n", " batch[\"speech\"] = librosa.resample(np.asarray(batch[\"speech\"]), 48_000, 16_000)\n", " batch[\"sampling_rate\"] = 16_000\n", " return batch\n", "\n", "def prepare_dataset(batch):\n", " # check that all files have the correct sampling rate\n", " assert (\n", " len(set(batch[\"sampling_rate\"])) == 1\n", " ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n", "\n", " batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n", " \n", " with processor.as_target_processor():\n", " batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:22.624226Z", "start_time": "2021-03-14T18:07:16.402381Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.\n" ] } ], "source": [ "model = Wav2Vec2ForCTC.from_pretrained(\"wav2vec2-large-xlsr-greek/checkpoint-18400/\").to(\"cuda\")\n", "processor = Wav2Vec2Processor.from_pretrained(\"wav2vec2-large-xlsr-greek/\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:25.473609Z", "start_time": "2021-03-14T18:07:22.644765Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-afd0a157f05ee080\n", "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n" ] } ], "source": [ "common_voice_test = load_dataset(\"common_voice\", \"el\", data_dir=\"cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:25.504511Z", "start_time": "2021-03-14T18:07:25.500688Z" } }, "outputs": [], "source": [ "common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:25.540666Z", "start_time": "2021-03-14T18:07:25.536214Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-0ce2ebca66096fff.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(remove_special_characters, remove_columns=[\"sentence\"])" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:25.578015Z", "start_time": "2021-03-14T18:07:25.568808Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-38a09981767eff59.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(speech_file_to_array_fn, remove_columns=common_voice_test.column_names)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:26.404914Z", "start_time": "2021-03-14T18:07:25.605177Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ba8c6dd59eb8ccf2.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-2e240883a5f827fd.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-485c00dc9048ed50.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-44bf1791baae8e2e.arrow\n", "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-ecc0dfac5615a58e.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-bb54bb00dae79669.arrow\n", "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-923d905502a8661d.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-062aeafc3b8816c1.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(resample, num_proc=8)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:27.032511Z", "start_time": "2021-03-14T18:07:26.432613Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-82be72eab73488a6.arrow\n", "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-a30edec53656694c.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-91aacc366ff3e776.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-cce8223f5c38f863.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-4f0d5b132b7516de.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-55caed3924d51e22.arrow\n", "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-2cc086daed2595be.arrow\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "Loading cached processed dataset at /home/earendil/.cache/huggingface/datasets/common_voice/el-afd0a157f05ee080/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564/cache-118401c99df7b83c.arrow\n" ] } ], "source": [ "common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=8, batched=True)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:29.428864Z", "start_time": "2021-03-14T18:07:27.056686Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration el-ac779bf2c9f7c09b\n", "Reusing dataset common_voice (/home/earendil/.cache/huggingface/datasets/common_voice/el-ac779bf2c9f7c09b/6.1.0/32954a9015faa0d840f6c6894938545c5d12bc5d8936a80079af74bf50d71564)\n" ] } ], "source": [ "common_voice_test_transcription = load_dataset(\"common_voice\", \"el\", data_dir=\"./cv-corpus-6.1-2020-12-11\", split=\"test\")" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:54.722520Z", "start_time": "2021-03-14T18:07:29.451275Z" } }, "outputs": [], "source": [ "# Change this value to try inference on different CommonVoice extracts\n", "example = 678\n", "\n", "input_dict = processor(common_voice_test[\"input_values\"][example], return_tensors=\"pt\", sampling_rate=16_000, padding=True)\n", "\n", "logits = model(input_dict.input_values.to(\"cuda\")).logits\n", "\n", "pred_ids = torch.argmax(logits, dim=-1)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2021-03-14T18:07:54.742988Z", "start_time": "2021-03-14T18:07:54.739626Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction:\n", "πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n", "\n", "Reference:\n", "πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς.\n" ] } ], "source": [ "print(\"Prediction:\")\n", "print(processor.decode(pred_ids[0]))\n", "# πού θέλεις να πάμε ρώτησε φοβισμένα ο βασιλιάς\n", "\n", "print(\"\\nReference:\")\n", "print(common_voice_test_transcription[\"sentence\"][example].lower())\n", "# πού θέλεις να πάμε; ρώτησε φοβισμένα ο βασιλιάς." ] } ], "metadata": { "kernelspec": { "display_name": "cuda110", "language": "python", "name": "cuda110" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }