csukuangfj
commited on
Commit
•
06b4245
1
Parent(s):
994c238
add chinese models
Browse files
model.py
CHANGED
@@ -192,7 +192,9 @@ def get_vad() -> sherpa_onnx.VoiceActivityDetector:
|
|
192 |
|
193 |
@lru_cache(maxsize=10)
|
194 |
def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
|
195 |
-
if repo_id in
|
|
|
|
|
196 |
return english_models[repo_id](repo_id)
|
197 |
elif repo_id in chinese_english_mixed_models:
|
198 |
return chinese_english_mixed_models[repo_id](repo_id)
|
@@ -202,6 +204,49 @@ def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
|
|
202 |
raise ValueError(f"Unsupported repo_id: {repo_id}")
|
203 |
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
english_models = {
|
206 |
"whisper-tiny.en": _get_whisper_model,
|
207 |
"whisper-base.en": _get_whisper_model,
|
@@ -218,6 +263,7 @@ russian_models = {
|
|
218 |
}
|
219 |
|
220 |
language_to_models = {
|
|
|
221 |
"English": list(english_models.keys()),
|
222 |
"Chinese+English": list(chinese_english_mixed_models.keys()),
|
223 |
"Russian": list(russian_models.keys()),
|
|
|
192 |
|
193 |
@lru_cache(maxsize=10)
|
194 |
def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
|
195 |
+
if repo_id in chinese_models:
|
196 |
+
return chinese_models[repo_id](repo_id)
|
197 |
+
elif repo_id in english_models:
|
198 |
return english_models[repo_id](repo_id)
|
199 |
elif repo_id in chinese_english_mixed_models:
|
200 |
return chinese_english_mixed_models[repo_id](repo_id)
|
|
|
204 |
raise ValueError(f"Unsupported repo_id: {repo_id}")
|
205 |
|
206 |
|
207 |
+
def _get_wenetspeech_pre_trained_model(repo_id):
|
208 |
+
assert repo_id in (
|
209 |
+
"csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23",
|
210 |
+
), repo_id
|
211 |
+
|
212 |
+
encoder_model = _get_nn_model_filename(
|
213 |
+
repo_id=repo_id,
|
214 |
+
filename="encoder-epoch-99-avg-1.onnx",
|
215 |
+
subfolder=".",
|
216 |
+
)
|
217 |
+
|
218 |
+
decoder_model = _get_nn_model_filename(
|
219 |
+
repo_id=repo_id,
|
220 |
+
filename="decoder-epoch-99-avg-1.onnx",
|
221 |
+
subfolder=".",
|
222 |
+
)
|
223 |
+
|
224 |
+
joiner_model = _get_nn_model_filename(
|
225 |
+
repo_id=repo_id,
|
226 |
+
filename="joiner-epoch-99-avg-1.onnx",
|
227 |
+
subfolder=".",
|
228 |
+
)
|
229 |
+
|
230 |
+
tokens = _get_token_filename(repo_id=repo_id, subfolder=".")
|
231 |
+
|
232 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
233 |
+
tokens=tokens,
|
234 |
+
encoder=encoder_model,
|
235 |
+
decoder=decoder_model,
|
236 |
+
joiner=joiner_model,
|
237 |
+
num_threads=2,
|
238 |
+
sample_rate=16000,
|
239 |
+
feature_dim=80,
|
240 |
+
decoding_method="greedy_search",
|
241 |
+
)
|
242 |
+
|
243 |
+
return recognizer
|
244 |
+
|
245 |
+
|
246 |
+
chinese_models = {
|
247 |
+
"csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23": _get_wenetspeech_pre_trained_model, # noqa
|
248 |
+
}
|
249 |
+
|
250 |
english_models = {
|
251 |
"whisper-tiny.en": _get_whisper_model,
|
252 |
"whisper-base.en": _get_whisper_model,
|
|
|
263 |
}
|
264 |
|
265 |
language_to_models = {
|
266 |
+
"Chinese": list(chinese_models),
|
267 |
"English": list(english_models.keys()),
|
268 |
"Chinese+English": list(chinese_english_mixed_models.keys()),
|
269 |
"Russian": list(russian_models.keys()),
|