File size: 2,374 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
import re
def read_lists(list_file):
lists = []
with open(list_file, "r", encoding="utf8") as fin:
for line in fin:
lists.append(line.strip())
return lists
def read_non_lang_symbols(non_lang_sym_path):
"""read non-linguistic symbol from file.
The file format is like below:
{NOISE}\n
{BRK}\n
...
Args:
non_lang_sym_path: non-linguistic symbol file path, None means no any
syms.
"""
if non_lang_sym_path is None:
return None
else:
syms = read_lists(non_lang_sym_path)
non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
for sym in syms:
if non_lang_syms_pattern.fullmatch(sym) is None:
class BadSymbolFormat(Exception):
pass
raise BadSymbolFormat(
"Non-linguistic symbols should be "
"formatted in {xxx}/<xxx>/[xxx], consider"
" modify '%s' to meet the requirment. "
"More details can be found in discussions here : "
"https://github.com/wenet-e2e/wenet/pull/819" % (sym)
)
return syms
def read_symbol_table(symbol_table_file):
symbol_table = {}
with open(symbol_table_file, "r", encoding="utf8") as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
symbol_table[arr[0]] = int(arr[1])
return symbol_table
|