diff --git "a/run_evaluations_on_common_voice_test_7_0.ipynb" "b/run_evaluations_on_common_voice_test_7_0.ipynb" new file mode 100644--- /dev/null +++ "b/run_evaluations_on_common_voice_test_7_0.ipynb" @@ -0,0 +1,3217 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c8c824ea", + "metadata": {}, + "source": [ + "

TRhis notebook is for testing models against common_voice (v7)

" + ] + }, + { + "cell_type": "markdown", + "id": "a9180c0b", + "metadata": { + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "source": [ + "\n", + "\n", + "\n", + "##### TEST WITH RASMUS 1B model with language model added using our own common_voice v7 (processed before event) ###" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0ea4e3d0", + "metadata": { + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# Environment settings: \n", + "import pandas as pd\n", + "pd.set_option('display.max_column', None)\n", + "pd.set_option('display.max_rows', None)\n", + "pd.set_option('display.max_seq_items', None)\n", + "pd.set_option('display.max_colwidth', 500)\n", + "pd.set_option('expand_frame_repr', True)\n", + "\n", + "from datasets import concatenate_datasets, load_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0a810556", + "metadata": { + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c56f4efc99cc4320a0ddee55f8c6dfce", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
\\n\", \"…\", \"–\", \"°\", \"´\", \"ʾ\", \"‹\", \"›\", \"©\", \"®\", \"—\", \"→\", \"。\",\n", + " \"、\", \"﹂\", \"﹁\", \"‧\", \"~\", \"﹏\", \",\", \"{\", \"}\", \"(\", \")\", \"[\", \"]\", \"【\", \"】\", \"‥\", \"〽\",\n", + " \"『\", \"』\", \"〝\", \"〟\", \"⟨\", \"⟩\", \"〜\", \":\", \"!\", \"?\", \"♪\", \"؛\", \"/\", \"\\\\\", \"º\", \"−\", \"^\", \"ʻ\", \"ˆ\"]\n", + "\n", + "\n", + "chars_to_remove_regex = f\"[{re.escape(''.join(CHARS_TO_IGNORE))}]\"\n", + "\n", + "def remove_special_characters(batch):\n", + " batch[\"sentence\"] = re.sub(chars_to_remove_regex, '', batch[\"sentence\"]).lower()\n", + " return batch\n", + "\n", + "common_voice_test = common_voice_test.map(remove_special_characters)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "42423e76", + "metadata": { + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['client_id', 'path', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment', 'split', 'audio', 'dataset_name', 'filename', '__index_level_0__'],\n", + " num_rows: 1599\n", + "})" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "common_voice_test" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "21d0937f", + "metadata": { + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "common_voice_test_audio = common_voice_test.cast_column(\"audio\", Audio(sampling_rate=16_000))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "02f295cf", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_dataset(batch):\n", + " audio = batch[\"audio\"]\n", + "\n", + " # batched output is \"un-batched\"\n", + " batch[\"input_values\"] = processor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_values[0]\n", + " batch[\"input_length\"] = len(batch[\"input_values\"])\n", + " batch[\"sentence\"] = batch[\"sentence\"]\n", + " \n", + " with processor.as_target_processor():\n", + " batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "1341388b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8798d1f90aa241129972f072d31f7686", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1599 [00:00 Test from mozilla common_voice v_7_0 directly from hub \n", + "Using currently the \"old\" preprocessing and not the \"audio\" method\", \"…\", \"–\", \"°\", \"´\", \"ʾ\", \"‹\", \"›\", \"©\", \"®\", \"—\", \"→\", \"。\",\n", + " \"、\", \"﹂\", \"﹁\", \"‧\", \"~\", \"﹏\", \",\", \"{\", \"}\", \"(\", \")\", \"[\", \"]\", \"【\", \"】\", \"‥\", \"〽\",\n", + " \"『\", \"』\", \"〝\", \"〟\", \"⟨\", \"⟩\", \"〜\", \":\", \"!\", \"?\", \"♪\", \"؛\", \"/\", \"\\\\\", \"º\", \"−\", \"^\", \"ʻ\", \"ˆ\"]\n", + "\n", + "\n", + "chars_to_remove_regex = f\"[{re.escape(''.join(CHARS_TO_IGNORE))}]\"\n", + "\n", + "def remove_special_characters(batch):\n", + " batch[\"sentence\"] = re.sub(chars_to_remove_regex, '', batch[\"sentence\"]).lower()\n", + " return batch\n", + "\n", + "common_voice_dataset = common_voice_dataset.map(remove_special_characters)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "396369b6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/fi/7.0.0/33e08856cfa0d0665e837bcad73ffd920a0bc713ce8c5fffb55dbdf1c084d5ba/cache-b66d07bf277a5504.arrow\n" + ] + } + ], + "source": [ + "def resample_audios(batch):\n", + " sr = batch['audio']['sampling_rate']\n", + " batch['audio']['array'] = F.resample(torch.tensor(batch[\"audio\"][\"array\"]), sr, 16_000).numpy()\n", + " return batch\n", + "\n", + "common_voice_dataset = common_voice_dataset.map(resample_audios)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "2852ca1c", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_dataset(batch):\n", + " batch[\"input_values\"] = processor(batch[\"audio\"][\"array\"], sampling_rate=16000).input_values[0]\n", + " batch[\"input_length\"] = len(batch[\"input_values\"])\n", + " batch[\"sentence\"] = batch[\"sentence\"]\n", + " \n", + " with processor.as_target_processor():\n", + " batch[\"labels\"] = processor(batch[\"sentence\"]).input_ids\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "7832bd68", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "58ff04eeda134bec9ffdc7bcf71fc66c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1599 [00:00 ASR PIPELINE PREDICTIONS (Same kind as in eval.py)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "5302f579", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using the latest cached version of the module from /workspace/.cache/huggingface/modules/datasets_modules/datasets/mozilla-foundation--common_voice_7_0/33e08856cfa0d0665e837bcad73ffd920a0bc713ce8c5fffb55dbdf1c084d5ba (last modified on Sun Jan 23 16:17:44 2022) since it couldn't be found locally at mozilla-foundation/common_voice_7_0., or remotely on the Hugging Face Hub.\n", + "Reusing dataset common_voice (/workspace/.cache/huggingface/datasets/mozilla-foundation___common_voice/fi/7.0.0/33e08856cfa0d0665e837bcad73ffd920a0bc713ce8c5fffb55dbdf1c084d5ba)\n" + ] + } + ], + "source": [ + "common_voice_dataset = load_dataset(\"mozilla-foundation/common_voice_7_0\", \"fi\", split=\"test\")\n", + "\n", + "common_voice_dataset = common_voice_dataset.cast_column(\"audio\", Audio(sampling_rate=16_000))" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "6c048f8b", + "metadata": {}, + "outputs": [], + "source": [ + "def normalize_text(text: str) -> str:\n", + " \"\"\"DO ADAPT FOR YOUR USE CASE. this function normalizes the target text.\"\"\"\n", + "\n", + " chars_to_ignore_regex = [\",\", \"?\", \"¿\", \".\", \"!\", \"¡\", \";\", \";\", \":\", '\"\"', \"%\", '\"', \"�\", \"ʿ\", \"·\", \"჻\", \"~\", \"՞\",\n", + " \"؟\", \"،\", \"।\", \"॥\", \"«\", \"»\", \"„\", \"“\", \"”\", \"「\", \"」\", \"‘\", \"’\", \"《\", \"》\", \"(\", \")\", \"[\", \"]\",\n", + " \"{\", \"}\", \"=\", \"`\", \"_\", \"+\", \"<\", \">\", \"…\", \"–\", \"°\", \"´\", \"ʾ\", \"‹\", \"›\", \"©\", \"®\", \"—\", \"→\", \"。\",\n", + " \"、\", \"﹂\", \"﹁\", \"‧\", \"~\", \"﹏\", \",\", \"{\", \"}\", \"(\", \")\", \"[\", \"]\", \"【\", \"】\", \"‥\", \"〽\",\n", + " \"『\", \"』\", \"〝\", \"〟\", \"⟨\", \"⟩\", \"〜\", \":\", \"!\", \"?\", \"♪\", \"؛\", \"/\", \"\\\\\", \"º\", \"−\", \"^\", \"ʻ\", \"ˆ\"] \n", + "\n", + "\n", + " chars_to_remove_regex = f\"[{re.escape(''.join(chars_to_ignore_regex))}]\"\n", + " \n", + " \n", + " \n", + " # remove punctuation\n", + " text = re.sub(chars_to_remove_regex, '', text)\n", + " \n", + " text = text.lower()\n", + " \n", + " # Let's also make sure we split on all kinds of newlines, spaces, etc...\n", + " #text = \" \".join(text.split())\n", + " \n", + " return text" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "9fa432f2", + "metadata": {}, + "outputs": [], + "source": [ + "# map function to decode audio\n", + "def map_to_pred(batch):\n", + " prediction = asr(\n", + " batch[\"audio\"][\"array\"]\n", + " )\n", + "\n", + " batch[\"prediction\"] = prediction[\"text\"]\n", + " batch[\"target\"] = normalize_text(batch[\"sentence\"])\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "d80cf431", + "metadata": { + "papermill": { + "duration": null, + "end_time": null, + "exception": null, + "start_time": null, + "status": "pending" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4dc38b78128f49938ec47a41be469153", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1599 [00:00