Spaces:
Paused
Paused
Upload 4 files
Browse files- app.py +172 -0
- languages.py +147 -0
- requirements.txt +5 -0
- subtitle_manager.py +52 -0
app.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
from sys import platform
|
6 |
+
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
|
7 |
+
from transformers.utils import is_flash_attn_2_available
|
8 |
+
from languages import get_language_names
|
9 |
+
from subtitle_manager import Subtitle
|
10 |
+
|
11 |
+
|
12 |
+
logging.basicConfig(level=logging.INFO)
|
13 |
+
last_model = None
|
14 |
+
|
15 |
+
def write_file(output_file,subtitle):
|
16 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
17 |
+
f.write(subtitle)
|
18 |
+
|
19 |
+
def create_pipe(model, flash):
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
device = "cuda:0"
|
22 |
+
elif platform == "darwin":
|
23 |
+
device = "mps"
|
24 |
+
else:
|
25 |
+
device = "cpu"
|
26 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
27 |
+
model_id = model
|
28 |
+
|
29 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
30 |
+
model_id,
|
31 |
+
torch_dtype=torch_dtype,
|
32 |
+
low_cpu_mem_usage=True,
|
33 |
+
use_safetensors=True,
|
34 |
+
attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
|
35 |
+
# eager (manual attention implementation)
|
36 |
+
# flash_attention_2 (implementation using flash attention 2)
|
37 |
+
# sdpa (implementation using torch.nn.functional.scaled_dot_product_attention)
|
38 |
+
# PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.
|
39 |
+
)
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
43 |
+
|
44 |
+
pipe = pipeline(
|
45 |
+
"automatic-speech-recognition",
|
46 |
+
model=model,
|
47 |
+
tokenizer=processor.tokenizer,
|
48 |
+
feature_extractor=processor.feature_extractor,
|
49 |
+
# max_new_tokens=128,
|
50 |
+
# chunk_length_s=15,
|
51 |
+
# batch_size=16,
|
52 |
+
torch_dtype=torch_dtype,
|
53 |
+
device=device,
|
54 |
+
)
|
55 |
+
return pipe
|
56 |
+
|
57 |
+
def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
|
58 |
+
chunk_length_s, batch_size, progress=gr.Progress()):
|
59 |
+
global last_model
|
60 |
+
|
61 |
+
progress(0, desc="Loading Audio..")
|
62 |
+
logging.info(f"urlData:{urlData}")
|
63 |
+
logging.info(f"multipleFiles:{multipleFiles}")
|
64 |
+
logging.info(f"microphoneData:{microphoneData}")
|
65 |
+
logging.info(f"task: {task}")
|
66 |
+
logging.info(f"is_flash_attn_2_available: {is_flash_attn_2_available()}")
|
67 |
+
logging.info(f"chunk_length_s: {chunk_length_s}")
|
68 |
+
logging.info(f"batch_size: {batch_size}")
|
69 |
+
|
70 |
+
if last_model == None:
|
71 |
+
logging.info("first model")
|
72 |
+
progress(0.1, desc="Loading Model..")
|
73 |
+
pipe = create_pipe(modelName, flash)
|
74 |
+
elif modelName != last_model:
|
75 |
+
logging.info("new model")
|
76 |
+
torch.cuda.empty_cache()
|
77 |
+
progress(0.1, desc="Loading Model..")
|
78 |
+
pipe = create_pipe(modelName, flash)
|
79 |
+
else:
|
80 |
+
logging.info("Model not changed")
|
81 |
+
last_model = modelName
|
82 |
+
|
83 |
+
srt_sub = Subtitle("srt")
|
84 |
+
vtt_sub = Subtitle("vtt")
|
85 |
+
txt_sub = Subtitle("txt")
|
86 |
+
|
87 |
+
files = []
|
88 |
+
if multipleFiles:
|
89 |
+
files+=multipleFiles
|
90 |
+
if urlData:
|
91 |
+
files.append(urlData)
|
92 |
+
if microphoneData:
|
93 |
+
files.append(microphoneData)
|
94 |
+
logging.info(files)
|
95 |
+
|
96 |
+
generate_kwargs = {}
|
97 |
+
if languageName != "Automatic Detection" and modelName.endswith(".en") == False:
|
98 |
+
generate_kwargs["language"] = languageName
|
99 |
+
if modelName.endswith(".en") == False:
|
100 |
+
generate_kwargs["task"] = task
|
101 |
+
|
102 |
+
files_out = []
|
103 |
+
for file in progress.tqdm(files, desc="Working..."):
|
104 |
+
start_time = time.time()
|
105 |
+
logging.info(file)
|
106 |
+
outputs = pipe(
|
107 |
+
file,
|
108 |
+
chunk_length_s=chunk_length_s,#30
|
109 |
+
batch_size=batch_size,#24
|
110 |
+
generate_kwargs=generate_kwargs,
|
111 |
+
return_timestamps=True,
|
112 |
+
)
|
113 |
+
logging.debug(outputs)
|
114 |
+
logging.info(print(f"transcribe: {time.time() - start_time} sec."))
|
115 |
+
|
116 |
+
file_out = file.split('/')[-1]
|
117 |
+
srt = srt_sub.get_subtitle(outputs["chunks"])
|
118 |
+
vtt = vtt_sub.get_subtitle(outputs["chunks"])
|
119 |
+
txt = txt_sub.get_subtitle(outputs["chunks"])
|
120 |
+
write_file(file_out+".srt",srt)
|
121 |
+
write_file(file_out+".vtt",vtt)
|
122 |
+
write_file(file_out+".txt",txt)
|
123 |
+
files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]
|
124 |
+
|
125 |
+
progress(1, desc="Completed!")
|
126 |
+
|
127 |
+
return files_out, vtt, txt
|
128 |
+
|
129 |
+
|
130 |
+
with gr.Blocks(title="Insanely Fast Whisper") as demo:
|
131 |
+
description = "An opinionated CLI to transcribe Audio files w/ Whisper on-device! Powered by 🤗 Transformers, Optimum & flash-attn"
|
132 |
+
article = "Read the [documentation here](https://github.com/Vaibhavs10/insanely-fast-whisper#cli-options)."
|
133 |
+
whisper_models = [
|
134 |
+
"openai/whisper-tiny", "openai/whisper-tiny.en",
|
135 |
+
"openai/whisper-base", "openai/whisper-base.en",
|
136 |
+
"openai/whisper-small", "openai/whisper-small.en", "distil-whisper/distil-small.en",
|
137 |
+
"openai/whisper-medium", "openai/whisper-medium.en", "distil-whisper/distil-medium.en",
|
138 |
+
"openai/whisper-large",
|
139 |
+
"openai/whisper-large-v1",
|
140 |
+
"openai/whisper-large-v2", "distil-whisper/distil-large-v2",
|
141 |
+
"openai/whisper-large-v3", "xaviviro/whisper-large-v3-catalan-finetuned-v2",
|
142 |
+
]
|
143 |
+
waveform_options=gr.WaveformOptions(
|
144 |
+
waveform_color="#01C6FF",
|
145 |
+
waveform_progress_color="#0066B4",
|
146 |
+
skip_length=2,
|
147 |
+
show_controls=False,
|
148 |
+
)
|
149 |
+
|
150 |
+
simple_transcribe = gr.Interface(fn=transcribe_webui_simple_progress,
|
151 |
+
description=description,
|
152 |
+
article=article,
|
153 |
+
inputs=[
|
154 |
+
gr.Dropdown(choices=whisper_models, value="distil-whisper/distil-large-v2", label="Model", info="Select whisper model", interactive = True,),
|
155 |
+
gr.Dropdown(choices=["Automatic Detection"] + sorted(get_language_names()), value="Automatic Detection", label="Language", info="Select audio voice language", interactive = True,),
|
156 |
+
gr.Text(label="URL", info="(YouTube, etc.)", interactive = True),
|
157 |
+
gr.File(label="Upload Files", file_count="multiple"),
|
158 |
+
gr.Audio(sources=["microphone"], type="filepath", label="Microphone Input", waveform_options = waveform_options),
|
159 |
+
gr.Dropdown(choices=["transcribe", "translate"], label="Task", value="transcribe", interactive = True),
|
160 |
+
gr.Checkbox(label='Flash',info='Use Flash Attention 2'),
|
161 |
+
gr.Number(label='chunk_length_s',value=30, interactive = True),
|
162 |
+
gr.Number(label='batch_size',value=24, interactive = True)
|
163 |
+
], outputs=[
|
164 |
+
gr.File(label="Download"),
|
165 |
+
gr.Text(label="Transcription"),
|
166 |
+
gr.Text(label="Segments")
|
167 |
+
]
|
168 |
+
)
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
demo.launch()
|
172 |
+
|
languages.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Language():
|
2 |
+
def __init__(self, code, name):
|
3 |
+
self.code = code
|
4 |
+
self.name = name
|
5 |
+
|
6 |
+
def __str__(self):
|
7 |
+
return "Language(code={}, name={})".format(self.code, self.name)
|
8 |
+
|
9 |
+
LANGUAGES = [
|
10 |
+
Language('en', 'English'),
|
11 |
+
Language('zh', 'Chinese'),
|
12 |
+
Language('de', 'German'),
|
13 |
+
Language('es', 'Spanish'),
|
14 |
+
Language('ru', 'Russian'),
|
15 |
+
Language('ko', 'Korean'),
|
16 |
+
Language('fr', 'French'),
|
17 |
+
Language('ja', 'Japanese'),
|
18 |
+
Language('pt', 'Portuguese'),
|
19 |
+
Language('tr', 'Turkish'),
|
20 |
+
Language('pl', 'Polish'),
|
21 |
+
Language('ca', 'Catalan'),
|
22 |
+
Language('nl', 'Dutch'),
|
23 |
+
Language('ar', 'Arabic'),
|
24 |
+
Language('sv', 'Swedish'),
|
25 |
+
Language('it', 'Italian'),
|
26 |
+
Language('id', 'Indonesian'),
|
27 |
+
Language('hi', 'Hindi'),
|
28 |
+
Language('fi', 'Finnish'),
|
29 |
+
Language('vi', 'Vietnamese'),
|
30 |
+
Language('he', 'Hebrew'),
|
31 |
+
Language('uk', 'Ukrainian'),
|
32 |
+
Language('el', 'Greek'),
|
33 |
+
Language('ms', 'Malay'),
|
34 |
+
Language('cs', 'Czech'),
|
35 |
+
Language('ro', 'Romanian'),
|
36 |
+
Language('da', 'Danish'),
|
37 |
+
Language('hu', 'Hungarian'),
|
38 |
+
Language('ta', 'Tamil'),
|
39 |
+
Language('no', 'Norwegian'),
|
40 |
+
Language('th', 'Thai'),
|
41 |
+
Language('ur', 'Urdu'),
|
42 |
+
Language('hr', 'Croatian'),
|
43 |
+
Language('bg', 'Bulgarian'),
|
44 |
+
Language('lt', 'Lithuanian'),
|
45 |
+
Language('la', 'Latin'),
|
46 |
+
Language('mi', 'Maori'),
|
47 |
+
Language('ml', 'Malayalam'),
|
48 |
+
Language('cy', 'Welsh'),
|
49 |
+
Language('sk', 'Slovak'),
|
50 |
+
Language('te', 'Telugu'),
|
51 |
+
Language('fa', 'Persian'),
|
52 |
+
Language('lv', 'Latvian'),
|
53 |
+
Language('bn', 'Bengali'),
|
54 |
+
Language('sr', 'Serbian'),
|
55 |
+
Language('az', 'Azerbaijani'),
|
56 |
+
Language('sl', 'Slovenian'),
|
57 |
+
Language('kn', 'Kannada'),
|
58 |
+
Language('et', 'Estonian'),
|
59 |
+
Language('mk', 'Macedonian'),
|
60 |
+
Language('br', 'Breton'),
|
61 |
+
Language('eu', 'Basque'),
|
62 |
+
Language('is', 'Icelandic'),
|
63 |
+
Language('hy', 'Armenian'),
|
64 |
+
Language('ne', 'Nepali'),
|
65 |
+
Language('mn', 'Mongolian'),
|
66 |
+
Language('bs', 'Bosnian'),
|
67 |
+
Language('kk', 'Kazakh'),
|
68 |
+
Language('sq', 'Albanian'),
|
69 |
+
Language('sw', 'Swahili'),
|
70 |
+
Language('gl', 'Galician'),
|
71 |
+
Language('mr', 'Marathi'),
|
72 |
+
Language('pa', 'Punjabi'),
|
73 |
+
Language('si', 'Sinhala'),
|
74 |
+
Language('km', 'Khmer'),
|
75 |
+
Language('sn', 'Shona'),
|
76 |
+
Language('yo', 'Yoruba'),
|
77 |
+
Language('so', 'Somali'),
|
78 |
+
Language('af', 'Afrikaans'),
|
79 |
+
Language('oc', 'Occitan'),
|
80 |
+
Language('ka', 'Georgian'),
|
81 |
+
Language('be', 'Belarusian'),
|
82 |
+
Language('tg', 'Tajik'),
|
83 |
+
Language('sd', 'Sindhi'),
|
84 |
+
Language('gu', 'Gujarati'),
|
85 |
+
Language('am', 'Amharic'),
|
86 |
+
Language('yi', 'Yiddish'),
|
87 |
+
Language('lo', 'Lao'),
|
88 |
+
Language('uz', 'Uzbek'),
|
89 |
+
Language('fo', 'Faroese'),
|
90 |
+
Language('ht', 'Haitian creole'),
|
91 |
+
Language('ps', 'Pashto'),
|
92 |
+
Language('tk', 'Turkmen'),
|
93 |
+
Language('nn', 'Nynorsk'),
|
94 |
+
Language('mt', 'Maltese'),
|
95 |
+
Language('sa', 'Sanskrit'),
|
96 |
+
Language('lb', 'Luxembourgish'),
|
97 |
+
Language('my', 'Myanmar'),
|
98 |
+
Language('bo', 'Tibetan'),
|
99 |
+
Language('tl', 'Tagalog'),
|
100 |
+
Language('mg', 'Malagasy'),
|
101 |
+
Language('as', 'Assamese'),
|
102 |
+
Language('tt', 'Tatar'),
|
103 |
+
Language('haw', 'Hawaiian'),
|
104 |
+
Language('ln', 'Lingala'),
|
105 |
+
Language('ha', 'Hausa'),
|
106 |
+
Language('ba', 'Bashkir'),
|
107 |
+
Language('jw', 'Javanese'),
|
108 |
+
Language('su', 'Sundanese')
|
109 |
+
]
|
110 |
+
|
111 |
+
_TO_LANGUAGE_CODE = {
|
112 |
+
**{language.code: language for language in LANGUAGES},
|
113 |
+
"burmese": "my",
|
114 |
+
"valencian": "ca",
|
115 |
+
"flemish": "nl",
|
116 |
+
"haitian": "ht",
|
117 |
+
"letzeburgesch": "lb",
|
118 |
+
"pushto": "ps",
|
119 |
+
"panjabi": "pa",
|
120 |
+
"moldavian": "ro",
|
121 |
+
"moldovan": "ro",
|
122 |
+
"sinhalese": "si",
|
123 |
+
"castilian": "es",
|
124 |
+
}
|
125 |
+
|
126 |
+
_FROM_LANGUAGE_NAME = {
|
127 |
+
**{language.name.lower(): language for language in LANGUAGES}
|
128 |
+
}
|
129 |
+
|
130 |
+
def get_language_from_code(language_code, default=None) -> Language:
|
131 |
+
"""Return the language name from the language code."""
|
132 |
+
return _TO_LANGUAGE_CODE.get(language_code, default)
|
133 |
+
|
134 |
+
def get_language_from_name(language, default=None) -> Language:
|
135 |
+
"""Return the language code from the language name."""
|
136 |
+
return _FROM_LANGUAGE_NAME.get(language.lower() if language else None, default)
|
137 |
+
|
138 |
+
def get_language_names():
|
139 |
+
"""Return a list of language names."""
|
140 |
+
return [language.name for language in LANGUAGES]
|
141 |
+
|
142 |
+
if __name__ == "__main__":
|
143 |
+
# Test lookup
|
144 |
+
print(get_language_from_code('en'))
|
145 |
+
print(get_language_from_name('English'))
|
146 |
+
|
147 |
+
print(get_language_names())
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
--index-url https://download.pytorch.org/whl/cu121
|
3 |
+
torch>=2.1.1
|
4 |
+
torchvision
|
5 |
+
torchaudio
|
subtitle_manager.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
class Subtitle():
|
4 |
+
def __init__(self,ext="srt"):
|
5 |
+
sub_dict = {
|
6 |
+
"srt":{
|
7 |
+
"coma": ",",
|
8 |
+
"header": "",
|
9 |
+
"format": lambda i,segment : f"{i + 1}\n{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n",
|
10 |
+
},
|
11 |
+
"vtt":{
|
12 |
+
"coma": ".",
|
13 |
+
"header": "WebVTT\n\n",
|
14 |
+
"format": lambda i,segment : f"{self.timeformat(segment['timestamp'][0])} --> {self.timeformat(segment['timestamp'][1] if segment['timestamp'][1] != None else segment['timestamp'][0])}\n{segment['text']}\n\n",
|
15 |
+
},
|
16 |
+
"txt":{
|
17 |
+
"coma": "",
|
18 |
+
"header": "",
|
19 |
+
"format": lambda i,segment : f"{segment['text']}\n",
|
20 |
+
},
|
21 |
+
}
|
22 |
+
|
23 |
+
self.ext = ext
|
24 |
+
self.coma = sub_dict[ext]["coma"]
|
25 |
+
self.header = sub_dict[ext]["header"]
|
26 |
+
self.format = sub_dict[ext]["format"]
|
27 |
+
|
28 |
+
def timeformat(self,time):
|
29 |
+
hours = time // 3600
|
30 |
+
minutes = (time - hours * 3600) // 60
|
31 |
+
seconds = time - hours * 3600 - minutes * 60
|
32 |
+
milliseconds = (time - int(time)) * 1000
|
33 |
+
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}{self.coma}{int(milliseconds):03d}"
|
34 |
+
|
35 |
+
def get_subtitle(self,segments):
|
36 |
+
output = self.header
|
37 |
+
for i, segment in enumerate(segments):
|
38 |
+
if segment['text'].startswith(' '):
|
39 |
+
segment['text'] = segment['text'][1:]
|
40 |
+
try:
|
41 |
+
output += self.format(i,segment)
|
42 |
+
except Exception as e:
|
43 |
+
print(e,segment)
|
44 |
+
|
45 |
+
return output
|
46 |
+
|
47 |
+
def write_subtitle(self, segments, output_file):
|
48 |
+
output_file += "."+self.ext
|
49 |
+
subtitle = self.get_subtitle(segments)
|
50 |
+
|
51 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
52 |
+
f.write(subtitle)
|