import os import numpy as np import unicodedata import diff_match_patch as dmp_module from enum import Enum import gradio as gr from datasets import load_dataset import pandas as pd from jiwer import process_words, wer_default from nltk import ngrams class Action(Enum): INSERTION = 1 DELETION = -1 EQUAL = 0 def compare_string(text1: str, text2: str) -> list: text1_normalized = unicodedata.normalize("NFKC", text1) text2_normalized = unicodedata.normalize("NFKC", text2) dmp = dmp_module.diff_match_patch() diff = dmp.diff_main(text1_normalized, text2_normalized) dmp.diff_cleanupSemantic(diff) return diff def style_text(diff): fullText = "" for action, text in diff: if action == Action.INSERTION.value: fullText += f"{text}" elif action == Action.DELETION.value: fullText += f"{text}" elif action == Action.EQUAL.value: fullText += f"{text}" else: raise Exception("Not Implemented") fullText = fullText.replace("](", "]\(").replace("~", "\~") return fullText dataset = load_dataset( "distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count() ) csv_v2 = pd.read_csv("assets/large-v2.csv") norm_target = csv_v2["Norm Target"] norm_pred_v2 = csv_v2["Norm Pred"] norm_target = [norm_target[i] for i in range(len(norm_target))] norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))] csv_v2 = pd.read_csv("assets/large-32-2.csv") norm_pred_32_2 = csv_v2["Norm Pred"] norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))] target_dtype = np.int16 max_range = np.iinfo(target_dtype).max def get_statistics(model="large-v2", round_dp=2, ngram_degree=5): text1 = norm_target if model == "large-v2": text2 = norm_pred_v2 elif model == "large-32-2": text2 = norm_pred_32_2 else: raise ValueError( f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`." ) wer_output = process_words(text1, text2, wer_default, wer_default) wer_percentage = round(100 * wer_output.wer, round_dp) ier_percentage = round( 100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references]), round_dp ) all_ngrams = list(ngrams(" ".join(text2).split(), ngram_degree)) unique_ngrams = [] for ngram in all_ngrams: if ngram not in unique_ngrams: unique_ngrams.append(ngram) repeated_ngrams = len(all_ngrams) - len(unique_ngrams) return wer_percentage, ier_percentage, repeated_ngrams def get_overall_table(): large_v2 = get_statistics(model="large-v2") large_32_2 = get_statistics(model="large-32-2") # format the rows table = [large_v2, large_32_2] # format the model names table[0] = ["Whisper", *table[0]] table[1] = ["Distil-Whisper", *table[1]] return table def get_visualisation(idx, model="large-v2", round_dp=2, ngram_degree=5): idx -= 1 audio = dataset[idx]["audio"] array = (audio["array"] * max_range).astype(np.int16) sampling_rate = audio["sampling_rate"] text1 = norm_target[idx] if model == "large-v2": text2 = norm_pred_v2[idx] elif model == "large-32-2": text2 = norm_pred_32_2[idx] else: raise ValueError( f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`." ) wer_output = process_words(text1, text2, wer_default, wer_default) wer_percentage = round(100 * wer_output.wer, round_dp) ier_percentage = round( 100 * wer_output.insertions / len(wer_output.references[0]), round_dp ) all_ngrams = list(ngrams(text2.split(), ngram_degree)) unique_ngrams = [] for ngram in all_ngrams: if ngram not in unique_ngrams: unique_ngrams.append(ngram) repeated_ngrams = len(all_ngrams) - len(unique_ngrams) diff = compare_string(text1, text2) full_text = style_text(diff) return ( (sampling_rate, array), wer_percentage, ier_percentage, repeated_ngrams, full_text, ) def get_side_by_side_visualisation(idx): large_v2 = get_visualisation(idx, model="large-v2") large_32_2 = get_visualisation(idx, model="large-32-2") # format the rows table = [large_v2[1:-1], large_32_2[1:-1]] # format the model names table[0] = ["Whisper", *table[0]] table[1] = ["Distil-Whisper", *table[1]] return large_v2[0], table, large_v2[-1], large_32_2[-1] if __name__ == "__main__": with gr.Blocks() as demo: gr.HTML( """

Whisper Transcription Analysis

""" ) gr.Markdown( """ One of the major claims of the Distil-Whisper paper is that that Distil-Whisper hallucinates less than Whisper on long-form audio. To demonstrate this, we'll analyse the transcriptions generated by Whisper and Distil-Whisper on the TED-LIUM validation set. To quantify the amount of repetition and hallucination in the predicted transcriptions, we measure the number of repeated 5-gram word duplicates (5-Dup.) and the insertion error rate (IER). Analysis is performed on the overall level, where statistics are computed over the entire dataset, and also a per-sample level (i.e. an on an individual example basis). The transcriptions for both models are shown at the bottom of the demo. We compute a text difference for each relative to the ground truth transcriptions. Insertions are displayed in green, and deletions in red. Multiple words in green indicates that a model has hallucinated, since it has inserted words not present in the ground truth transcription. Overall, Distil-Whisper has roughly half the number of 5-Dup. and IER. This indicates that it has a lower propensity to hallucinate compared to the Whisper model. Try both models with some of the TED-LIUM examples and view the reduction in hallucinations for yourself! """ ) gr.Markdown("**Overall statistics:**") table = gr.Dataframe( value=get_overall_table(), headers=[ "Model", "Word Error Rate (WER)", "Insertion Error Rate (IER)", "Repeated 5-grams", ], row_count=2, ) gr.Markdown("**Per-sample statistics:**") slider = gr.Slider( minimum=1, maximum=len(norm_target), step=1, label="Dataset sample" ) btn = gr.Button("Analyse") audio_out = gr.Audio(label="Audio input") with gr.Column(): table = gr.Dataframe( headers=[ "Model", "Word Error Rate (WER)", "Insertion Error Rate (IER)", "Repeated 5-grams", ], row_count=2, ) with gr.Row(): gr.Markdown("**Whisper text diff**") gr.Markdown("**Distil-Whisper text diff**") with gr.Row(): text_out_v2 = gr.Markdown(label="Text difference") text_out_32_2 = gr.Markdown(label="Text difference") btn.click( fn=get_side_by_side_visualisation, inputs=slider, outputs=[audio_out, table, text_out_v2, text_out_32_2], ) demo.launch()