Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM | |
MODELS = { | |
"Tatar": {"model_id": "sammy786/wav2vec2-xlsr-tatar", "has_lm": False}, | |
"Chuvash": {"model_id": "sammy786/wav2vec2-xlsr-chuvash", "has_lm": False} | |
} | |
CACHED_MODELS_BY_ID = {} | |
LANGUAGES = sorted(MODELS.keys()) | |
def run(input_file, language, decoding_type, history): | |
#logger.info(f"Running ASR {language}-{model_size}-{decoding_type} for {input_file}") | |
model = MODELS.get(language, None) | |
if decoding_type == "LM" and not model["has_lm"]: | |
history.append({ | |
"error_message": f"LM not available for {language} language :(" | |
}) | |
else: | |
# model_instance = AutoModelForCTC.from_pretrained(model["model_id"]) | |
model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None) | |
if model_instance is None: | |
model_instance = AutoModelForCTC.from_pretrained(model["model_id"]) | |
CACHED_MODELS_BY_ID[model["model_id"]] = model_instance | |
if decoding_type == "LM": | |
processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"]) | |
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, decoder=processor.decoder) | |
else: | |
processor = Wav2Vec2Processor.from_pretrained(model["model_id"]) | |
asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, decoder=None) | |
transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"] | |
return transcription | |
gr.Interface( | |
run, | |
inputs=[ | |
gr.Audio(source="microphone", type="filepath", label="Record something..."), | |
gr.Radio(label="Language", choices=LANGUAGES), | |
gr.Radio(label="Decoding type", choices=["greedy", "LM"]), | |
# gr.inputs.Radio(label="Model size", choices=["300M", "1B"]), | |
"state" | |
], | |
outputs=[ | |
gr.TextBox | |
], | |
allow_screenshot=False, | |
allow_flagging="never", | |
theme="grass" | |
).launch(enable_queue=True) |