minority-asr / app.py
mizoru's picture
.
2a767a3
import gradio as gr
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
MODELS = {
"Tatar": {"model_id": "sammy786/wav2vec2-xlsr-tatar", "has_lm": False},
"Chuvash": {"model_id": "sammy786/wav2vec2-xlsr-chuvash", "has_lm": False},
"Bashkir": {"model_id": "AigizK/wav2vec2-large-xls-r-300m-bashkir-cv7_opt", "has_lm": True},
"Erzya": {"model_id": "DrishtiSharma/wav2vec2-large-xls-r-300m-myv-v1", "has_lm": False}
}
CACHED_MODELS_BY_ID = {}
LANGUAGES_ENG = list(MODELS.keys())
LANGUAGES_RUS = ["Татарский", "Чувашский", "Башкирский", "Эрзянский"]
RUS2ENG = {k:v for k,v in zip(LANGUAGES_RUS, LANGUAGES_ENG)}
LANG2YDX = {"Tatar": 'tt',
"Chuvash": "ba",
"Bashkir": "cv",
"Erzya": None,
"English": 'en',
'Русский': 'ru'
}
def run(input_file, language, decoding_type, lang):
language = RUS2ENG.get(language, language)
model = MODELS.get(language, None)
model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
if model_instance is None:
model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
CACHED_MODELS_BY_ID[model["model_id"]] = model_instance
if decoding_type == "LM":
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor, decoder=processor.decoder)
else:
processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor, decoder=None)
transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]
if LANG2YDX[language]:
url = 'https://translate.yandex.ru/?lang=' + LANG2YDX[language] + '-' + LANG2YDX[lang] + '&text=' + transcription # ru-fr&text=
if lang == "Русский":
label = 'Посмотреть перевод'
else: label = 'Check the translation'
html = f'<a href="{url}" target="_blank">{label}</a>'
else: html = None
return transcription, html
def update_decoding(language):
language = RUS2ENG.get(language, language)
if MODELS[language]['has_lm']:
return gr.Radio.update(visible=True)
else: return gr.Radio.update(visible=False, value='Greedy')
def update_interface(lang):
if lang == 'Русский':
languages = gr.Radio.update(label='Язык записи', choices=LANGUAGES_RUS)
audio = gr.Audio.update(label='Скажите что-нибудь...')
# btn = gr.Button.update(value='Расшифровать')
decoding = gr.Radio.update(label='Тип декодирования')
elif lang == 'English':
languages = gr.Radio.update(label='Language', choices=LANGUAGES_ENG)
audio = gr.Audio.update(label='Say something...')
# btn = gr.Button.update(value='Transcribe')
decoding = gr.Radio.update(label='Decoding type')
return languages, audio, decoding
with gr.Blocks() as blocks:
lang = gr.Radio(label="Выберите язык интерфейса / Interface language", choices=['Русский','English'])
languages = gr.Radio(label="Language", choices=LANGUAGES_RUS)
audio = gr.Audio(source="microphone", type="filepath", label="Скажите что-нибудь...")
decoding = gr.Radio(label="Тип декодирования", choices=["Greedy", "LM"], visible=False, type='index')
btn = gr.Button('Расшифровать / Transcribe')
output = gr.Textbox(show_label=False)
translation = gr.HTML()
languages.change(fn=update_decoding, inputs=[languages], outputs=[decoding])
lang.change(fn=update_interface, inputs=[lang], outputs=[languages, audio, decoding])
btn.click(fn=run, inputs=[audio, languages, decoding, lang], outputs=[output, translation])
blocks.launch(enable_queue=True, debug=True)