# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang) # # See LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import lru_cache import sherpa_onnx from huggingface_hub import hf_hub_download sample_rate = 16000 def _get_nn_model_filename( repo_id: str, filename: str, subfolder: str = "exp", ) -> str: nn_model_filename = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) return nn_model_filename get_file = _get_nn_model_filename def _get_bpe_model_filename( repo_id: str, filename: str = "bpe.model", subfolder: str = "data/lang_bpe_500", ) -> str: bpe_model_filename = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) return bpe_model_filename def _get_token_filename( repo_id: str, filename: str = "tokens.txt", subfolder: str = "data/lang_char", ) -> str: token_filename = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, ) return token_filename @lru_cache(maxsize=10) def _get_whisper_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: name = repo_id.split("-")[1] assert name in ("tiny.en", "base.en", "small.en", "medium.en"), repo_id full_repo_id = "csukuangfj/sherpa-onnx-whisper-" + name encoder = _get_nn_model_filename( repo_id=full_repo_id, filename=f"{name}-encoder.int8.onnx", subfolder=".", ) decoder = _get_nn_model_filename( repo_id=full_repo_id, filename=f"{name}-decoder.int8.onnx", subfolder=".", ) tokens = _get_token_filename( repo_id=full_repo_id, subfolder=".", filename=f"{name}-tokens.txt" ) recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( encoder=encoder, decoder=decoder, tokens=tokens, num_threads=2, tail_paddings=2000, ) return recognizer @lru_cache(maxsize=10) def _get_paraformer_zh_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert repo_id in [ "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28", ], repo_id nn_model = _get_nn_model_filename( repo_id=repo_id, filename="model.int8.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( paraformer=nn_model, tokens=tokens, num_threads=2, sample_rate=sample_rate, feature_dim=80, decoding_method="greedy_search", debug=False, ) return recognizer @lru_cache(maxsize=5) def _get_chinese_dialect_models(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert repo_id in [ "csukuangfj/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04", ], repo_id nn_model = _get_nn_model_filename( repo_id=repo_id, filename="model.int8.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_telespeech_ctc( model=nn_model, tokens=tokens, num_threads=2, ) return recognizer @lru_cache(maxsize=10) def _get_russian_pre_trained_model_ctc(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert repo_id in ( "csukuangfj/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24", ), repo_id model = _get_nn_model_filename( repo_id=repo_id, filename="model.int8.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( model=model, tokens=tokens, num_threads=2, ) return recognizer @lru_cache(maxsize=10) def _get_russian_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert repo_id in ( "alphacep/vosk-model-ru", "alphacep/vosk-model-small-ru", "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24", ), repo_id if repo_id == "alphacep/vosk-model-ru": model_dir = "am-onnx" encoder = "encoder.onnx" model_type = "transducer" elif repo_id == "alphacep/vosk-model-small-ru": model_dir = "am" encoder = "encoder.onnx" model_type = "transducer" elif repo_id == "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24": model_dir = "." encoder = "encoder.int8.onnx" model_type = "nemo_transducer" encoder_model = _get_nn_model_filename( repo_id=repo_id, filename=encoder, subfolder=model_dir, ) decoder_model = _get_nn_model_filename( repo_id=repo_id, filename="decoder.onnx", subfolder=model_dir, ) joiner_model = _get_nn_model_filename( repo_id=repo_id, filename="joiner.onnx", subfolder=model_dir, ) if repo_id == "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24": tokens = _get_token_filename(repo_id=repo_id, subfolder=".") else: tokens = _get_token_filename(repo_id=repo_id, subfolder="lang") recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( tokens=tokens, encoder=encoder_model, decoder=decoder_model, joiner=joiner_model, num_threads=2, sample_rate=16000, feature_dim=80, model_type=model_type, ) return recognizer @lru_cache(maxsize=2) def get_punct_model() -> sherpa_onnx.OfflinePunctuation: model = _get_nn_model_filename( repo_id="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12", filename="model.onnx", subfolder=".", ) config = sherpa_onnx.OfflinePunctuationConfig( model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model), ) punct = sherpa_onnx.OfflinePunctuation(config) return punct def get_vad() -> sherpa_onnx.VoiceActivityDetector: vad_model = _get_nn_model_filename( repo_id="csukuangfj/vad", filename="silero_vad_v5.onnx", subfolder=".", ) config = sherpa_onnx.VadModelConfig() config.silero_vad.model = vad_model config.silero_vad.min_silence_duration = 0.15 config.silero_vad.min_speech_duration = 0.25 config.sample_rate = sample_rate vad = sherpa_onnx.VoiceActivityDetector( config, buffer_size_in_seconds=180, ) return vad @lru_cache(maxsize=10) def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: if repo_id in chinese_models: return chinese_models[repo_id](repo_id) elif repo_id in chinese_dialect_models: return chinese_dialect_models[repo_id](repo_id) elif repo_id in english_models: return english_models[repo_id](repo_id) elif repo_id in chinese_english_mixed_models: return chinese_english_mixed_models[repo_id](repo_id) elif repo_id in russian_models: return russian_models[repo_id](repo_id) elif repo_id in korean_models: return korean_models[repo_id](repo_id) elif repo_id in thai_models: return thai_models[repo_id](repo_id) elif repo_id in japanese_models: return japanese_models[repo_id](repo_id) elif repo_id in zh_en_ko_ja_yue_models: return zh_en_ko_ja_yue_models[repo_id](repo_id) else: raise ValueError(f"Unsupported repo_id: {repo_id}") def _get_wenetspeech_pre_trained_model(repo_id): assert repo_id in ( "csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23", ), repo_id encoder_model = _get_nn_model_filename( repo_id=repo_id, filename="encoder-epoch-99-avg-1.onnx", subfolder=".", ) decoder_model = _get_nn_model_filename( repo_id=repo_id, filename="decoder-epoch-99-avg-1.onnx", subfolder=".", ) joiner_model = _get_nn_model_filename( repo_id=repo_id, filename="joiner-epoch-99-avg-1.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( tokens=tokens, encoder=encoder_model, decoder=decoder_model, joiner=joiner_model, num_threads=2, sample_rate=16000, feature_dim=80, decoding_method="greedy_search", ) return recognizer def _get_multi_zh_hans_pre_trained_model(repo_id): assert repo_id in ("zrjin/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2",), repo_id encoder_model = _get_nn_model_filename( repo_id=repo_id, filename="encoder-epoch-20-avg-1.onnx", subfolder=".", ) decoder_model = _get_nn_model_filename( repo_id=repo_id, filename="decoder-epoch-20-avg-1.onnx", subfolder=".", ) joiner_model = _get_nn_model_filename( repo_id=repo_id, filename="joiner-epoch-20-avg-1.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( tokens=tokens, encoder=encoder_model, decoder=decoder_model, joiner=joiner_model, num_threads=2, sample_rate=16000, feature_dim=80, decoding_method="greedy_search", ) return recognizer def _get_english_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert ( repo_id == "yfyeung/icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04" ), repo_id encoder_model = _get_nn_model_filename( repo_id=repo_id, filename="encoder-epoch-30-avg-4.onnx", subfolder="exp", ) decoder_model = _get_nn_model_filename( repo_id=repo_id, filename="decoder-epoch-30-avg-4.onnx", subfolder="exp", ) joiner_model = _get_nn_model_filename( repo_id=repo_id, filename="joiner-epoch-30-avg-4.onnx", subfolder="exp", ) tokens = _get_token_filename(repo_id=repo_id, subfolder="lang_bpe_500") recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( tokens=tokens, encoder=encoder_model, decoder=decoder_model, joiner=joiner_model, num_threads=2, sample_rate=16000, feature_dim=80, decoding_method="greedy_search", ) return recognizer @lru_cache(maxsize=10) def _get_korean_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert repo_id in ("k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24",), repo_id encoder_model = _get_nn_model_filename( repo_id=repo_id, filename="encoder-epoch-99-avg-1.int8.onnx", subfolder=".", ) decoder_model = _get_nn_model_filename( repo_id=repo_id, filename="decoder-epoch-99-avg-1.onnx", subfolder=".", ) joiner_model = _get_nn_model_filename( repo_id=repo_id, filename="joiner-epoch-99-avg-1.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( tokens=tokens, encoder=encoder_model, decoder=decoder_model, joiner=joiner_model, num_threads=2, sample_rate=16000, feature_dim=80, ) return recognizer @lru_cache(maxsize=10) def _get_japanese_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert repo_id in ("reazon-research/reazonspeech-k2-v2",), repo_id encoder_model = _get_nn_model_filename( repo_id=repo_id, filename="encoder-epoch-99-avg-1.int8.onnx", subfolder=".", ) decoder_model = _get_nn_model_filename( repo_id=repo_id, filename="decoder-epoch-99-avg-1.onnx", subfolder=".", ) joiner_model = _get_nn_model_filename( repo_id=repo_id, filename="joiner-epoch-99-avg-1.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( tokens=tokens, encoder=encoder_model, decoder=decoder_model, joiner=joiner_model, num_threads=2, sample_rate=16000, feature_dim=80, ) return recognizer @lru_cache(maxsize=10) def _get_yifan_thai_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer: assert repo_id in ( "yfyeung/icefall-asr-gigaspeech2-th-zipformer-2024-06-20", ), repo_id encoder_model = _get_nn_model_filename( repo_id=repo_id, filename="encoder-epoch-12-avg-5.int8.onnx", subfolder="exp", ) decoder_model = _get_nn_model_filename( repo_id=repo_id, filename="decoder-epoch-12-avg-5.onnx", subfolder="exp", ) joiner_model = _get_nn_model_filename( repo_id=repo_id, filename="joiner-epoch-12-avg-5.int8.onnx", subfolder="exp", ) tokens = _get_token_filename(repo_id=repo_id, subfolder="data/lang_bpe_2000") recognizer = sherpa_onnx.OfflineRecognizer.from_transducer( tokens=tokens, encoder=encoder_model, decoder=decoder_model, joiner=joiner_model, num_threads=2, sample_rate=16000, feature_dim=80, ) return recognizer @lru_cache(maxsize=10) def _get_sense_voice_pre_trained_model( repo_id: str, decoding_method: str, num_active_paths: int, ) -> sherpa_onnx.OfflineRecognizer: assert repo_id in [ "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17", ], repo_id nn_model = _get_nn_model_filename( repo_id=repo_id, filename="model.int8.onnx", subfolder=".", ) tokens = _get_token_filename(repo_id=repo_id, subfolder=".") recognizer = sherpa_onnx.OfflineRecognizer.from_sense_voice( model=nn_model, tokens=tokens, num_threads=2, sample_rate=sample_rate, feature_dim=80, decoding_method="greedy_search", debug=True, use_itn=True, ) return recognizer chinese_dialect_models = { "csukuangfj/sherpa-onnx-telespeech-ctc-int8-zh-2024-06-04": _get_chinese_dialect_models, } zh_en_ko_ja_yue_models = { "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17": _get_sense_voice_pre_trained_model, } chinese_models = { "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28": _get_paraformer_zh_pre_trained_model, "csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23": _get_wenetspeech_pre_trained_model, # noqa "zrjin/sherpa-onnx-zipformer-multi-zh-hans-2023-9-2": _get_multi_zh_hans_pre_trained_model, # noqa } english_models = { "whisper-tiny.en": _get_whisper_model, "whisper-base.en": _get_whisper_model, "whisper-small.en": _get_whisper_model, "whisper-distil-small.en": _get_whisper_model, "whisper-medium.en": _get_whisper_model, "whisper-distil-medium.en": _get_whisper_model, "yfyeung/icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04": _get_english_model, # noqa } chinese_english_mixed_models = { "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28": _get_paraformer_zh_pre_trained_model, } korean_models = { "k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24": _get_korean_pre_trained_model, } russian_models = { "csukuangfj/sherpa-onnx-nemo-transducer-giga-am-russian-2024-10-24": _get_russian_pre_trained_model, "csukuangfj/sherpa-onnx-nemo-ctc-giga-am-russian-2024-10-24": _get_russian_pre_trained_model_ctc, "alphacep/vosk-model-ru": _get_russian_pre_trained_model, "alphacep/vosk-model-small-ru": _get_russian_pre_trained_model, } thai_models = { "yfyeung/icefall-asr-gigaspeech2-th-zipformer-2024-06-20": _get_yifan_thai_pretrained_model, } japanese_models = { "reazon-research/reazonspeech-k2-v2": _get_japanese_pre_trained_model } language_to_models = { "超多种中文方言": list(chinese_dialect_models.keys()), "Chinese+English": list(chinese_english_mixed_models.keys()), "Chinese+English+Korean+Japanese+Cantoes(中英韩日粤语)": list( zh_en_ko_ja_yue_models.keys() ), "Chinese": list(chinese_models.keys()), "English": list(english_models.keys()), "Russian": list(russian_models.keys()), "Korean": list(korean_models.keys()), "Thai": list(thai_models.keys()), "Japanese": list(japanese_models.keys()), }