File size: 5,187 Bytes
7091430
 
 
 
 
 
 
 
 
5d7014c
7091430
 
 
 
 
 
 
 
 
 
 
 
 
 
2ad1599
 
7091430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80ca55c
7091430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80ca55c
7091430
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d7014c
7091430
 
 
 
 
 
213b090
 
7091430
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers.utils import is_flash_attn_2_available
import torch
import gradio as gr
import matplotlib.pyplot as plt
import time
import os

BATCH_SIZE = 16
# TODO: remove token before release and update ckpt path
TOKEN = os.environ.get("HF_TOKEN", None)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
use_flash_attention_2 = is_flash_attn_2_available()

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    "openai/whisper-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
)
distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
    "sanchit-gandhi/distil-large-v2-private", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2, token=TOKEN
)

if not use_flash_attention_2:
    model = model.to_bettertransformer()
    distilled_model = distilled_model.to_bettertransformer()

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")

model.to(device)
distilled_model.to(device)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=30,
    torch_dtype=torch_dtype,
    device=device,
    generate_kwargs={"language": "en", "task": "transcribe"},
)
pipe_forward = pipe._forward

distil_pipe = pipeline(
    "automatic-speech-recognition",
    model=distilled_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    chunk_length_s=15,
    torch_dtype=torch_dtype,
    device=device,
)
distil_pipe_forward = distil_pipe._forward

def transcribe(inputs):
    if inputs is None:
        raise gr.Error("No audio file submitted! Please record or upload an audio file before submitting your request.")

    def _forward_distil_time(*args, **kwargs):
        global distil_runtime
        start_time = time.time()
        result = distil_pipe_forward(*args, **kwargs)
        distil_runtime = time.time() - start_time
        return result

    distil_pipe._forward = _forward_distil_time
    distil_text = distil_pipe(inputs, batch_size=BATCH_SIZE)["text"]
    yield distil_text, distil_runtime, None, None, None

    def _forward_time(*args, **kwargs):
        global runtime
        start_time = time.time()
        result = pipe_forward(*args, **kwargs)
        runtime = time.time() - start_time
        return result

    pipe._forward = _forward_time
    text = pipe(inputs, batch_size=BATCH_SIZE)["text"]

    # Create figure and axis
    fig, ax = plt.subplots(figsize=(5, 5))

    # Define bar width and positions
    bar_width = 0.1
    positions = [0, 0.1]  # Adjusted positions to bring bars closer

    # Plot data
    ax.bar(positions[0], distil_runtime, bar_width, edgecolor='black')
    ax.bar(positions[1], runtime, bar_width, edgecolor='black')

    # Set title, labels, and xticks
    ax.set_ylabel('Transcription time (s)')
    ax.set_xticks(positions)
    ax.set_xticklabels(['Distil-Whisper', 'Whisper'])

    # Gridlines and other styling
    ax.grid(which='major', axis='y', linestyle='--', linewidth=0.5)

    # Use tight layout to avoid overlaps
    plt.tight_layout()

    yield distil_text, distil_runtime, text, runtime, plt

if __name__ == "__main__":
    with gr.Blocks() as demo:
        gr.HTML(
            """
                <div style="text-align: center; max-width: 700px; margin: 0 auto;">
                  <div
                    style="
                      display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
                    "
                  >
                    <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
                      Distil-Whisper VS Whisper
                    </h1>
                  </div>
                </div>
            """
        )
        gr.HTML(
            f"""
            This demo evaluates the <a href="https://huggingface.co/distil-whisper/distil-large-v2"> Distil-Whisper </a> model 
            against the <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper </a> model.  
            """
        )
        audio = gr.components.Audio(type="filepath", label="Audio input")
        button = gr.Button("Transcribe")
        plot = gr.components.Plot()
        with gr.Row():
            distil_runtime = gr.components.Textbox(label="Distil-Whisper Transcription Time (s)")
            runtime = gr.components.Textbox(label="Whisper Transcription Time (s)")
        with gr.Row():
            distil_transcription = gr.components.Textbox(label="Distil-Whisper Transcription", show_copy_button=True)
            transcription = gr.components.Textbox(label="Whisper Transcription", show_copy_button=True)

        button.click(
            fn=transcribe,
            inputs=audio,
            outputs=[distil_transcription, distil_runtime, transcription, runtime, plot],
        )

    demo.queue().launch()