File size: 3,879 Bytes
d03da7e
 
 
723cd11
 
 
 
d03da7e
 
 
723cd11
 
 
 
 
d03da7e
 
 
 
 
 
 
723cd11
 
 
 
 
 
d03da7e
723cd11
 
 
 
 
 
d03da7e
723cd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d03da7e
 
723cd11
d03da7e
 
 
 
 
 
 
723cd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d03da7e
 
 
723cd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json

from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder

ASR_SAMPLING_RATE = 16_000

ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        ASR_LANGUAGES[iso] = name

MODEL_ID = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)


# lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename="decoding_config.json",
#     subfolder="mms-1b-all",
# )

# with open(lm_decoding_configfile) as f:
#     lm_decoding_config = json.loads(f.read())

# # allow language model decoding for "eng"

# decoding_config = lm_decoding_config["eng"]

# lm_file = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename=decoding_config["lmfile"].rsplit("/", 1)[1],
#     subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
# )
# token_file = hf_hub_download(
#     repo_id="facebook/mms-cclms",
#     filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
#     subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
# )
# lexicon_file = None
# if decoding_config["lexiconfile"] is not None:
#     lexicon_file = hf_hub_download(
#         repo_id="facebook/mms-cclms",
#         filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
#         subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
#     )
    
# beam_search_decoder = ctc_decoder(
#     lexicon=lexicon_file,
#     tokens=token_file,
#     lm=lm_file,
#     nbest=1,
#     beam_size=500,
#     beam_size_token=50,
#     lm_weight=float(decoding_config["lmweight"]),
#     word_score=float(decoding_config["wordscore"]),
#     sil_score=float(decoding_config["silweight"]),
#     blank_token="<s>",
# )


def transcribe(
    audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
):
    if type(microphone) is dict:
        # HACK: microphone variable is a dict when running on examples
        microphone = microphone["name"]
    audio_fp = (
        file_upload if "upload" in str(audio_source or "").lower() else microphone
    )
    
    if audio_fp is None:
        return "ERROR: You have to either use the microphone or upload an audio file"
    
    audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]

    lang_code = lang.split()[0]
    processor.tokenizer.set_target_lang(lang_code)
    model.load_adapter(lang_code)

    inputs = processor(
        audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
    )

    # set device
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif (
        hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
        and torch.backends.mps.is_built()
    ):
        device = torch.device("mps")
    else:
        device = torch.device("cpu")

    model.to(device)
    inputs = inputs.to(device)

    with torch.no_grad():
        outputs = model(**inputs).logits

    if lang_code != "eng" or True:
        ids = torch.argmax(outputs, dim=-1)[0]
        transcription = processor.decode(ids)
    else:
        assert False 
        # beam_search_result = beam_search_decoder(outputs.to("cpu"))
        # transcription = " ".join(beam_search_result[0][0].words).strip()

    return transcription


ASR_EXAMPLES = [
    [None, "assets/english.mp3", None, "eng (English)"],
    # [None, "assets/tamil.mp3", None, "tam (Tamil)"],
    # [None, "assets/burmese.mp3", None, "mya (Burmese)"],
]

ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model. 
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
"""