File size: 4,955 Bytes
022d425
 
 
 
 
 
 
 
 
 
 
be77fdc
022d425
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os

import gradio as gr
from datasets import Audio
from datasets import load_dataset
from jiwer import wer, cer
from transformers import pipeline

from arabic_normalizer import ArabicTextNormalizer

# Load dataset
common_voice = load_dataset("mozilla-foundation/common_voice_11_0",trust_remote_code=True, name = "ar", split = "train")
# select column that will be used
common_voice = common_voice.select_columns(["audio", "sentence"])

generate_kwargs = {
    "language": "arabic",
    "task": "transcribe"
}
# Initialize ASR pipeline
asr_whisper_large = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3", device = 0,
                             generate_kwargs = generate_kwargs)
asr_whisper_large_turbo = pipeline("automatic-speech-recognition", model = "openai/whisper-large-v3-turbo",
                                   device = 0, generate_kwargs = generate_kwargs)
normalizer = ArabicTextNormalizer()


def generate_audio(index = None):
    """Select an audio sample, resample if needed, and transcribe using ASR."""
    # inspect dataset
    # print(common_voice)
    # print(common_voice.features)

    # resample audio using dataset function
    global common_voice
    common_voice = common_voice.cast_column("audio", Audio(sampling_rate = 16000))
    # print(common_voice.features)

    # Randomly shuffle the dataset and pick the first sample
    example = common_voice.shuffle()[0]
    audio = example["audio"]

    # Ground truth transcription (for WER/CER calculations)
    reference_text = normalizer(example["sentence"])

    # Prepare audio data for ASR
    audio_data = {
        "array": audio["array"],
        "sampling_rate": audio["sampling_rate"]
    }

    audio_data_turbo = {
        "raw": audio["array"],
        "sampling_rate": audio["sampling_rate"]
    }

    # Perform automatic speech recognition (ASR) directly on the resampled audio array
    asr_output = asr_whisper_large(audio_data)

    asr_output_turbo = asr_whisper_large_turbo(audio_data_turbo)

    # Extract the transcription from the ASR model output
    predicted_text = normalizer(asr_output["text"])
    predicted_text_turbo = normalizer(asr_output_turbo["text"])

    # Compute WER, Word Accuracy, and CER
    wer_score = wer(reference_text, predicted_text)
    cer_score = cer(reference_text, predicted_text)

    wer_score_turbo = wer(reference_text, predicted_text_turbo)
    cer_score_turbo = cer(reference_text, predicted_text_turbo)

    # Prepare display data: original sentence, sampling rate, ASR transcription, and metrics
    sentence_info = "-".join([reference_text, str(audio["sampling_rate"])])

    return ((
                audio["sampling_rate"],
                audio["array"]
            ), sentence_info, predicted_text, wer_score, cer_score, predicted_text_turbo,
            wer_score_turbo, cer_score_turbo)

def update_ui():
    res = []
    for i in range(4):
        res.append(gr.Textbox(label=f"Label {i}"))
    return res

with (gr.Blocks() as demo):
    gr.HTML("""
        <h1>Whisper Arabic: ASR Comparison (large and large turbo)</h1>""")
    gr.Markdown("""
        This is a demo to compare the outputs, WER & CER of two ASR models (Whisper large and large turbo) using 
        arabic dataset from mozilla-foundation/common_voice_11_0
    """)
    num_samples_input = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of audio samples")
    generate_button = gr.Button("Generate Samples")


    @gr.render(inputs=num_samples_input, triggers=[generate_button.click])
    def render(num_samples):
        with gr.Column():
            for i in range(num_samples):
                # Generate audio and associated data
                _audio, label, asr_text, wer_score, cer_score, asr_text_turbo, wer_score_turbo, cer_score_turbo =generate_audio()

                # Create Gradio components to display the audio, transcription, and metrics
                gr.Audio(_audio, label = label)
                with gr.Row():
                    with gr.Column():
                        gr.Textbox(value = asr_text, label = "Whisper large output"),
                        gr.Textbox(value = f"WER: {wer_score:.2f}", label = "Word Error Rate"),
                        gr.Textbox(value = f"CER: {cer_score:.2f}", label = "Character Error Rate"),
                    with gr.Column():
                        gr.Textbox(value = asr_text_turbo, label = "Whisper large turbo output"),
                        gr.Textbox(value = f"WER: {wer_score_turbo:.2f}", label = "Word Error Rate - "
                                                                                                   "TURBO  "),
                        gr.Textbox(value = f"CER: {cer_score_turbo:.2f}", label = "Character Error "
                                                                                                      "Rate - TURBO")

if __name__ == '__main__':
    demo.launch(show_error = True)