File size: 3,879 Bytes
d03da7e 723cd11 d03da7e 723cd11 d03da7e 723cd11 d03da7e 723cd11 d03da7e 723cd11 d03da7e 723cd11 d03da7e 723cd11 d03da7e 723cd11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import librosa
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch
import json
from huggingface_hub import hf_hub_download
from torchaudio.models.decoder import ctc_decoder
ASR_SAMPLING_RATE = 16_000
ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
ASR_LANGUAGES[iso] = name
MODEL_ID = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
# lm_decoding_config = {}
# lm_decoding_configfile = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename="decoding_config.json",
# subfolder="mms-1b-all",
# )
# with open(lm_decoding_configfile) as f:
# lm_decoding_config = json.loads(f.read())
# # allow language model decoding for "eng"
# decoding_config = lm_decoding_config["eng"]
# lm_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lmfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lmfile"].rsplit("/", 1)[0],
# )
# token_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["tokensfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["tokensfile"].rsplit("/", 1)[0],
# )
# lexicon_file = None
# if decoding_config["lexiconfile"] is not None:
# lexicon_file = hf_hub_download(
# repo_id="facebook/mms-cclms",
# filename=decoding_config["lexiconfile"].rsplit("/", 1)[1],
# subfolder=decoding_config["lexiconfile"].rsplit("/", 1)[0],
# )
# beam_search_decoder = ctc_decoder(
# lexicon=lexicon_file,
# tokens=token_file,
# lm=lm_file,
# nbest=1,
# beam_size=500,
# beam_size_token=50,
# lm_weight=float(decoding_config["lmweight"]),
# word_score=float(decoding_config["wordscore"]),
# sil_score=float(decoding_config["silweight"]),
# blank_token="<s>",
# )
def transcribe(
audio_source=None, microphone=None, file_upload=None, lang="eng (English)"
):
if type(microphone) is dict:
# HACK: microphone variable is a dict when running on examples
microphone = microphone["name"]
audio_fp = (
file_upload if "upload" in str(audio_source or "").lower() else microphone
)
if audio_fp is None:
return "ERROR: You have to either use the microphone or upload an audio file"
audio_samples = librosa.load(audio_fp, sr=ASR_SAMPLING_RATE, mono=True)[0]
lang_code = lang.split()[0]
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code)
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
if torch.cuda.is_available():
device = torch.device("cuda")
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and torch.backends.mps.is_built()
):
device = torch.device("mps")
else:
device = torch.device("cpu")
model.to(device)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
if lang_code != "eng" or True:
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
else:
assert False
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
# transcription = " ".join(beam_search_result[0][0].words).strip()
return transcription
ASR_EXAMPLES = [
[None, "assets/english.mp3", None, "eng (English)"],
# [None, "assets/tamil.mp3", None, "tam (Tamil)"],
# [None, "assets/burmese.mp3", None, "mya (Burmese)"],
]
ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model.
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
""" |