Chatty_Ashe / app.py
gdnartea's picture
Update app.py
6a5c442 verified
raw
history blame
3.37 kB
import gradio as gr
import json
import librosa
import os
import soundfile as sf
import tempfile
import uuid
import torch
from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED
SAMPLE_RATE = 16000 # Hz
MAX_AUDIO_MINUTES = 10 # wont try to transcribe if longer than this
model = ASRModel.from_pretrained("nvidia/canary-1b")
model.eval()
# make sure beam size always 1 for consistency
model.change_decoding_strategy(None)
decoding_cfg = model.cfg.decoding
decoding_cfg.beam.beam_size = 1
model.change_decoding_strategy(decoding_cfg)
# setup for buffered inference
model.cfg.preprocessor.dither = 0.0
model.cfg.preprocessor.pad_to = 0
feature_stride = model.cfg.preprocessor['window_stride']
model_stride_in_secs = feature_stride * 8 # 8 = model stride, which is 8 for FastConformer
frame_asr = FrameBatchMultiTaskAED(
asr_model=model,
frame_len=40.0,
total_buffer=40.0,
batch_size=16,
)
amp_dtype = torch.float16
def convert_audio(audio_filepath, tmpdir, utt_id):
"""
Convert all files to monochannel 16 kHz wav files.
Do not convert and raise error if audio too long.
Returns output filename and duration.
"""
data, sr = librosa.load(audio_filepath, sr=None, mono=True)
duration = librosa.get_duration(y=data, sr=sr)
if duration / 60.0 > MAX_AUDIO_MINUTES:
raise gr.Error(
f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. "
"If you wish, you may trim the audio using the Audio viewer in Step 1 "
"(click on the scissors icon to start trimming audio)."
)
if sr != SAMPLE_RATE:
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
out_filename = os.path.join(tmpdir, utt_id + '.wav')
# save output audio
sf.write(out_filename, data, SAMPLE_RATE)
return out_filename, duration
def transcribe(audio_filepath):
if audio_filepath is None:
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
utt_id = uuid.uuid4()
with tempfile.TemporaryDirectory() as tmpdir:
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
# make manifest file and save
manifest_data = {
"audio_filepath": converted_audio_filepath,
"source_lang": "en",
"target_lang": "en",
"taskname": "asr",
"pnc": "no",
"answer": "predict",
"duration": str(duration),
}
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
with open(manifest_filepath, 'w') as fout:
line = json.dumps(manifest_data)
fout.write(line + '\n')
# call transcribe, passing in manifest filepath
if duration < 40:
output_text = model.transcribe(manifest_filepath)[0]
else: # do buffered inference
with torch.cuda.amp.autocast(dtype=amp_dtype): # TODO: make it work if no cuda
with torch.no_grad():
hyps = get_buffered_pred_feat_multitaskAED(
frame_asr,
model.cfg.preprocessor,
model_stride_in_secs,
model.device,
manifest=manifest_filepath,
filepaths=None,
)
output_text = hyps[0].text
return output_text
iface = gr.Interface(
fn=transcribe,
inputs=[gr.inputs.Audio(sources="microphone", type="filepath")],
outputs="text")
iface.queue()
iface.launch()