espnet2_asr_librispeech_100h / gradio_asr_en_libri100.py
jaekookang
add resample
9f79e93
'''Librispeech 100h English ASR demo
@ML2 --> @HuggingFace
2022-02-11
2022-02-16
- changed to HF
- server setting commented
- model cache dir commented
'''
import os
from glob import glob
from loguru import logger
import soundfile as sf
import librosa
# from scipy.io import wavfile
import gradio as gr
from espnet_model_zoo.downloader import ModelDownloader
from espnet2.bin.asr_inference import Speech2Text
# ---------- Settings ----------
GPU_ID = '-1'
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
DEVICE = 'cuda' if GPU_ID != '-1' else 'cpu'
SERVER_PORT = 42208
SERVER_NAME = "0.0.0.0"
SSL_DIR = './keyble_ssl'
# MODEL_DIR = '/home/jkang/HDD4T/jkang/huggingface'
MODEL_DIR = './model'
EXAMPLE_DIR = './examples'
examples = sorted(glob(os.path.join(EXAMPLE_DIR, '*.wav')))
# ---------- Logging ----------
logger.add('app.log', mode='a')
logger.info('============================= App restarted =============================')
# ---------- Model ----------
logger.info('download model')
d = ModelDownloader(MODEL_DIR)
out = d.download_and_unpack("jkang/espnet2_librispeech_100_conformer")
logger.info('model downloaded')
model = Speech2Text.from_pretrained(
asr_train_config=out['asr_train_config'],
asr_model_file=out['asr_model_file']
)
logger.info('model loaded')
def predict(wav_file):
logger.info('wav file loaded')
# speech, rate = sf.read(wav_file)
speech, rate = librosa.load(wav_file, sr=16000)
# rate, speech = wavfile.read(wav_file)
nbests = model(speech)
text, *_ = nbests[0]
logger.info('predicted')
return text
iface = gr.Interface(
predict,
title='ESPNet2 ASR Librispeech Conformer (trained on clean-100h)',
description='Upload your wav file to test the model',
inputs=[
gr.inputs.Audio(label='wav file', source='microphone', type='filepath')
],
outputs=[
gr.outputs.Textbox(label='decoding result'),
],
examples=examples,
article='<p style="text-align:center">Model URL<a target="_blank" href="https://huggingface.co/jkang/espnet2_librispeech_100_conformer">🤗</a></p>',
)
if __name__ == '__main__':
try:
iface.launch(debug=True,
# server_name=SERVER_NAME,
# server_port=SERVER_PORT,
enable_queue=True,
# ssl_keyfile=SSL_DIR,
# ssl_certfile=SSL_DIR
)
except KeyboardInterrupt as e:
print(e)
finally:
iface.close()