|
import os |
|
|
|
import numpy as np |
|
import unicodedata |
|
import diff_match_patch as dmp_module |
|
from enum import Enum |
|
import gradio as gr |
|
from datasets import load_dataset |
|
import pandas as pd |
|
from jiwer import process_words, wer_default |
|
from nltk import ngrams |
|
|
|
|
|
class Action(Enum): |
|
INSERTION = 1 |
|
DELETION = -1 |
|
EQUAL = 0 |
|
|
|
|
|
def compare_string(text1: str, text2: str) -> list: |
|
text1_normalized = unicodedata.normalize("NFKC", text1) |
|
text2_normalized = unicodedata.normalize("NFKC", text2) |
|
|
|
dmp = dmp_module.diff_match_patch() |
|
diff = dmp.diff_main(text1_normalized, text2_normalized) |
|
dmp.diff_cleanupSemantic(diff) |
|
|
|
return diff |
|
|
|
|
|
def style_text(diff): |
|
fullText = "" |
|
for action, text in diff: |
|
if action == Action.INSERTION.value: |
|
fullText += f"<span style='background-color:Lightgreen'>{text}</span>" |
|
elif action == Action.DELETION.value: |
|
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>" |
|
elif action == Action.EQUAL.value: |
|
fullText += f"{text}" |
|
else: |
|
raise Exception("Not Implemented") |
|
fullText = fullText.replace("](", "]\(").replace("~", "\~") |
|
return fullText |
|
|
|
|
|
dataset = load_dataset( |
|
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count() |
|
) |
|
|
|
csv_v2 = pd.read_csv("assets/large-v2.csv") |
|
|
|
norm_target = csv_v2["Norm Target"] |
|
norm_pred_v2 = csv_v2["Norm Pred"] |
|
|
|
norm_target = [norm_target[i] for i in range(len(norm_target))] |
|
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))] |
|
|
|
csv_v2 = pd.read_csv("assets/large-32-2.csv") |
|
|
|
norm_pred_32_2 = csv_v2["Norm Pred"] |
|
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))] |
|
|
|
target_dtype = np.int16 |
|
max_range = np.iinfo(target_dtype).max |
|
|
|
|
|
def get_statistics(model="large-v2", round_dp=2, ngram_degree=5): |
|
text1 = norm_target |
|
if model == "large-v2": |
|
text2 = norm_pred_v2 |
|
elif model == "large-32-2": |
|
text2 = norm_pred_32_2 |
|
else: |
|
raise ValueError( |
|
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`." |
|
) |
|
|
|
wer_output = process_words(text1, text2, wer_default, wer_default) |
|
wer_percentage = round(100 * wer_output.wer, round_dp) |
|
ier_percentage = round( |
|
100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references]), round_dp |
|
) |
|
|
|
all_ngrams = list(ngrams(" ".join(text2).split(), ngram_degree)) |
|
|
|
unique_ngrams = [] |
|
for ngram in all_ngrams: |
|
if ngram not in unique_ngrams: |
|
unique_ngrams.append(ngram) |
|
|
|
repeated_ngrams = len(all_ngrams) - len(unique_ngrams) |
|
|
|
return wer_percentage, ier_percentage, repeated_ngrams |
|
|
|
|
|
def get_overall_table(): |
|
large_v2 = get_statistics(model="large-v2") |
|
large_32_2 = get_statistics(model="large-32-2") |
|
|
|
table = [large_v2, large_32_2] |
|
|
|
table[0] = ["Whisper", *table[0]] |
|
table[1] = ["Distil-Whisper", *table[1]] |
|
return table |
|
|
|
|
|
def get_visualisation(idx, model="large-v2", round_dp=2, ngram_degree=5): |
|
idx -= 1 |
|
audio = dataset[idx]["audio"] |
|
array = (audio["array"] * max_range).astype(np.int16) |
|
sampling_rate = audio["sampling_rate"] |
|
|
|
text1 = norm_target[idx] |
|
if model == "large-v2": |
|
text2 = norm_pred_v2[idx] |
|
elif model == "large-32-2": |
|
text2 = norm_pred_32_2[idx] |
|
else: |
|
raise ValueError( |
|
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`." |
|
) |
|
|
|
wer_output = process_words(text1, text2, wer_default, wer_default) |
|
wer_percentage = round(100 * wer_output.wer, round_dp) |
|
ier_percentage = round( |
|
100 * wer_output.insertions / len(wer_output.references[0]), round_dp |
|
) |
|
|
|
all_ngrams = list(ngrams(text2.split(), ngram_degree)) |
|
|
|
unique_ngrams = [] |
|
for ngram in all_ngrams: |
|
if ngram not in unique_ngrams: |
|
unique_ngrams.append(ngram) |
|
|
|
repeated_ngrams = len(all_ngrams) - len(unique_ngrams) |
|
|
|
diff = compare_string(text1, text2) |
|
full_text = style_text(diff) |
|
|
|
return ( |
|
(sampling_rate, array), |
|
wer_percentage, |
|
ier_percentage, |
|
repeated_ngrams, |
|
full_text, |
|
) |
|
|
|
|
|
def get_side_by_side_visualisation(idx): |
|
large_v2 = get_visualisation(idx, model="large-v2") |
|
large_32_2 = get_visualisation(idx, model="large-32-2") |
|
|
|
table = [large_v2[1:-1], large_32_2[1:-1]] |
|
|
|
table[0] = ["Whisper", *table[0]] |
|
table[1] = ["Distil-Whisper", *table[1]] |
|
return large_v2[0], table, large_v2[-1], large_32_2[-1] |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<div style="text-align: center; max-width: 700px; margin: 0 auto;"> |
|
<div |
|
style=" |
|
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem; |
|
" |
|
> |
|
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;"> |
|
Whisper Transcription Analysis |
|
</h1> |
|
</div> |
|
</div> |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
One of the major claims of the <a href="https://arxiv.org/abs/2311.00430"> Distil-Whisper paper</a> is that |
|
that Distil-Whisper hallucinates less than Whisper on long-form audio. To demonstrate this, we'll analyse the |
|
transcriptions generated by <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper</a> |
|
and <a href="https://huggingface.co/distil-whisper/distil-large-v2"> Distil-Whisper</a> on the |
|
<a href="https://huggingface.co/datasets/distil-whisper/tedlium-long-form"> TED-LIUM</a> validation set. |
|
|
|
To quantify the amount of repetition and hallucination in the predicted transcriptions, we measure the number |
|
of repeated 5-gram word duplicates (5-Dup.) and the insertion error rate (IER). Analysis is performed on the |
|
overall level, where statistics are computed over the entire dataset, and also a per-sample level (i.e. an |
|
on an individual example basis). |
|
|
|
The transcriptions for both models are shown at the bottom of the demo. We compute a text difference for each |
|
relative to the ground truth transcriptions. Insertions are displayed in <span style='background-color:Lightgreen'>green</span>, |
|
and deletions in <span style='background-color:#FFCCCB'><s>red</s></span>. Multiple words in <span style='background-color:Lightgreen'>green</span> |
|
indicates that a model has hallucinated, since it has inserted words not present in the ground truth transcription. |
|
|
|
Overall, Distil-Whisper has roughly half the number of 5-Dup. and IER. This indicates that it has a lower |
|
propensity to hallucinate compared to the Whisper model. Try both models with some of the TED-LIUM examples |
|
and view the reduction in hallucinations for yourself! |
|
""" |
|
) |
|
gr.Markdown("**Overall statistics:**") |
|
table = gr.Dataframe( |
|
value=get_overall_table(), |
|
headers=[ |
|
"Model", |
|
"Word Error Rate (WER)", |
|
"Insertion Error Rate (IER)", |
|
"Repeated 5-grams", |
|
], |
|
row_count=2, |
|
) |
|
gr.Markdown("**Per-sample statistics:**") |
|
slider = gr.Slider( |
|
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample" |
|
) |
|
btn = gr.Button("Analyse") |
|
audio_out = gr.Audio(label="Audio input") |
|
with gr.Column(): |
|
table = gr.Dataframe( |
|
headers=[ |
|
"Model", |
|
"Word Error Rate (WER)", |
|
"Insertion Error Rate (IER)", |
|
"Repeated 5-grams", |
|
], |
|
row_count=2, |
|
) |
|
with gr.Row(): |
|
gr.Markdown("**Whisper text diff**") |
|
gr.Markdown("**Distil-Whisper text diff**") |
|
with gr.Row(): |
|
text_out_v2 = gr.Markdown(label="Text difference") |
|
text_out_32_2 = gr.Markdown(label="Text difference") |
|
|
|
btn.click( |
|
fn=get_side_by_side_visualisation, |
|
inputs=slider, |
|
outputs=[audio_out, table, text_out_v2, text_out_32_2], |
|
) |
|
demo.launch() |
|
|