minority-asr / app.py
mizoru's picture
Update app.py
04026a3
raw
history blame
2.45 kB
import gradio as gr
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM
MODELS = {
"Tatar": {"model_id": "sammy786/wav2vec2-xlsr-tatar", "has_lm": False},
"Chuvash": {"model_id": "sammy786/wav2vec2-xlsr-chuvash", "has_lm": False},
"Bashkir": {"model_id": "AigizK/wav2vec2-large-xls-r-300m-bashkir-cv7_opt", "has_lm": True},
"Erzya": {"model_id": "DrishtiSharma/wav2vec2-large-xls-r-300m-myv-v1", "has_lm": True}
}
CACHED_MODELS_BY_ID = {}
LANGUAGES = (MODELS.keys())
def run(input_file, language, decoding_type):
#logger.info(f"Running ASR {language}-{model_size}-{decoding_type} for {input_file}")
model = MODELS.get(language, None)
if decoding_type == "LM" and not model["has_lm"]:
history.append({
"error_message": f"LM not available for {language} language :("
})
else:
# model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
if model_instance is None:
model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
CACHED_MODELS_BY_ID[model["model_id"]] = model_instance
if decoding_type == "LM":
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor, decoder=processor.decoder)
else:
processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor, decoder=None)
transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]
return transcription
gr.Interface(
run,
inputs=[
gr.Audio(source="microphone", type="filepath", label="Record something..."),
gr.Radio(label="Language", choices=LANGUAGES),
gr.Radio(label="Decoding type", choices=["greedy", "LM"])
# gr.inputs.Radio(label="Model size", choices=["300M", "1B"]),
],
outputs=[
gr.Textbox()
],
allow_screenshot=False,
allow_flagging="never",
theme="grass"
).launch(enable_queue=True)