Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
import librosa | |
import os | |
import soundfile as sf | |
import tempfile | |
import uuid | |
import torch | |
from nemo.collections.asr.models import ASRModel | |
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED | |
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED | |
SAMPLE_RATE = 16000 # Hz | |
MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this | |
model = ASRModel.from_pretrained("nvidia/canary-1b") | |
model.eval() | |
# make sure beam size always 1 for consistency | |
model.change_decoding_strategy(None) | |
decoding_cfg = model.cfg.decoding | |
decoding_cfg.beam.beam_size = 1 | |
model.change_decoding_strategy(decoding_cfg) | |
# setup for buffered inference | |
model.cfg.preprocessor.dither = 0.0 | |
model.cfg.preprocessor.pad_to = 0 | |
feature_stride = model.cfg.preprocessor['window_stride'] | |
model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer | |
frame_asr = FrameBatchMultiTaskAED( | |
asr_model=model, | |
frame_len=40.0, | |
total_buffer=40.0, | |
batch_size=16, | |
) | |
amp_dtype = torch.float16 | |
def transcribe(audio_filepath, src_lang="en", tgt_lang="en", pnc="yes"): | |
if audio_filepath is None: | |
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone") | |
utt_id = uuid.uuid4() | |
with tempfile.TemporaryDirectory() as tmpdir: | |
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id)) | |
# map src_lang and tgt_lang from long versions to short | |
LANG_LONG_TO_LANG_SHORT = { | |
"English": "en", | |
"Spanish": "es", | |
"French": "fr", | |
"German": "de", | |
} | |
if src_lang not in LANG_LONG_TO_LANG_SHORT.keys(): | |
raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") | |
else: | |
src_lang = LANG_LONG_TO_LANG_SHORT[src_lang] | |
if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys(): | |
raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") | |
else: | |
tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang] | |
# infer taskname from src_lang and tgt_lang | |
if src_lang == tgt_lang: | |
taskname = "asr" | |
else: | |
taskname = "s2t_translation" | |
# update pnc variable to be "yes" or "no" | |
pnc = "yes" if pnc else "no" | |
# make manifest file and save | |
manifest_data = { | |
"audio_filepath": converted_audio_filepath, | |
"source_lang": src_lang, | |
"target_lang": tgt_lang, | |
"taskname": taskname, | |
"pnc": pnc, | |
"answer": "predict", | |
"duration": str(duration), | |
} | |
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json') | |
with open(manifest_filepath, 'w') as fout: | |
line = json.dumps(manifest_data) | |
fout.write(line + '\n') | |
# call transcribe, passing in manifest filepath | |
if duration < 40: | |
output_text = model.transcribe(manifest_filepath)[0] | |
else: # do buffered inference | |
with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda | |
with torch.no_grad(): | |
hyps = get_buffered_pred_feat_multitaskAED( | |
frame_asr, | |
model.cfg.preprocessor, | |
model_stride_in_secs, | |
model.device, | |
manifest=manifest_filepath, | |
filepaths=None, | |
) | |
output_text = hyps[0].text | |
return output_text | |
iface = gr.Interface(fn=transcribe, inputs=gr.Audio(sources="microphone"), outputs="text") | |
iface.launch() |