File size: 4,955 Bytes
022d425 be77fdc 022d425 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import os
import gradio as gr
from datasets import Audio
from datasets import load_dataset
from jiwer import wer, cer
from transformers import pipeline
from arabic_normalizer import ArabicTextNormalizer
# Load dataset
common_voice = load_dataset("mozilla-foundation/common_voice_11_0",trust_remote_code=True, name = "ar", split = "train")
# select column that will be used
common_voice = common_voice.select_columns(["audio", "sentence"])
generate_kwargs = {
"language": "arabic",
"task": "transcribe"
}
# Initialize ASR pipeline
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3", device = 0,
generate_kwargs = generate_kwargs)
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
device = 0, generate_kwargs = generate_kwargs)
normalizer = ArabicTextNormalizer()
def generate_audio(index = None):
"""Select an audio sample, resample if needed, and transcribe using ASR."""
# inspect dataset
# print(common_voice)
# print(common_voice.features)
# resample audio using dataset function
global common_voice
common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000))
# print(common_voice.features)
# Randomly shuffle the dataset and pick the first sample
example = common_voice.shuffle()[0]
audio = example["audio"]
# Ground truth transcription (for WER/CER calculations)
reference_text = normalizer(example["sentence"])
# Prepare audio data for ASR
audio_data = {
"array": audio["array"],
"sampling_rate": audio["sampling_rate"]
}
audio_data_turbo = {
"raw": audio["array"],
"sampling_rate": audio["sampling_rate"]
}
# Perform automatic speech recognition (ASR) directly on the resampled audio array
asr_output = asr_whisper_large(audio_data)
asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)
# Extract the transcription from the ASR model output
predicted_text = normalizer(asr_output["text"])
predicted_text_turbo = normalizer(asr_output_turbo["text"])
# Compute WER, Word Accuracy, and CER
wer_score = wer(reference_text, predicted_text)
cer_score = cer(reference_text, predicted_text)
wer_score_turbo = wer(reference_text, predicted_text_turbo)
cer_score_turbo = cer(reference_text, predicted_text_turbo)
# Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])
return ((
audio["sampling_rate"],
audio["array"]
), sentence_info, predicted_text, wer_score, cer_score, predicted_text_turbo,
wer_score_turbo, cer_score_turbo)
def update_ui():
res = []
for i in range(4):
res.append(gr.Textbox(label=f"Label {i}"))
return res
with (gr.Blocks() as demo):
gr.HTML("""
<h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
gr.Markdown("""
This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using
arabic dataset from mozilla-foundation/common_voice_11_0
""")
num_samples_input = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of audio samples")
generate_button = gr.Button("Generate Samples")
@gr.render(inputs=num_samples_input, triggers=[generate_button.click])
def render(num_samples):
with gr.Column():
for i in range(num_samples):
# Generate audio and associated data
_audio, label, asr_text, wer_score, cer_score, asr_text_turbo, wer_score_turbo, cer_score_turbo =generate_audio()
# Create Gradio components to display the audio, transcription, and metrics
gr.Audio(_audio, label = label)
with gr.Row():
with gr.Column():
gr.Textbox(value = asr_text, label = "Whisper large output"),
gr.Textbox(value = f"WER: {wer_score:.2f}", label = "Word Error Rate"),
gr.Textbox(value = f"CER: {cer_score:.2f}", label = "Character Error Rate"),
with gr.Column():
gr.Textbox(value = asr_text_turbo, label = "Whisper large turbo output"),
gr.Textbox(value = f"WER: {wer_score_turbo:.2f}", label = "Word Error Rate - "
"TURBO "),
gr.Textbox(value = f"CER: {cer_score_turbo:.2f}", label = "Character Error "
"Rate - TURBO")
if __name__ == '__main__':
demo.launch(show_error = True)
|