sanchit-gandhi commited on
Commit
7bd1e74
·
1 Parent(s): 4487a27

short-form

Browse files
Files changed (1) hide show
  1. app.py +52 -17
app.py CHANGED
@@ -1,6 +1,7 @@
1
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
2
  from transformers.utils import is_flash_attn_2_available
3
  from transformers.pipelines.audio_utils import ffmpeg_read
 
4
  import torch
5
  import gradio as gr
6
  import time
@@ -25,6 +26,7 @@ if not use_flash_attention_2:
25
  distilled_model = distilled_model.to_bettertransformer()
26
 
27
  processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
 
28
 
29
  model.to(device)
30
  distilled_model.to(device)
@@ -72,32 +74,65 @@ def transcribe(inputs):
72
  f"Got an audio of length {round(audio_length_mins, 3)} minutes."
73
  )
74
 
75
- inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
 
76
 
77
- def _forward_distil_time(*args, **kwargs):
78
- global distil_runtime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  start_time = time.time()
80
- result = distil_pipe_forward(*args, **kwargs)
 
 
 
 
81
  distil_runtime = time.time() - start_time
82
  distil_runtime = round(distil_runtime, 2)
83
- return result
84
 
85
- distil_pipe._forward = _forward_distil_time
86
- distil_text = distil_pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"]
87
- yield distil_text, distil_runtime, None, None, None
88
 
89
- def _forward_time(*args, **kwargs):
90
- global runtime
91
  start_time = time.time()
92
- result = pipe_forward(*args, **kwargs)
 
 
 
 
93
  runtime = time.time() - start_time
94
  runtime = round(runtime, 2)
95
- return result
96
-
97
- pipe._forward = _forward_time
98
- text = pipe(inputs, batch_size=BATCH_SIZE)["text"]
99
 
100
- yield distil_text, distil_runtime, text, runtime
101
 
102
  if __name__ == "__main__":
103
  with gr.Blocks() as demo:
 
1
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, TextIteratorStreamer
2
  from transformers.utils import is_flash_attn_2_available
3
  from transformers.pipelines.audio_utils import ffmpeg_read
4
+ from threading import Thread
5
  import torch
6
  import gradio as gr
7
  import time
 
26
  distilled_model = distilled_model.to_bettertransformer()
27
 
28
  processor = AutoProcessor.from_pretrained("openai/whisper-large-v2")
29
+ streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
30
 
31
  model.to(device)
32
  distilled_model.to(device)
 
74
  f"Got an audio of length {round(audio_length_mins, 3)} minutes."
75
  )
76
 
77
+ if audio_length_mins >= 0.5:
78
+ inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
79
 
80
+ def _forward_distil_time(*args, **kwargs):
81
+ global distil_runtime
82
+ start_time = time.time()
83
+ result = distil_pipe_forward(*args, **kwargs)
84
+ distil_runtime = time.time() - start_time
85
+ distil_runtime = round(distil_runtime, 2)
86
+ return result
87
+
88
+ distil_pipe._forward = _forward_distil_time
89
+ distil_text = distil_pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"]
90
+ yield distil_text, distil_runtime, None, None
91
+
92
+ def _forward_time(*args, **kwargs):
93
+ global runtime
94
+ start_time = time.time()
95
+ result = pipe_forward(*args, **kwargs)
96
+ runtime = time.time() - start_time
97
+ runtime = round(runtime, 2)
98
+ return result
99
+
100
+ pipe._forward = _forward_time
101
+ text = pipe(inputs, batch_size=BATCH_SIZE)["text"]
102
+
103
+ yield distil_text, distil_runtime, text, runtime
104
+
105
+ else:
106
+ input_features = processor(inputs, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="pt").input_features
107
+
108
+ # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
109
+ generation_kwargs = dict(input_features, streamer=streamer, max_new_tokens=128, language="en", task="transcribe")
110
+ thread = Thread(target=distilled_model.generate, kwargs=generation_kwargs)
111
+
112
+ thread.start()
113
  start_time = time.time()
114
+ distil_text = ""
115
+ for generated_text in streamer:
116
+ distil_text += generated_text
117
+ yield distil_text, None, None, None
118
+
119
  distil_runtime = time.time() - start_time
120
  distil_runtime = round(distil_runtime, 2)
121
+ yield distil_text, distil_runtime, None, None
122
 
123
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
124
 
125
+ thread.start()
 
126
  start_time = time.time()
127
+ text = ""
128
+ for generated_text in streamer:
129
+ text += generated_text
130
+ yield distil_text, distil_runtime, text, None
131
+
132
  runtime = time.time() - start_time
133
  runtime = round(runtime, 2)
134
+ yield distil_text, distil_runtime, text, runtime
 
 
 
135
 
 
136
 
137
  if __name__ == "__main__":
138
  with gr.Blocks() as demo: