File size: 2,287 Bytes
2b10872
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM

MODELS = {
    "Tatar": {"model_id": "sammy786/wav2vec2-xlsr-tatar", "has_lm": False},
    "Chuvash": {"model_id": "sammy786/wav2vec2-xlsr-chuvash", "has_lm": False}
    }
    
CACHED_MODELS_BY_ID = {}
    
LANGUAGES = sorted(MODELS.keys())

def run(input_file, language, decoding_type, history):

    #logger.info(f"Running ASR {language}-{model_size}-{decoding_type} for {input_file}")

    model = MODELS.get(language, None)

    
    if decoding_type == "LM" and not model["has_lm"]:
        history.append({
            "error_message": f"LM not available for {language} language :("
        })
    else:

        # model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
        model_instance = CACHED_MODELS_BY_ID.get(model["model_id"], None)
        if model_instance is None:
            model_instance = AutoModelForCTC.from_pretrained(model["model_id"])
            CACHED_MODELS_BY_ID[model["model_id"]] = model_instance

        if decoding_type == "LM":
            processor = Wav2Vec2ProcessorWithLM.from_pretrained(model["model_id"])
            asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer, 
                           feature_extractor=processor.feature_extractor, decoder=processor.decoder)
        else:
            processor = Wav2Vec2Processor.from_pretrained(model["model_id"])
            asr = pipeline("automatic-speech-recognition", model=model_instance, tokenizer=processor.tokenizer, 
                           feature_extractor=processor.feature_extractor, decoder=None)

        transcription = asr(input_file, chunk_length_s=5, stride_length_s=1)["text"]

    

    return transcription
    
gr.Interface(
    run,
    inputs=[
        gr.Audio(source="microphone", type="filepath", label="Record something..."),
        gr.Radio(label="Language", choices=LANGUAGES),
        gr.Radio(label="Decoding type", choices=["greedy", "LM"]),
        # gr.inputs.Radio(label="Model size", choices=["300M", "1B"]),
        "state"
    ],
    outputs=[
        gr.TextBox
    ],
    allow_screenshot=False,
    allow_flagging="never",
    theme="grass"
).launch(enable_queue=True)