sanchit-gandhi's picture
tidy
4447566
raw
history blame
8.61 kB
import os
import numpy as np
import unicodedata
import diff_match_patch as dmp_module
from enum import Enum
import gradio as gr
from datasets import load_dataset
import pandas as pd
from jiwer import process_words, wer_default
from nltk import ngrams
class Action(Enum):
INSERTION = 1
DELETION = -1
EQUAL = 0
def compare_string(text1: str, text2: str) -> list:
text1_normalized = unicodedata.normalize("NFKC", text1)
text2_normalized = unicodedata.normalize("NFKC", text2)
dmp = dmp_module.diff_match_patch()
diff = dmp.diff_main(text1_normalized, text2_normalized)
dmp.diff_cleanupSemantic(diff)
return diff
def style_text(diff):
fullText = ""
for action, text in diff:
if action == Action.INSERTION.value:
fullText += f"<span style='background-color:Lightgreen'>{text}</span>"
elif action == Action.DELETION.value:
fullText += f"<span style='background-color:#FFCCCB'><s>{text}</s></span>"
elif action == Action.EQUAL.value:
fullText += f"{text}"
else:
raise Exception("Not Implemented")
fullText = fullText.replace("](", "]\(").replace("~", "\~")
return fullText
dataset = load_dataset(
"distil-whisper/tedlium-long-form", split="validation", num_proc=os.cpu_count()
)
csv_v2 = pd.read_csv("assets/large-v2.csv")
norm_target = csv_v2["Norm Target"]
norm_pred_v2 = csv_v2["Norm Pred"]
norm_target = [norm_target[i] for i in range(len(norm_target))]
norm_pred_v2 = [norm_pred_v2[i] for i in range(len(norm_pred_v2))]
csv_v2 = pd.read_csv("assets/large-32-2.csv")
norm_pred_32_2 = csv_v2["Norm Pred"]
norm_pred_32_2 = [norm_pred_32_2[i] for i in range(len(norm_pred_32_2))]
target_dtype = np.int16
max_range = np.iinfo(target_dtype).max
def get_statistics(model="large-v2", round_dp=2, ngram_degree=5):
text1 = norm_target
if model == "large-v2":
text2 = norm_pred_v2
elif model == "large-32-2":
text2 = norm_pred_32_2
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references]), round_dp
)
all_ngrams = list(ngrams(" ".join(text2).split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
return wer_percentage, ier_percentage, repeated_ngrams
def get_overall_table():
large_v2 = get_statistics(model="large-v2")
large_32_2 = get_statistics(model="large-32-2")
# format the rows
table = [large_v2, large_32_2]
# format the model names
table[0] = ["Whisper", *table[0]]
table[1] = ["Distil-Whisper", *table[1]]
return table
def get_visualisation(idx, model="large-v2", round_dp=2, ngram_degree=5):
idx -= 1
audio = dataset[idx]["audio"]
array = (audio["array"] * max_range).astype(np.int16)
sampling_rate = audio["sampling_rate"]
text1 = norm_target[idx]
if model == "large-v2":
text2 = norm_pred_v2[idx]
elif model == "large-32-2":
text2 = norm_pred_32_2[idx]
else:
raise ValueError(
f"Got unknown model {model}, should be one of `'large-v2'` or `'large-32-2'`."
)
wer_output = process_words(text1, text2, wer_default, wer_default)
wer_percentage = round(100 * wer_output.wer, round_dp)
ier_percentage = round(
100 * wer_output.insertions / len(wer_output.references[0]), round_dp
)
all_ngrams = list(ngrams(text2.split(), ngram_degree))
unique_ngrams = []
for ngram in all_ngrams:
if ngram not in unique_ngrams:
unique_ngrams.append(ngram)
repeated_ngrams = len(all_ngrams) - len(unique_ngrams)
diff = compare_string(text1, text2)
full_text = style_text(diff)
return (
(sampling_rate, array),
wer_percentage,
ier_percentage,
repeated_ngrams,
full_text,
)
def get_side_by_side_visualisation(idx):
large_v2 = get_visualisation(idx, model="large-v2")
large_32_2 = get_visualisation(idx, model="large-32-2")
# format the rows
table = [large_v2[1:-1], large_32_2[1:-1]]
# format the model names
table[0] = ["Whisper", *table[0]]
table[1] = ["Distil-Whisper", *table[1]]
return large_v2[0], table, large_v2[-1], large_32_2[-1]
if __name__ == "__main__":
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="
display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
"
>
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
Whisper Transcription Analysis
</h1>
</div>
</div>
"""
)
gr.Markdown(
"""
One of the major claims of the <a href="https://arxiv.org/abs/2311.00430"> Distil-Whisper paper</a> is that
that Distil-Whisper hallucinates less than Whisper on long-form audio. To demonstrate this, we'll analyse the
transcriptions generated by <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper</a>
and <a href="https://huggingface.co/distil-whisper/distil-large-v2"> Distil-Whisper</a> on the
<a href="https://huggingface.co/datasets/distil-whisper/tedlium-long-form"> TED-LIUM</a> validation set.
To quantify the amount of repetition and hallucination in the predicted transcriptions, we measure the number
of repeated 5-gram word duplicates (5-Dup.) and the insertion error rate (IER). Analysis is performed on the
overall level, where statistics are computed over the entire dataset, and also a per-sample level (i.e. an
on an individual example basis).
The transcriptions for both models are shown at the bottom of the demo. We compute a text difference for each
relative to the ground truth transcriptions. Insertions are displayed in <span style='background-color:Lightgreen'>green</span>,
and deletions in <span style='background-color:#FFCCCB'><s>red</s></span>. Multiple words in <span style='background-color:Lightgreen'>green</span>
indicates that a model has hallucinated, since it has inserted words not present in the ground truth transcription.
Overall, Distil-Whisper has roughly half the number of 5-Dup. and IER. This indicates that it has a lower
propensity to hallucinate compared to the Whisper model. Try both models with some of the TED-LIUM examples
and view the reduction in hallucinations for yourself!
"""
)
gr.Markdown("**Overall statistics:**")
table = gr.Dataframe(
value=get_overall_table(),
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Repeated 5-grams",
],
row_count=2,
)
gr.Markdown("**Per-sample statistics:**")
slider = gr.Slider(
minimum=1, maximum=len(norm_target), step=1, label="Dataset sample"
)
btn = gr.Button("Analyse")
audio_out = gr.Audio(label="Audio input")
with gr.Column():
table = gr.Dataframe(
headers=[
"Model",
"Word Error Rate (WER)",
"Insertion Error Rate (IER)",
"Repeated 5-grams",
],
row_count=2,
)
with gr.Row():
gr.Markdown("**Whisper text diff**")
gr.Markdown("**Distil-Whisper text diff**")
with gr.Row():
text_out_v2 = gr.Markdown(label="Text difference")
text_out_32_2 = gr.Markdown(label="Text difference")
btn.click(
fn=get_side_by_side_visualisation,
inputs=slider,
outputs=[audio_out, table, text_out_v2, text_out_32_2],
)
demo.launch()