Chatty_Ashe / app.py
gdnartea's picture
Update app.py
b09cd28 verified
raw
history blame
3.3 kB
import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import torch
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
SAMPLE_RATE = 16000 # Hz
MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
model = ASRModel.from_pretrained("nvidia/canary-1b")
model.eval()
# make sure beam size always 1 for consistency
model.change_decoding_strategy(None)
decoding_cfg = model.cfg.decoding
decoding_cfg.beam.beam_size = 1
model.change_decoding_strategy(decoding_cfg)
# setup for buffered inference
model.cfg.preprocessor.dither = 0.0
model.cfg.preprocessor.pad_to = 0
feature_stride = model.cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
frame_asr = FrameBatchMultiTaskAED(
asr_model=model,
frame_len=40.0,
total_buffer=40.0,
batch_size=16,
)
amp_dtype = torch.float16
def transcribe(audio_filepath, src_lang="en", tgt_lang="en", pnc="yes"):
if audio_filepath is None:
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
utt_id = uuid.uuid4()
with tempfile.TemporaryDirectory() as tmpdir:
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
# map src_lang and tgt_lang from long versions to short
LANG_LONG_TO_LANG_SHORT = {
"English": "en",
"Spanish": "es",
"French": "fr",
"German": "de",
}
if src_lang not in LANG_LONG_TO_LANG_SHORT.keys():
raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
else:
src_lang = LANG_LONG_TO_LANG_SHORT[src_lang]
if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys():
raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
else:
tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang]
# infer taskname from src_lang and tgt_lang
if src_lang == tgt_lang:
taskname = "asr"
else:
taskname = "s2t_translation"
# update pnc variable to be "yes" or "no"
pnc = "yes" if pnc else "no"
# make manifest file and save
manifest_data = {
"audio_filepath": converted_audio_filepath,
"source_lang": src_lang,
"target_lang": tgt_lang,
"taskname": taskname,
"pnc": pnc,
"answer": "predict",
"duration": str(duration),
}
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
with open(manifest_filepath, 'w') as fout:
line = json.dumps(manifest_data)
fout.write(line + '\n')
# call transcribe, passing in manifest filepath
if duration < 40:
output_text = model.transcribe(manifest_filepath)[0]
else: # do buffered inference
with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
with torch.no_grad():
hyps = get_buffered_pred_feat_multitaskAED(
frame_asr,
model.cfg.preprocessor,
model_stride_in_secs,
model.device,
manifest=manifest_filepath,
filepaths=None,
)
output_text = hyps[0].text
return output_text
iface = gr.Interface(fn=transcribe, inputs=gr.Audio(sources="microphone"), outputs="text")
iface.launch()