happyme531 commited on
Commit
8d28ba8
·
verified ·
1 Parent(s): f806722

Upload 9 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ output.wav filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,131 @@
1
- ---
2
- license: agpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: agpl-3.0
3
+ language:
4
+ - en
5
+ - zh
6
+ - ja
7
+ - ko
8
+ base_model: lovemefan/SenseVoice-onnx
9
+ tags:
10
+ - rknn
11
+ ---
12
+
13
+ # SenseVoiceSmall-RKNN2
14
+
15
+ ### (English README see below)
16
+
17
+ SenseVoice是具有音频理解能力的音频基础模型, 包括语音识别(ASR)、语种识别(LID)、语音情感识别(SER)和声学事件分类(AEC)或声学事件检测(AED)。
18
+
19
+ 当前SenseVoice-small支持中、粤、英、日、韩语的多语言语音识别,情感识别和事件检测能力,具有极低的推理延迟。
20
+
21
+ - 推理速度(RKNN2):RK3588上单核NPU推理速度约20倍 (每秒识别20秒的音频), 比官方rknn-model-zoo中提供的whisper约快6倍.
22
+ - 内存占用(RKNN2):约1.1GB
23
+
24
+ ## 使用方法
25
+
26
+ 1. 克隆项目到本地
27
+
28
+ 2. 安装依赖
29
+
30
+ ```bash
31
+ pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
32
+ ```
33
+
34
+ 你还需要手动安装rknn-toolkit2-lite2.
35
+
36
+ 3. 运行
37
+
38
+ ```bash
39
+ python ./sensevoice_rknn.py --audio_file output.wav
40
+ ```
41
+
42
+ 如果使用自己的音频文件测试发现识别不正常,你可能需要提前将它转换为16kHz, 16bit, 单声道的wav格式。
43
+
44
+ ```bash
45
+ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
46
+ ```
47
+
48
+ ## RKNN模型转换
49
+
50
+ 你需要提前安装rknn-toolkit2 v2.1.0或更高版本.
51
+
52
+ 1. 下载或转换onnx模型
53
+
54
+ 可以从 https://huggingface.co/lovemefan/SenseVoice-onnx 下载到onnx模型.
55
+ 应该也可以根据 https://github.com/FunAudioLLM/SenseVoice 中的文档从Pytorch模型转换得到onnx模型.
56
+
57
+ 模型文件应该命名为'sense-voice-encoder.onnx', 放在转换脚本所在目录.
58
+
59
+ 2. 转换为rknn模型
60
+ ```bash
61
+ python convert_rknn.py
62
+ ```
63
+
64
+ ## 已知问题
65
+
66
+ - RKNN2使用fp16推理时可能会出现溢出,导致结果为inf,可以尝试修改输入数据的缩放比例来解决.
67
+ 在`sensevoice_rknn.py`中将`SPEECH_SCALE`设置为更小的值.
68
+
69
+ ## 参考
70
+ - [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
71
+ - [lovemefan/SenseVoice-python](https://github.com/lovemefan/SenseVoice-python)
72
+
73
+ # English README
74
+
75
+ # SenseVoiceSmall-RKNN2
76
+
77
+ SenseVoice is an audio foundation model with audio understanding capabilities, including Automatic Speech Recognition (ASR), Language Identification (LID), Speech Emotion Recognition (SER), and Acoustic Event Classification (AEC) or Acoustic Event Detection (AED).
78
+
79
+ Currently, SenseVoice-small supports multilingual speech recognition, emotion recognition, and event detection for Chinese, Cantonese, English, Japanese, and Korean, with extremely low inference latency.
80
+
81
+ - Inference speed (RKNN2): About 20x real-time on a single NPU core of RK3588 (processing 20 seconds of audio per second), approximately 6 times faster than the official whisper model provided in the rknn-model-zoo.
82
+ - Memory usage (RKNN2): About 1.1GB
83
+
84
+ ## Usage
85
+
86
+ 1. Clone the project to your local machine
87
+
88
+ 2. Install dependencies
89
+
90
+ ```bash
91
+ pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
92
+ ```
93
+
94
+ You also need to manually install rknn-toolkit2-lite2.
95
+
96
+ 3. Run
97
+
98
+ ```bash
99
+ python ./sensevoice_rknn.py --audio_file output.wav
100
+ ```
101
+
102
+ If you find that recognition is not working correctly when testing with your own audio files, you may need to convert them to 16kHz, 16-bit, mono WAV format in advance.
103
+
104
+ ```bash
105
+ ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
106
+ ```
107
+
108
+ ## RKNN Model Conversion
109
+
110
+ You need to install rknn-toolkit2 v2.1.0 or higher in advance.
111
+
112
+ 1. Download or convert the ONNX model
113
+
114
+ You can download the ONNX model from https://huggingface.co/lovemefan/SenseVoice-onnx.
115
+ It should also be possible to convert from a PyTorch model to an ONNX model according to the documentation at https://github.com/FunAudioLLM/SenseVoice.
116
+
117
+ The model file should be named 'sense-voice-encoder.onnx' and placed in the same directory as the conversion script.
118
+
119
+ 2. Convert to RKNN model
120
+ ```bash
121
+ python convert_rknn.py
122
+ ```
123
+
124
+ ## Known Issues
125
+
126
+ - When using fp16 inference with RKNN2, overflow may occur, resulting in inf values. You can try modifying the scaling ratio of the input data to resolve this.
127
+ Set `SPEECH_SCALE` to a smaller value in `sensevoice_rknn.py`.
128
+
129
+ ## References
130
+ - [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
131
+ - [lovemefan/SenseVoice-python](https://github.com/lovemefan/SenseVoice-python)
chn_jpn_yue_eng_ko_spectok.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa87f86064c3730d799ddf7af3c04659151102cba548bce325cf06ba4da4e6a8
3
+ size 377341
convert_rknn.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import os
5
+ from rknn.api import RKNN
6
+ from math import exp
7
+ from sys import exit
8
+ import argparse
9
+ import onnxscript
10
+ from onnxscript.rewriter import pattern
11
+ import onnx.numpy_helper as onh
12
+ import numpy as np
13
+ import onnx
14
+ import onnxruntime as ort
15
+ from rknn.utils import onnx_edit
16
+
17
+ os.chdir(os.path.dirname(os.path.abspath(__file__)))
18
+
19
+ speech_length = 171
20
+
21
+ def convert_encoder():
22
+ rknn = RKNN(verbose=True)
23
+
24
+ ONNX_MODEL=f"sense-voice-encoder.onnx"
25
+ RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
26
+ DATASET="dataset.txt"
27
+ QUANTIZE=False
28
+
29
+ #开局先给我来个大惊喜,rknn做第一步常量折叠的时候就会在这个子图里报错,所以要单独拿出来先跑一遍
30
+ #然后把这个子图的输出结果保存下来喂给rknn
31
+ onnx.utils.extract_model(ONNX_MODEL, "extract_model.onnx", ['speech_lengths'], ['/make_pad_mask/Cast_2_output_0'])
32
+ sess = ort.InferenceSession("extract_model.onnx", providers=['CPUExecutionProvider'])
33
+ extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0]
34
+
35
+ # 删掉模型最后的多余transpose, 速度从365ms提升到350ms
36
+ ret = onnx_edit(model = ONNX_MODEL,
37
+ export_path = ONNX_MODEL.replace(".onnx", "_edited.onnx"),
38
+ # # 1, len, 25055 -> 1, 25055, 1, len # 这个是坏的, 我真服了,
39
+ # outputs_transform = {'encoder_out': 'a,b,c->a,c,1,b'},
40
+ outputs_transform = {'encoder_out': 'a,b,c->a,c,b'},
41
+ )
42
+ ONNX_MODEL = ONNX_MODEL.replace(".onnx", "_edited.onnx")
43
+
44
+ # pre-process config
45
+ print('--> Config model')
46
+ rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
47
+ print('done')
48
+
49
+ # Load ONNX model
50
+ print("--> Loading model")
51
+ ret = rknn.load_onnx(
52
+ model=ONNX_MODEL,
53
+ inputs=["speech", "/make_pad_mask/Cast_2_output_0"],
54
+ input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]],
55
+ input_initial_val=[None, extract_result],
56
+ # outputs=["output"]
57
+ )
58
+
59
+ if ret != 0:
60
+ print('Load model failed!')
61
+ exit(ret)
62
+ print('done')
63
+
64
+ # Build model
65
+ print('--> Building model')
66
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
67
+ if ret != 0:
68
+ print('Build model failed!')
69
+ exit(ret)
70
+ print('done')
71
+
72
+ # export
73
+ print('--> Export RKNN model')
74
+ ret = rknn.export_rknn(RKNN_MODEL)
75
+ if ret != 0:
76
+ print('Export RKNN model failed!')
77
+ exit(ret)
78
+ print('done')
79
+
80
+ # usage: python convert_rknn.py encoder|all
81
+
82
+ if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument("model", type=str, help="model to convert", choices=["encoder", "all"], nargs='?')
85
+ args = parser.parse_args()
86
+ if args.model is None:
87
+ args.model = "all"
88
+
89
+ if args.model == "encoder":
90
+ convert_encoder()
91
+ elif args.model == "all":
92
+ convert_encoder()
93
+ else:
94
+ print(f"Unknown model: {args.model}")
95
+ exit(1)
embedding.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83cf1fc5680fdf6d7edb411be5ce351cad4eca03b29a5bf5050aa19dfcc12267
3
+ size 35968
fsmn-am.mvn ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <Nnet>
2
+ <Splice> 400 400
3
+ [ 0 ]
4
+ <AddShift> 400 400
5
+ <LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
6
+ <Rescale> 400 400
7
+ <LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
8
+ </Nnet>
fsmn-config.yaml ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ WavFrontend:
2
+ frontend_conf:
3
+ fs: 16000
4
+ window: hamming
5
+ n_mels: 80
6
+ frame_length: 25
7
+ frame_shift: 10
8
+ dither: 0.0
9
+ lfr_m: 5
10
+ lfr_n: 1
11
+
12
+ FSMN:
13
+ use_cuda: False
14
+ CUDAExecutionProvider:
15
+ device_id: 0
16
+ arena_extend_strategy: kNextPowerOfTwo
17
+ cudnn_conv_algo_search: EXHAUSTIVE
18
+ do_copy_in_default_stream: true
19
+ encoder_conf:
20
+ input_dim: 400
21
+ input_affine_dim: 140
22
+ fsmn_layers: 4
23
+ linear_dim: 250
24
+ proj_dim: 128
25
+ lorder: 20
26
+ rorder: 0
27
+ lstride: 1
28
+ rstride: 0
29
+ output_affine_dim: 140
30
+ output_dim: 248
31
+
32
+ vadPostArgs:
33
+ sample_rate: 16000
34
+ detect_mode: 1
35
+ snr_mode: 0
36
+ max_end_silence_time: 800
37
+ max_start_silence_time: 3000
38
+ do_start_point_detection: True
39
+ do_end_point_detection: True
40
+ window_size_ms: 200
41
+ sil_to_speech_time_thres: 150
42
+ speech_to_sil_time_thres: 150
43
+ speech_2_noise_ratio: 1.0
44
+ do_extend: 1
45
+ lookback_time_start_point: 200
46
+ lookahead_time_end_point: 100
47
+ max_single_segment_time: 10000
48
+ snr_thres: -100.0
49
+ noise_frame_num_used_for_snr: 100
50
+ decibel_thres: -100.0
51
+ speech_noise_thres: 0.6
52
+ fe_prior_thres: 0.0001
53
+ silence_pdf_num: 1
54
+ sil_pdf_ids: [ 0 ]
55
+ speech_noise_thresh_low: -0.1
56
+ speech_noise_thresh_high: 0.3
57
+ output_frame_probs: False
58
+ frame_in_ms: 10
59
+ frame_length_ms: 25
fsmnvad-offline.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4bbd68b11519e916b6871ff6f8df15e2100936b256be9cb104cd63fb7c859965
3
+ size 1725472
output.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6f02d2c58b9a8a294a306ccb60bdf667587d74984915a3ec87a6de5e04bb020
3
+ size 1289994
sensevoice_rknn.py ADDED
@@ -0,0 +1,1402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # File: onnx/fsmn_vad_ort_session.py
2
+ # ```py
3
+
4
+ # -*- coding:utf-8 -*-
5
+ # @FileName :fsmn_vad_ort_session.py.py
6
+ # @Time :2024/8/31 16:45
7
+ # @Author :lovemefan
8
+ # @Email :lovemefan@outlook.com
9
+
10
+ import argparse
11
+ import logging
12
+ import math
13
+ import os
14
+ import time
15
+ import warnings
16
+ from enum import Enum
17
+ from pathlib import Path
18
+ from typing import Any, Dict, List, Tuple, Union
19
+
20
+ import kaldi_native_fbank as knf
21
+ import numpy as np
22
+ import sentencepiece as spm
23
+ import soundfile as sf
24
+ import yaml
25
+ from onnxruntime import (GraphOptimizationLevel, InferenceSession,
26
+ SessionOptions, get_available_providers, get_device)
27
+ from rknnlite.api.rknn_lite import RKNNLite
28
+
29
+ RKNN_INPUT_LEN = 171
30
+
31
+ SPEECH_SCALE = 1/2 # 因为是fp16推理,如果中间结果太大可能会溢出变inf,所以需要缩放一下
32
+
33
+ class VadOrtInferRuntimeSession:
34
+ def __init__(self, config, root_dir: Path):
35
+ sess_opt = SessionOptions()
36
+ sess_opt.log_severity_level = 4
37
+ sess_opt.enable_cpu_mem_arena = False
38
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
39
+
40
+ cuda_ep = "CUDAExecutionProvider"
41
+ cpu_ep = "CPUExecutionProvider"
42
+ cpu_provider_options = {
43
+ "arena_extend_strategy": "kSameAsRequested",
44
+ }
45
+
46
+ EP_list = []
47
+ if (
48
+ config["use_cuda"]
49
+ and get_device() == "GPU"
50
+ and cuda_ep in get_available_providers()
51
+ ):
52
+ EP_list = [(cuda_ep, config[cuda_ep])]
53
+ EP_list.append((cpu_ep, cpu_provider_options))
54
+
55
+ config["model_path"] = root_dir / str(config["model_path"])
56
+ self._verify_model(config["model_path"])
57
+ logging.info(f"Loading onnx model at {str(config['model_path'])}")
58
+ self.session = InferenceSession(
59
+ str(config["model_path"]), sess_options=sess_opt, providers=EP_list
60
+ )
61
+
62
+ if config["use_cuda"] and cuda_ep not in self.session.get_providers():
63
+ logging.warning(
64
+ f"{cuda_ep} is not available for current env, "
65
+ f"the inference part is automatically shifted to be "
66
+ f"executed under {cpu_ep}.\n "
67
+ "Please ensure the installed onnxruntime-gpu version"
68
+ " matches your cuda and cudnn version, "
69
+ "you can check their relations from the offical web site: "
70
+ "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
71
+ RuntimeWarning,
72
+ )
73
+
74
+ def __call__(
75
+ self, input_content: List[Union[np.ndarray, np.ndarray]]
76
+ ) -> np.ndarray:
77
+ if isinstance(input_content, list):
78
+ input_dict = {
79
+ "speech": input_content[0],
80
+ "in_cache0": input_content[1],
81
+ "in_cache1": input_content[2],
82
+ "in_cache2": input_content[3],
83
+ "in_cache3": input_content[4],
84
+ }
85
+ else:
86
+ input_dict = {"speech": input_content}
87
+
88
+ return self.session.run(None, input_dict)
89
+
90
+ def get_input_names(
91
+ self,
92
+ ):
93
+ return [v.name for v in self.session.get_inputs()]
94
+
95
+ def get_output_names(
96
+ self,
97
+ ):
98
+ return [v.name for v in self.session.get_outputs()]
99
+
100
+ def get_character_list(self, key: str = "character"):
101
+ return self.meta_dict[key].splitlines()
102
+
103
+ def have_key(self, key: str = "character") -> bool:
104
+ self.meta_dict = self.session.get_modelmeta().custom_metadata_map
105
+ if key in self.meta_dict.keys():
106
+ return True
107
+ return False
108
+
109
+ @staticmethod
110
+ def _verify_model(model_path):
111
+ model_path = Path(model_path)
112
+ if not model_path.exists():
113
+ raise FileNotFoundError(f"{model_path} does not exists.")
114
+ if not model_path.is_file():
115
+ raise FileExistsError(f"{model_path} is not a file.")
116
+
117
+ # ```
118
+
119
+ # File: onnx/sense_voice_ort_session.py
120
+ # ```py
121
+ # -*- coding:utf-8 -*-
122
+ # @FileName :sense_voice_onnxruntime.py
123
+ # @Time :2024/7/17 20:53
124
+ # @Author :lovemefan
125
+ # @Email :lovemefan@outlook.com
126
+
127
+
128
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
129
+ logging.basicConfig(format=formatter, level=logging.INFO)
130
+
131
+
132
+ class OrtInferRuntimeSession:
133
+ def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
134
+ device_id = str(device_id)
135
+ sess_opt = SessionOptions()
136
+ sess_opt.intra_op_num_threads = intra_op_num_threads
137
+ sess_opt.log_severity_level = 4
138
+ sess_opt.enable_cpu_mem_arena = False
139
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
140
+
141
+ cuda_ep = "CUDAExecutionProvider"
142
+ cuda_provider_options = {
143
+ "device_id": device_id,
144
+ "arena_extend_strategy": "kNextPowerOfTwo",
145
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
146
+ "do_copy_in_default_stream": "true",
147
+ }
148
+ cpu_ep = "CPUExecutionProvider"
149
+ cpu_provider_options = {
150
+ "arena_extend_strategy": "kSameAsRequested",
151
+ }
152
+
153
+ EP_list = []
154
+ if (
155
+ device_id != "-1"
156
+ and get_device() == "GPU"
157
+ and cuda_ep in get_available_providers()
158
+ ):
159
+ EP_list = [(cuda_ep, cuda_provider_options)]
160
+ EP_list.append((cpu_ep, cpu_provider_options))
161
+
162
+ self._verify_model(model_file)
163
+
164
+ self.session = InferenceSession(
165
+ model_file, sess_options=sess_opt, providers=EP_list
166
+ )
167
+
168
+ # delete binary of model file to save memory
169
+ del model_file
170
+
171
+ if device_id != "-1" and cuda_ep not in self.session.get_providers():
172
+ warnings.warn(
173
+ f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
174
+ "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
175
+ "you can check their relations from the offical web site: "
176
+ "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
177
+ RuntimeWarning,
178
+ )
179
+
180
+ def __call__(self, input_content) -> np.ndarray:
181
+ input_dict = dict(zip(self.get_input_names(), input_content))
182
+ try:
183
+ result = self.session.run(self.get_output_names(), input_dict)
184
+ return result
185
+ except Exception as e:
186
+ print(e)
187
+ raise RuntimeError(f"ONNXRuntime inferece failed. ") from e
188
+
189
+ def get_input_names(
190
+ self,
191
+ ):
192
+ return [v.name for v in self.session.get_inputs()]
193
+
194
+ def get_output_names(
195
+ self,
196
+ ):
197
+ return [v.name for v in self.session.get_outputs()]
198
+
199
+ def get_character_list(self, key: str = "character"):
200
+ return self.meta_dict[key].splitlines()
201
+
202
+ def have_key(self, key: str = "character") -> bool:
203
+ self.meta_dict = self.session.get_modelmeta().custom_metadata_map
204
+ if key in self.meta_dict.keys():
205
+ return True
206
+ return False
207
+
208
+ @staticmethod
209
+ def _verify_model(model_path):
210
+ model_path = Path(model_path)
211
+ if not model_path.exists():
212
+ raise FileNotFoundError(f"{model_path} does not exists.")
213
+ if not model_path.is_file():
214
+ raise FileExistsError(f"{model_path} is not a file.")
215
+
216
+
217
+ def log_softmax(x: np.ndarray) -> np.ndarray:
218
+ # Subtract the maximum value in each row for numerical stability
219
+ x_max = np.max(x, axis=-1, keepdims=True)
220
+ # Calculate the softmax of x
221
+ softmax = np.exp(x - x_max)
222
+ softmax_sum = np.sum(softmax, axis=-1, keepdims=True)
223
+ softmax = softmax / softmax_sum
224
+ # Calculate the log of the softmax values
225
+ return np.log(softmax)
226
+
227
+
228
+ class SenseVoiceInferenceSession:
229
+ def __init__(
230
+ self,
231
+ embedding_model_file,
232
+ encoder_model_file,
233
+ bpe_model_file,
234
+ device_id=-1,
235
+ intra_op_num_threads=4,
236
+ ):
237
+ logging.info(f"Loading model from {embedding_model_file}")
238
+
239
+ self.embedding = np.load(embedding_model_file)
240
+ logging.info(f"Loading model {encoder_model_file}")
241
+ start = time.time()
242
+ self.encoder = RKNNLite(verbose=False)
243
+ self.encoder.load_rknn(encoder_model_file)
244
+ self.encoder.init_runtime()
245
+
246
+ logging.info(
247
+ f"Loading {encoder_model_file} takes {time.time() - start:.2f} seconds"
248
+ )
249
+ self.blank_id = 0
250
+ self.sp = spm.SentencePieceProcessor()
251
+ self.sp.load(bpe_model_file)
252
+
253
+ def __call__(self, speech, language: int, use_itn: bool) -> np.ndarray:
254
+ language_query = self.embedding[[[language]]]
255
+
256
+ # 14 means with itn, 15 means without itn
257
+ text_norm_query = self.embedding[[[14 if use_itn else 15]]]
258
+ event_emo_query = self.embedding[[[1, 2]]]
259
+
260
+ # scale the speech
261
+ speech = speech * SPEECH_SCALE
262
+
263
+ input_content = np.concatenate(
264
+ [
265
+ language_query,
266
+ event_emo_query,
267
+ text_norm_query,
268
+ speech,
269
+ ],
270
+ axis=1,
271
+ ).astype(np.float32)
272
+ print(input_content.shape)
273
+ # pad [1, len, ...] to [1, RKNN_INPUT_LEN, ... ]
274
+ input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0)))
275
+ print("padded shape:", input_content.shape)
276
+ start_time = time.time()
277
+ encoder_out = self.encoder.inference(inputs=[input_content])[0]
278
+ end_time = time.time()
279
+ print(f"encoder inference time: {end_time - start_time:.2f} seconds")
280
+ # print(encoder_out)
281
+ def unique_consecutive(arr):
282
+ if len(arr) == 0:
283
+ return arr
284
+ # Create a boolean mask where True indicates the element is different from the previous one
285
+ mask = np.append([True], arr[1:] != arr[:-1])
286
+ out = arr[mask]
287
+ out = out[out != self.blank_id]
288
+ return out.tolist()
289
+
290
+ #现在shape变成了1, n_vocab, n_seq. 这里axis需要改一下
291
+ # hypos = unique_consecutive(encoder_out[0].argmax(axis=-1))
292
+ hypos = unique_consecutive(encoder_out[0].argmax(axis=0))
293
+ text = self.sp.DecodeIds(hypos)
294
+ return text
295
+
296
+ # ```
297
+
298
+ # File: utils/frontend.py
299
+ # ```py
300
+ # -*- coding:utf-8 -*-
301
+ # @FileName :frontend.py
302
+ # @Time :2024/7/18 09:39
303
+ # @Author :lovemefan
304
+ # @Email :lovemefan@outlook.com
305
+
306
+ class WavFrontend:
307
+ """Conventional frontend structure for ASR."""
308
+
309
+ def __init__(
310
+ self,
311
+ cmvn_file: str = None,
312
+ fs: int = 16000,
313
+ window: str = "hamming",
314
+ n_mels: int = 80,
315
+ frame_length: int = 25,
316
+ frame_shift: int = 10,
317
+ lfr_m: int = 7,
318
+ lfr_n: int = 6,
319
+ dither: float = 0,
320
+ **kwargs,
321
+ ) -> None:
322
+ opts = knf.FbankOptions()
323
+ opts.frame_opts.samp_freq = fs
324
+ opts.frame_opts.dither = dither
325
+ opts.frame_opts.window_type = window
326
+ opts.frame_opts.frame_shift_ms = float(frame_shift)
327
+ opts.frame_opts.frame_length_ms = float(frame_length)
328
+ opts.mel_opts.num_bins = n_mels
329
+ opts.energy_floor = 0
330
+ opts.frame_opts.snip_edges = True
331
+ opts.mel_opts.debug_mel = False
332
+ self.opts = opts
333
+
334
+ self.lfr_m = lfr_m
335
+ self.lfr_n = lfr_n
336
+ self.cmvn_file = cmvn_file
337
+
338
+ if self.cmvn_file:
339
+ self.cmvn = self.load_cmvn()
340
+ self.fbank_fn = None
341
+ self.fbank_beg_idx = 0
342
+ self.reset_status()
343
+
344
+ def reset_status(self):
345
+ self.fbank_fn = knf.OnlineFbank(self.opts)
346
+ self.fbank_beg_idx = 0
347
+
348
+ def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
349
+ waveform = waveform * (1 << 15)
350
+ self.fbank_fn = knf.OnlineFbank(self.opts)
351
+ self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
352
+ frames = self.fbank_fn.num_frames_ready
353
+ mat = np.empty([frames, self.opts.mel_opts.num_bins])
354
+ for i in range(frames):
355
+ mat[i, :] = self.fbank_fn.get_frame(i)
356
+ feat = mat.astype(np.float32)
357
+ feat_len = np.array(mat.shape[0]).astype(np.int32)
358
+ return feat, feat_len
359
+
360
+ def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
361
+ if self.lfr_m != 1 or self.lfr_n != 1:
362
+ feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
363
+
364
+ if self.cmvn_file:
365
+ feat = self.apply_cmvn(feat)
366
+
367
+ feat_len = np.array(feat.shape[0]).astype(np.int32)
368
+ return feat, feat_len
369
+
370
+ def load_audio(self, filename: str) -> Tuple[np.ndarray, int]:
371
+ data, sample_rate = sf.read(
372
+ filename,
373
+ always_2d=True,
374
+ dtype="float32",
375
+ )
376
+ assert (
377
+ sample_rate == 16000
378
+ ), f"Only 16000 Hz is supported, but got {sample_rate}Hz"
379
+ self.sample_rate = sample_rate
380
+ data = data[:, 0] # use only the first channel
381
+ samples = np.ascontiguousarray(data)
382
+
383
+ return samples, sample_rate
384
+
385
+ @staticmethod
386
+ def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
387
+ LFR_inputs = []
388
+
389
+ T = inputs.shape[0]
390
+ T_lfr = int(np.ceil(T / lfr_n))
391
+ left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
392
+ inputs = np.vstack((left_padding, inputs))
393
+ T = T + (lfr_m - 1) // 2
394
+ for i in range(T_lfr):
395
+ if lfr_m <= T - i * lfr_n:
396
+ LFR_inputs.append(
397
+ (inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
398
+ )
399
+ else:
400
+ # process last LFR frame
401
+ num_padding = lfr_m - (T - i * lfr_n)
402
+ frame = inputs[i * lfr_n :].reshape(-1)
403
+ for _ in range(num_padding):
404
+ frame = np.hstack((frame, inputs[-1]))
405
+
406
+ LFR_inputs.append(frame)
407
+ LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
408
+ return LFR_outputs
409
+
410
+ def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
411
+ """
412
+ Apply CMVN with mvn data
413
+ """
414
+ frame, dim = inputs.shape
415
+ means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
416
+ vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
417
+ inputs = (inputs + means) * vars
418
+ return inputs
419
+
420
+ def get_features(self, inputs: Union[str, np.ndarray]) -> Tuple[np.ndarray, int]:
421
+ if isinstance(inputs, str):
422
+ inputs, _ = self.load_audio(inputs)
423
+
424
+ fbank, _ = self.fbank(inputs)
425
+ feats = self.apply_cmvn(self.apply_lfr(fbank, self.lfr_m, self.lfr_n))
426
+ return feats
427
+
428
+ def load_cmvn(
429
+ self,
430
+ ) -> np.ndarray:
431
+ with open(self.cmvn_file, "r", encoding="utf-8") as f:
432
+ lines = f.readlines()
433
+
434
+ means_list = []
435
+ vars_list = []
436
+ for i in range(len(lines)):
437
+ line_item = lines[i].split()
438
+ if line_item[0] == "<AddShift>":
439
+ line_item = lines[i + 1].split()
440
+ if line_item[0] == "<LearnRateCoef>":
441
+ add_shift_line = line_item[3 : (len(line_item) - 1)]
442
+ means_list = list(add_shift_line)
443
+ continue
444
+ elif line_item[0] == "<Rescale>":
445
+ line_item = lines[i + 1].split()
446
+ if line_item[0] == "<LearnRateCoef>":
447
+ rescale_line = line_item[3 : (len(line_item) - 1)]
448
+ vars_list = list(rescale_line)
449
+ continue
450
+
451
+ means = np.array(means_list).astype(np.float64)
452
+ vars = np.array(vars_list).astype(np.float64)
453
+ cmvn = np.array([means, vars])
454
+ return cmvn
455
+
456
+ # ```
457
+
458
+ # File: utils/fsmn_vad.py
459
+ # ```py
460
+ # -*- coding:utf-8 -*-
461
+ # @FileName :fsmn_vad.py
462
+ # @Time :2024/8/31 16:50
463
+ # @Author :lovemefan
464
+ # @Email :lovemefan@outlook.com
465
+
466
+
467
+
468
+ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
469
+ if not Path(yaml_path).exists():
470
+ raise FileExistsError(f"The {yaml_path} does not exist.")
471
+
472
+ with open(str(yaml_path), "rb") as f:
473
+ data = yaml.load(f, Loader=yaml.Loader)
474
+ return data
475
+
476
+
477
+ class VadStateMachine(Enum):
478
+ kVadInStateStartPointNotDetected = 1
479
+ kVadInStateInSpeechSegment = 2
480
+ kVadInStateEndPointDetected = 3
481
+
482
+
483
+ class FrameState(Enum):
484
+ kFrameStateInvalid = -1
485
+ kFrameStateSpeech = 1
486
+ kFrameStateSil = 0
487
+
488
+
489
+ # final voice/unvoice state per frame
490
+ class AudioChangeState(Enum):
491
+ kChangeStateSpeech2Speech = 0
492
+ kChangeStateSpeech2Sil = 1
493
+ kChangeStateSil2Sil = 2
494
+ kChangeStateSil2Speech = 3
495
+ kChangeStateNoBegin = 4
496
+ kChangeStateInvalid = 5
497
+
498
+
499
+ class VadDetectMode(Enum):
500
+ kVadSingleUtteranceDetectMode = 0
501
+ kVadMutipleUtteranceDetectMode = 1
502
+
503
+
504
+ class VADXOptions:
505
+ def __init__(
506
+ self,
507
+ sample_rate: int = 16000,
508
+ detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
509
+ snr_mode: int = 0,
510
+ max_end_silence_time: int = 800,
511
+ max_start_silence_time: int = 3000,
512
+ do_start_point_detection: bool = True,
513
+ do_end_point_detection: bool = True,
514
+ window_size_ms: int = 200,
515
+ sil_to_speech_time_thres: int = 150,
516
+ speech_to_sil_time_thres: int = 150,
517
+ speech_2_noise_ratio: float = 1.0,
518
+ do_extend: int = 1,
519
+ lookback_time_start_point: int = 200,
520
+ lookahead_time_end_point: int = 100,
521
+ max_single_segment_time: int = 60000,
522
+ nn_eval_block_size: int = 8,
523
+ dcd_block_size: int = 4,
524
+ snr_thres: int = -100.0,
525
+ noise_frame_num_used_for_snr: int = 100,
526
+ decibel_thres: int = -100.0,
527
+ speech_noise_thres: float = 0.6,
528
+ fe_prior_thres: float = 1e-4,
529
+ silence_pdf_num: int = 1,
530
+ sil_pdf_ids: List[int] = [0],
531
+ speech_noise_thresh_low: float = -0.1,
532
+ speech_noise_thresh_high: float = 0.3,
533
+ output_frame_probs: bool = False,
534
+ frame_in_ms: int = 10,
535
+ frame_length_ms: int = 25,
536
+ ):
537
+ self.sample_rate = sample_rate
538
+ self.detect_mode = detect_mode
539
+ self.snr_mode = snr_mode
540
+ self.max_end_silence_time = max_end_silence_time
541
+ self.max_start_silence_time = max_start_silence_time
542
+ self.do_start_point_detection = do_start_point_detection
543
+ self.do_end_point_detection = do_end_point_detection
544
+ self.window_size_ms = window_size_ms
545
+ self.sil_to_speech_time_thres = sil_to_speech_time_thres
546
+ self.speech_to_sil_time_thres = speech_to_sil_time_thres
547
+ self.speech_2_noise_ratio = speech_2_noise_ratio
548
+ self.do_extend = do_extend
549
+ self.lookback_time_start_point = lookback_time_start_point
550
+ self.lookahead_time_end_point = lookahead_time_end_point
551
+ self.max_single_segment_time = max_single_segment_time
552
+ self.nn_eval_block_size = nn_eval_block_size
553
+ self.dcd_block_size = dcd_block_size
554
+ self.snr_thres = snr_thres
555
+ self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
556
+ self.decibel_thres = decibel_thres
557
+ self.speech_noise_thres = speech_noise_thres
558
+ self.fe_prior_thres = fe_prior_thres
559
+ self.silence_pdf_num = silence_pdf_num
560
+ self.sil_pdf_ids = sil_pdf_ids
561
+ self.speech_noise_thresh_low = speech_noise_thresh_low
562
+ self.speech_noise_thresh_high = speech_noise_thresh_high
563
+ self.output_frame_probs = output_frame_probs
564
+ self.frame_in_ms = frame_in_ms
565
+ self.frame_length_ms = frame_length_ms
566
+
567
+
568
+ class E2EVadSpeechBufWithDoa(object):
569
+ def __init__(self):
570
+ self.start_ms = 0
571
+ self.end_ms = 0
572
+ self.buffer = []
573
+ self.contain_seg_start_point = False
574
+ self.contain_seg_end_point = False
575
+ self.doa = 0
576
+
577
+ def reset(self):
578
+ self.start_ms = 0
579
+ self.end_ms = 0
580
+ self.buffer = []
581
+ self.contain_seg_start_point = False
582
+ self.contain_seg_end_point = False
583
+ self.doa = 0
584
+
585
+
586
+ class E2EVadFrameProb(object):
587
+ def __init__(self):
588
+ self.noise_prob = 0.0
589
+ self.speech_prob = 0.0
590
+ self.score = 0.0
591
+ self.frame_id = 0
592
+ self.frm_state = 0
593
+
594
+
595
+ class WindowDetector(object):
596
+ def __init__(
597
+ self,
598
+ window_size_ms: int,
599
+ sil_to_speech_time: int,
600
+ speech_to_sil_time: int,
601
+ frame_size_ms: int,
602
+ ):
603
+ self.window_size_ms = window_size_ms
604
+ self.sil_to_speech_time = sil_to_speech_time
605
+ self.speech_to_sil_time = speech_to_sil_time
606
+ self.frame_size_ms = frame_size_ms
607
+
608
+ self.win_size_frame = int(window_size_ms / frame_size_ms)
609
+ self.win_sum = 0
610
+ self.win_state = [0] * self.win_size_frame # 初始化窗
611
+
612
+ self.cur_win_pos = 0
613
+ self.pre_frame_state = FrameState.kFrameStateSil
614
+ self.cur_frame_state = FrameState.kFrameStateSil
615
+ self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
616
+ self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
617
+
618
+ self.voice_last_frame_count = 0
619
+ self.noise_last_frame_count = 0
620
+ self.hydre_frame_count = 0
621
+
622
+ def reset(self) -> None:
623
+ self.cur_win_pos = 0
624
+ self.win_sum = 0
625
+ self.win_state = [0] * self.win_size_frame
626
+ self.pre_frame_state = FrameState.kFrameStateSil
627
+ self.cur_frame_state = FrameState.kFrameStateSil
628
+ self.voice_last_frame_count = 0
629
+ self.noise_last_frame_count = 0
630
+ self.hydre_frame_count = 0
631
+
632
+ def get_win_size(self) -> int:
633
+ return int(self.win_size_frame)
634
+
635
+ def detect_one_frame(
636
+ self, frameState: FrameState, frame_count: int
637
+ ) -> AudioChangeState:
638
+ cur_frame_state = FrameState.kFrameStateSil
639
+ if frameState == FrameState.kFrameStateSpeech:
640
+ cur_frame_state = 1
641
+ elif frameState == FrameState.kFrameStateSil:
642
+ cur_frame_state = 0
643
+ else:
644
+ return AudioChangeState.kChangeStateInvalid
645
+ self.win_sum -= self.win_state[self.cur_win_pos]
646
+ self.win_sum += cur_frame_state
647
+ self.win_state[self.cur_win_pos] = cur_frame_state
648
+ self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
649
+
650
+ if (
651
+ self.pre_frame_state == FrameState.kFrameStateSil
652
+ and self.win_sum >= self.sil_to_speech_frmcnt_thres
653
+ ):
654
+ self.pre_frame_state = FrameState.kFrameStateSpeech
655
+ return AudioChangeState.kChangeStateSil2Speech
656
+
657
+ if (
658
+ self.pre_frame_state == FrameState.kFrameStateSpeech
659
+ and self.win_sum <= self.speech_to_sil_frmcnt_thres
660
+ ):
661
+ self.pre_frame_state = FrameState.kFrameStateSil
662
+ return AudioChangeState.kChangeStateSpeech2Sil
663
+
664
+ if self.pre_frame_state == FrameState.kFrameStateSil:
665
+ return AudioChangeState.kChangeStateSil2Sil
666
+ if self.pre_frame_state == FrameState.kFrameStateSpeech:
667
+ return AudioChangeState.kChangeStateSpeech2Speech
668
+ return AudioChangeState.kChangeStateInvalid
669
+
670
+ def frame_size_ms(self) -> int:
671
+ return int(self.frame_size_ms)
672
+
673
+
674
+ class E2EVadModel:
675
+ def __init__(self, config, vad_post_args: Dict[str, Any], root_dir: Path):
676
+ super(E2EVadModel, self).__init__()
677
+ self.vad_opts = VADXOptions(**vad_post_args)
678
+ self.windows_detector = WindowDetector(
679
+ self.vad_opts.window_size_ms,
680
+ self.vad_opts.sil_to_speech_time_thres,
681
+ self.vad_opts.speech_to_sil_time_thres,
682
+ self.vad_opts.frame_in_ms,
683
+ )
684
+ self.model = VadOrtInferRuntimeSession(config, root_dir)
685
+ self.all_reset_detection()
686
+
687
+ def all_reset_detection(self):
688
+ # init variables
689
+ self.is_final = False
690
+ self.data_buf_start_frame = 0
691
+ self.frm_cnt = 0
692
+ self.latest_confirmed_speech_frame = 0
693
+ self.lastest_confirmed_silence_frame = -1
694
+ self.continous_silence_frame_count = 0
695
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
696
+ self.confirmed_start_frame = -1
697
+ self.confirmed_end_frame = -1
698
+ self.number_end_time_detected = 0
699
+ self.sil_frame = 0
700
+ self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
701
+ self.noise_average_decibel = -100.0
702
+ self.pre_end_silence_detected = False
703
+ self.next_seg = True
704
+
705
+ self.output_data_buf = []
706
+ self.output_data_buf_offset = 0
707
+ self.frame_probs = []
708
+ self.max_end_sil_frame_cnt_thresh = (
709
+ self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
710
+ )
711
+ self.speech_noise_thres = self.vad_opts.speech_noise_thres
712
+ self.scores = None
713
+ self.scores_offset = 0
714
+ self.max_time_out = False
715
+ self.decibel = []
716
+ self.decibel_offset = 0
717
+ self.data_buf_size = 0
718
+ self.data_buf_all_size = 0
719
+ self.waveform = None
720
+ self.reset_detection()
721
+
722
+ def reset_detection(self):
723
+ self.continous_silence_frame_count = 0
724
+ self.latest_confirmed_speech_frame = 0
725
+ self.lastest_confirmed_silence_frame = -1
726
+ self.confirmed_start_frame = -1
727
+ self.confirmed_end_frame = -1
728
+ self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
729
+ self.windows_detector.reset()
730
+ self.sil_frame = 0
731
+ self.frame_probs = []
732
+
733
+ def compute_decibel(self) -> None:
734
+ frame_sample_length = int(
735
+ self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
736
+ )
737
+ frame_shift_length = int(
738
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
739
+ )
740
+ if self.data_buf_all_size == 0:
741
+ self.data_buf_all_size = len(self.waveform[0])
742
+ self.data_buf_size = self.data_buf_all_size
743
+ else:
744
+ self.data_buf_all_size += len(self.waveform[0])
745
+
746
+ for offset in range(
747
+ 0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length
748
+ ):
749
+ self.decibel.append(
750
+ 10
751
+ * np.log10(
752
+ np.square(
753
+ self.waveform[0][offset : offset + frame_sample_length]
754
+ ).sum()
755
+ + 1e-6
756
+ )
757
+ )
758
+
759
+ def compute_scores(self, feats: np.ndarray) -> None:
760
+ scores = self.model(feats)
761
+ self.vad_opts.nn_eval_block_size = scores[0].shape[1]
762
+ self.frm_cnt += scores[0].shape[1] # count total frames
763
+ if isinstance(feats, list):
764
+ # return B * T * D
765
+ feats = feats[0]
766
+
767
+ assert (
768
+ scores[0].shape[1] == feats.shape[1]
769
+ ), "The shape between feats and scores does not match"
770
+
771
+ self.scores = scores[0] # the first calculation
772
+ self.scores_offset += self.scores.shape[1]
773
+
774
+ return scores[1:]
775
+
776
+ def pop_data_buf_till_frame(self, frame_idx: int) -> None: # need check again
777
+ while self.data_buf_start_frame < frame_idx:
778
+ if self.data_buf_size >= int(
779
+ self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
780
+ ):
781
+ self.data_buf_start_frame += 1
782
+ self.data_buf_size = (
783
+ self.data_buf_all_size
784
+ - self.data_buf_start_frame
785
+ * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
786
+ )
787
+
788
+ def pop_data_to_output_buf(
789
+ self,
790
+ start_frm: int,
791
+ frm_cnt: int,
792
+ first_frm_is_start_point: bool,
793
+ last_frm_is_end_point: bool,
794
+ end_point_is_sent_end: bool,
795
+ ) -> None:
796
+ self.pop_data_buf_till_frame(start_frm)
797
+ expected_sample_number = int(
798
+ frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
799
+ )
800
+ if last_frm_is_end_point:
801
+ extra_sample = max(
802
+ 0,
803
+ int(
804
+ self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
805
+ - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
806
+ ),
807
+ )
808
+ expected_sample_number += int(extra_sample)
809
+ if end_point_is_sent_end:
810
+ expected_sample_number = max(expected_sample_number, self.data_buf_size)
811
+ if self.data_buf_size < expected_sample_number:
812
+ logging.error("error in calling pop data_buf\n")
813
+
814
+ if len(self.output_data_buf) == 0 or first_frm_is_start_point:
815
+ self.output_data_buf.append(E2EVadSpeechBufWithDoa())
816
+ self.output_data_buf[-1].reset()
817
+ self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
818
+ self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
819
+ self.output_data_buf[-1].doa = 0
820
+ cur_seg = self.output_data_buf[-1]
821
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
822
+ logging.error("warning\n")
823
+ out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
824
+ data_to_pop = 0
825
+ if end_point_is_sent_end:
826
+ data_to_pop = expected_sample_number
827
+ else:
828
+ data_to_pop = int(
829
+ frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
830
+ )
831
+ if data_to_pop > self.data_buf_size:
832
+ logging.error("VAD data_to_pop is bigger than self.data_buf.size()!!!\n")
833
+ data_to_pop = self.data_buf_size
834
+ expected_sample_number = self.data_buf_size
835
+
836
+ cur_seg.doa = 0
837
+ for sample_cpy_out in range(0, data_to_pop):
838
+ # cur_seg.buffer[out_pos ++] = data_buf_.back();
839
+ out_pos += 1
840
+ for sample_cpy_out in range(data_to_pop, expected_sample_number):
841
+ # cur_seg.buffer[out_pos++] = data_buf_.back()
842
+ out_pos += 1
843
+ if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
844
+ logging.error("Something wrong with the VAD algorithm\n")
845
+ self.data_buf_start_frame += frm_cnt
846
+ cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
847
+ if first_frm_is_start_point:
848
+ cur_seg.contain_seg_start_point = True
849
+ if last_frm_is_end_point:
850
+ cur_seg.contain_seg_end_point = True
851
+
852
+ def on_silence_detected(self, valid_frame: int):
853
+ self.lastest_confirmed_silence_frame = valid_frame
854
+ if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
855
+ self.pop_data_buf_till_frame(valid_frame)
856
+ # silence_detected_callback_
857
+ # pass
858
+
859
+ def on_voice_detected(self, valid_frame: int) -> None:
860
+ self.latest_confirmed_speech_frame = valid_frame
861
+ self.pop_data_to_output_buf(valid_frame, 1, False, False, False)
862
+
863
+ def on_voice_start(self, start_frame: int, fake_result: bool = False) -> None:
864
+ if self.vad_opts.do_start_point_detection:
865
+ pass
866
+ if self.confirmed_start_frame != -1:
867
+ logging.error("not reset vad properly\n")
868
+ else:
869
+ self.confirmed_start_frame = start_frame
870
+
871
+ if (
872
+ not fake_result
873
+ and self.vad_state_machine
874
+ == VadStateMachine.kVadInStateStartPointNotDetected
875
+ ):
876
+ self.pop_data_to_output_buf(
877
+ self.confirmed_start_frame, 1, True, False, False
878
+ )
879
+
880
+ def on_voice_end(
881
+ self, end_frame: int, fake_result: bool, is_last_frame: bool
882
+ ) -> None:
883
+ for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
884
+ self.on_voice_detected(t)
885
+ if self.vad_opts.do_end_point_detection:
886
+ pass
887
+ if self.confirmed_end_frame != -1:
888
+ logging.error("not reset vad properly\n")
889
+ else:
890
+ self.confirmed_end_frame = end_frame
891
+ if not fake_result:
892
+ self.sil_frame = 0
893
+ self.pop_data_to_output_buf(
894
+ self.confirmed_end_frame, 1, False, True, is_last_frame
895
+ )
896
+ self.number_end_time_detected += 1
897
+
898
+ def maybe_on_voice_end_last_frame(
899
+ self, is_final_frame: bool, cur_frm_idx: int
900
+ ) -> None:
901
+ if is_final_frame:
902
+ self.on_voice_end(cur_frm_idx, False, True)
903
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
904
+
905
+ def get_latency(self) -> int:
906
+ return int(self.latency_frm_num_at_start_point() * self.vad_opts.frame_in_ms)
907
+
908
+ def latency_frm_num_at_start_point(self) -> int:
909
+ vad_latency = self.windows_detector.get_win_size()
910
+ if self.vad_opts.do_extend:
911
+ vad_latency += int(
912
+ self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms
913
+ )
914
+ return vad_latency
915
+
916
+ def get_frame_state(self, t: int) -> FrameState:
917
+ frame_state = FrameState.kFrameStateInvalid
918
+ cur_decibel = self.decibel[t - self.decibel_offset]
919
+ cur_snr = cur_decibel - self.noise_average_decibel
920
+ # for each frame, calc log posterior probability of each state
921
+ if cur_decibel < self.vad_opts.decibel_thres:
922
+ frame_state = FrameState.kFrameStateSil
923
+ self.detect_one_frame(frame_state, t, False)
924
+ return frame_state
925
+
926
+ sum_score = 0.0
927
+ noise_prob = 0.0
928
+ assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
929
+ if len(self.sil_pdf_ids) > 0:
930
+ assert len(self.scores) == 1 # 只支持batch_size = 1的测试
931
+ sil_pdf_scores = [
932
+ self.scores[0][t - self.scores_offset][sil_pdf_id]
933
+ for sil_pdf_id in self.sil_pdf_ids
934
+ ]
935
+ sum_score = sum(sil_pdf_scores)
936
+ noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
937
+ total_score = 1.0
938
+ sum_score = total_score - sum_score
939
+ speech_prob = math.log(sum_score)
940
+ if self.vad_opts.output_frame_probs:
941
+ frame_prob = E2EVadFrameProb()
942
+ frame_prob.noise_prob = noise_prob
943
+ frame_prob.speech_prob = speech_prob
944
+ frame_prob.score = sum_score
945
+ frame_prob.frame_id = t
946
+ self.frame_probs.append(frame_prob)
947
+ if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
948
+ if (
949
+ cur_snr >= self.vad_opts.snr_thres
950
+ and cur_decibel >= self.vad_opts.decibel_thres
951
+ ):
952
+ frame_state = FrameState.kFrameStateSpeech
953
+ else:
954
+ frame_state = FrameState.kFrameStateSil
955
+ else:
956
+ frame_state = FrameState.kFrameStateSil
957
+ if self.noise_average_decibel < -99.9:
958
+ self.noise_average_decibel = cur_decibel
959
+ else:
960
+ self.noise_average_decibel = (
961
+ cur_decibel
962
+ + self.noise_average_decibel
963
+ * (self.vad_opts.noise_frame_num_used_for_snr - 1)
964
+ ) / self.vad_opts.noise_frame_num_used_for_snr
965
+
966
+ return frame_state
967
+
968
+ def infer_offline(
969
+ self,
970
+ feats: np.ndarray,
971
+ waveform: np.ndarray,
972
+ in_cache: Dict[str, np.ndarray] = dict(),
973
+ is_final: bool = False,
974
+ ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
975
+ self.waveform = waveform
976
+ self.compute_decibel()
977
+
978
+ self.compute_scores(feats)
979
+ if not is_final:
980
+ self.detect_common_frames()
981
+ else:
982
+ self.detect_last_frames()
983
+ segments = []
984
+ for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
985
+ segment_batch = []
986
+ if len(self.output_data_buf) > 0:
987
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
988
+ if (
989
+ not self.output_data_buf[i].contain_seg_start_point
990
+ or not self.output_data_buf[i].contain_seg_end_point
991
+ ):
992
+ continue
993
+ segment = [
994
+ self.output_data_buf[i].start_ms,
995
+ self.output_data_buf[i].end_ms,
996
+ ]
997
+ segment_batch.append(segment)
998
+ self.output_data_buf_offset += 1 # need update this parameter
999
+ if segment_batch:
1000
+ segments.append(segment_batch)
1001
+
1002
+ if is_final:
1003
+ # reset class variables and clear the dict for the next query
1004
+ self.all_reset_detection()
1005
+ return segments, in_cache
1006
+
1007
+ def infer_online(
1008
+ self,
1009
+ feats: np.ndarray,
1010
+ waveform: np.ndarray,
1011
+ in_cache: list = None,
1012
+ is_final: bool = False,
1013
+ max_end_sil: int = 800,
1014
+ ) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
1015
+ feats = [feats]
1016
+ if in_cache is None:
1017
+ in_cache = []
1018
+
1019
+ self.max_end_sil_frame_cnt_thresh = (
1020
+ max_end_sil - self.vad_opts.speech_to_sil_time_thres
1021
+ )
1022
+ self.waveform = waveform # compute decibel for each frame
1023
+ feats.extend(in_cache)
1024
+ in_cache = self.compute_scores(feats)
1025
+ self.compute_decibel()
1026
+
1027
+ if is_final:
1028
+ self.detect_last_frames()
1029
+ else:
1030
+ self.detect_common_frames()
1031
+
1032
+ segments = []
1033
+ # only support batch_size = 1 now
1034
+ for batch_num in range(0, feats[0].shape[0]):
1035
+ if len(self.output_data_buf) > 0:
1036
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
1037
+ if not self.output_data_buf[i].contain_seg_start_point:
1038
+ continue
1039
+ if (
1040
+ not self.next_seg
1041
+ and not self.output_data_buf[i].contain_seg_end_point
1042
+ ):
1043
+ continue
1044
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
1045
+ if self.output_data_buf[i].contain_seg_end_point:
1046
+ end_ms = self.output_data_buf[i].end_ms
1047
+ self.next_seg = True
1048
+ self.output_data_buf_offset += 1
1049
+ else:
1050
+ end_ms = -1
1051
+ self.next_seg = False
1052
+ segments.append([start_ms, end_ms])
1053
+
1054
+ return segments, in_cache
1055
+
1056
+ def get_frames_state(
1057
+ self,
1058
+ feats: np.ndarray,
1059
+ waveform: np.ndarray,
1060
+ in_cache: list = None,
1061
+ is_final: bool = False,
1062
+ max_end_sil: int = 800,
1063
+ ):
1064
+ feats = [feats]
1065
+ states = []
1066
+ if in_cache is None:
1067
+ in_cache = []
1068
+
1069
+ self.max_end_sil_frame_cnt_thresh = (
1070
+ max_end_sil - self.vad_opts.speech_to_sil_time_thres
1071
+ )
1072
+ self.waveform = waveform # compute decibel for each frame
1073
+ feats.extend(in_cache)
1074
+ in_cache = self.compute_scores(feats)
1075
+ self.compute_decibel()
1076
+
1077
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
1078
+ return states
1079
+
1080
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
1081
+ frame_state = FrameState.kFrameStateInvalid
1082
+ frame_state = self.get_frame_state(self.frm_cnt - 1 - i)
1083
+ states.append(frame_state)
1084
+ if i == 0 and is_final:
1085
+ logging.info("last frame detected")
1086
+ self.detect_one_frame(frame_state, self.frm_cnt - 1, True)
1087
+ else:
1088
+ self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False)
1089
+
1090
+ return states
1091
+
1092
+ def detect_common_frames(self) -> int:
1093
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
1094
+ return 0
1095
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
1096
+ frame_state = FrameState.kFrameStateInvalid
1097
+ frame_state = self.get_frame_state(self.frm_cnt - 1 - i)
1098
+ # print(f"cur frame: {self.frm_cnt - 1 - i}, state is {frame_state}")
1099
+ self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False)
1100
+
1101
+ self.decibel = self.decibel[self.vad_opts.nn_eval_block_size - 1 :]
1102
+ self.decibel_offset = self.frm_cnt - 1 - i
1103
+ return 0
1104
+
1105
+ def detect_last_frames(self) -> int:
1106
+ if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
1107
+ return 0
1108
+ for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
1109
+ frame_state = FrameState.kFrameStateInvalid
1110
+ frame_state = self.get_frame_state(self.frm_cnt - 1 - i)
1111
+ if i != 0:
1112
+ self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False)
1113
+ else:
1114
+ self.detect_one_frame(frame_state, self.frm_cnt - 1, True)
1115
+
1116
+ return 0
1117
+
1118
+ def detect_one_frame(
1119
+ self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool
1120
+ ) -> None:
1121
+ tmp_cur_frm_state = FrameState.kFrameStateInvalid
1122
+ if cur_frm_state == FrameState.kFrameStateSpeech:
1123
+ if math.fabs(1.0) > float(self.vad_opts.fe_prior_thres):
1124
+ tmp_cur_frm_state = FrameState.kFrameStateSpeech
1125
+ else:
1126
+ tmp_cur_frm_state = FrameState.kFrameStateSil
1127
+ elif cur_frm_state == FrameState.kFrameStateSil:
1128
+ tmp_cur_frm_state = FrameState.kFrameStateSil
1129
+ state_change = self.windows_detector.detect_one_frame(
1130
+ tmp_cur_frm_state, cur_frm_idx
1131
+ )
1132
+ frm_shift_in_ms = self.vad_opts.frame_in_ms
1133
+ if AudioChangeState.kChangeStateSil2Speech == state_change:
1134
+ self.continous_silence_frame_count = 0
1135
+ self.pre_end_silence_detected = False
1136
+
1137
+ if (
1138
+ self.vad_state_machine
1139
+ == VadStateMachine.kVadInStateStartPointNotDetected
1140
+ ):
1141
+ start_frame = max(
1142
+ self.data_buf_start_frame,
1143
+ cur_frm_idx - self.latency_frm_num_at_start_point(),
1144
+ )
1145
+ self.on_voice_start(start_frame)
1146
+ self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
1147
+ for t in range(start_frame + 1, cur_frm_idx + 1):
1148
+ self.on_voice_detected(t)
1149
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
1150
+ for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
1151
+ self.on_voice_detected(t)
1152
+ if (
1153
+ cur_frm_idx - self.confirmed_start_frame + 1
1154
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
1155
+ ):
1156
+ self.on_voice_end(cur_frm_idx, False, False)
1157
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
1158
+ elif not is_final_frame:
1159
+ self.on_voice_detected(cur_frm_idx)
1160
+ else:
1161
+ self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
1162
+ else:
1163
+ pass
1164
+ elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
1165
+ self.continous_silence_frame_count = 0
1166
+ if (
1167
+ self.vad_state_machine
1168
+ == VadStateMachine.kVadInStateStartPointNotDetected
1169
+ ):
1170
+ pass
1171
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
1172
+ if (
1173
+ cur_frm_idx - self.confirmed_start_frame + 1
1174
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
1175
+ ):
1176
+ self.on_voice_end(cur_frm_idx, False, False)
1177
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
1178
+ elif not is_final_frame:
1179
+ self.on_voice_detected(cur_frm_idx)
1180
+ else:
1181
+ self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
1182
+ else:
1183
+ pass
1184
+ elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
1185
+ self.continous_silence_frame_count = 0
1186
+ if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
1187
+ if (
1188
+ cur_frm_idx - self.confirmed_start_frame + 1
1189
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
1190
+ ):
1191
+ self.max_time_out = True
1192
+ self.on_voice_end(cur_frm_idx, False, False)
1193
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
1194
+ elif not is_final_frame:
1195
+ self.on_voice_detected(cur_frm_idx)
1196
+ else:
1197
+ self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
1198
+ else:
1199
+ pass
1200
+ elif AudioChangeState.kChangeStateSil2Sil == state_change:
1201
+ self.continous_silence_frame_count += 1
1202
+ if (
1203
+ self.vad_state_machine
1204
+ == VadStateMachine.kVadInStateStartPointNotDetected
1205
+ ):
1206
+ # silence timeout, return zero length decision
1207
+ if (
1208
+ (
1209
+ self.vad_opts.detect_mode
1210
+ == VadDetectMode.kVadSingleUtteranceDetectMode.value
1211
+ )
1212
+ and (
1213
+ self.continous_silence_frame_count * frm_shift_in_ms
1214
+ > self.vad_opts.max_start_silence_time
1215
+ )
1216
+ ) or (is_final_frame and self.number_end_time_detected == 0):
1217
+ for t in range(
1218
+ self.lastest_confirmed_silence_frame + 1, cur_frm_idx
1219
+ ):
1220
+ self.on_silence_detected(t)
1221
+ self.on_voice_start(0, True)
1222
+ self.on_voice_end(0, True, False)
1223
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
1224
+ else:
1225
+ if cur_frm_idx >= self.latency_frm_num_at_start_point():
1226
+ self.on_silence_detected(
1227
+ cur_frm_idx - self.latency_frm_num_at_start_point()
1228
+ )
1229
+ elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
1230
+ if (
1231
+ self.continous_silence_frame_count * frm_shift_in_ms
1232
+ >= self.max_end_sil_frame_cnt_thresh
1233
+ ):
1234
+ lookback_frame = int(
1235
+ self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms
1236
+ )
1237
+ if self.vad_opts.do_extend:
1238
+ lookback_frame -= int(
1239
+ self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
1240
+ )
1241
+ lookback_frame -= 1
1242
+ lookback_frame = max(0, lookback_frame)
1243
+ self.on_voice_end(cur_frm_idx - lookback_frame, False, False)
1244
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
1245
+ elif (
1246
+ cur_frm_idx - self.confirmed_start_frame + 1
1247
+ > self.vad_opts.max_single_segment_time / frm_shift_in_ms
1248
+ ):
1249
+ self.on_voice_end(cur_frm_idx, False, False)
1250
+ self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
1251
+ elif self.vad_opts.do_extend and not is_final_frame:
1252
+ if self.continous_silence_frame_count <= int(
1253
+ self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
1254
+ ):
1255
+ self.on_voice_detected(cur_frm_idx)
1256
+ else:
1257
+ self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
1258
+ else:
1259
+ pass
1260
+
1261
+ if (
1262
+ self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
1263
+ and self.vad_opts.detect_mode
1264
+ == VadDetectMode.kVadMutipleUtteranceDetectMode.value
1265
+ ):
1266
+ self.reset_detection()
1267
+
1268
+
1269
+ class FSMNVad(object):
1270
+ def __init__(self, config_dir: str):
1271
+ config_dir = Path(config_dir)
1272
+ self.config = read_yaml(config_dir / "fsmn-config.yaml")
1273
+ self.frontend = WavFrontend(
1274
+ cmvn_file=config_dir / "fsmn-am.mvn",
1275
+ **self.config["WavFrontend"]["frontend_conf"],
1276
+ )
1277
+ self.config["FSMN"]["model_path"] = config_dir / "fsmnvad-offline.onnx"
1278
+
1279
+ self.vad = E2EVadModel(
1280
+ self.config["FSMN"], self.config["vadPostArgs"], config_dir
1281
+ )
1282
+
1283
+ def set_parameters(self, mode):
1284
+ pass
1285
+
1286
+ def extract_feature(self, waveform):
1287
+ fbank, _ = self.frontend.fbank(waveform)
1288
+ feats, feats_len = self.frontend.lfr_cmvn(fbank)
1289
+ return feats.astype(np.float32), feats_len
1290
+
1291
+ def is_speech(self, buf, sample_rate=16000):
1292
+ assert sample_rate == 16000, "only support 16k sample rate"
1293
+
1294
+ def segments_offline(self, waveform_path: Union[str, Path, np.ndarray]):
1295
+ """get sements of audio"""
1296
+
1297
+ if isinstance(waveform_path, np.ndarray):
1298
+ waveform = waveform_path
1299
+ else:
1300
+ if not os.path.exists(waveform_path):
1301
+ raise FileExistsError(f"{waveform_path} is not exist.")
1302
+ if os.path.isfile(waveform_path):
1303
+ logging.info(f"load audio {waveform_path}")
1304
+ waveform, _sample_rate = sf.read(
1305
+ waveform_path,
1306
+ dtype="float32",
1307
+ )
1308
+ else:
1309
+ raise FileNotFoundError(str(Path))
1310
+ assert (
1311
+ _sample_rate == 16000
1312
+ ), f"only support 16k sample rate, current sample rate is {_sample_rate}"
1313
+
1314
+ feats, feats_len = self.extract_feature(waveform)
1315
+ waveform = waveform[None, ...]
1316
+ segments_part, in_cache = self.vad.infer_offline(
1317
+ feats[None, ...], waveform, is_final=True
1318
+ )
1319
+ return segments_part[0]
1320
+
1321
+ # ```
1322
+
1323
+ # File: sense_voice.py
1324
+ # ```py
1325
+ # -*- coding:utf-8 -*-
1326
+ # @FileName :sense_voice.py.py
1327
+ # @Time :2024/7/18 15:40
1328
+ # @Author :lovemefan
1329
+ # @Email :lovemefan@outlook.com
1330
+
1331
+ languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
1332
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
1333
+ logging.basicConfig(format=formatter, level=logging.INFO)
1334
+
1335
+ def main():
1336
+ arg_parser = argparse.ArgumentParser(description="Sense Voice")
1337
+ arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model")
1338
+ download_model_path = os.path.dirname(__file__)
1339
+ arg_parser.add_argument(
1340
+ "-dp",
1341
+ "--download_path",
1342
+ default=download_model_path,
1343
+ type=str,
1344
+ help="dir path of resource downloaded",
1345
+ )
1346
+ arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device")
1347
+ arg_parser.add_argument(
1348
+ "-n", "--num_threads", default=4, type=int, help="Num threads"
1349
+ )
1350
+ arg_parser.add_argument(
1351
+ "-l",
1352
+ "--language",
1353
+ choices=languages.keys(),
1354
+ default="auto",
1355
+ type=str,
1356
+ help="Language",
1357
+ )
1358
+ arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN")
1359
+ args = arg_parser.parse_args()
1360
+
1361
+ front = WavFrontend(os.path.join(download_model_path, "am.mvn"))
1362
+
1363
+ model = SenseVoiceInferenceSession(
1364
+ os.path.join(download_model_path, "embedding.npy"),
1365
+ os.path.join(
1366
+ download_model_path,
1367
+ "sense-voice-encoder.rknn",
1368
+ ),
1369
+ os.path.join(download_model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"),
1370
+ args.device,
1371
+ args.num_threads,
1372
+ )
1373
+ waveform, _sample_rate = sf.read(
1374
+ args.audio_file,
1375
+ dtype="float32",
1376
+ always_2d=True
1377
+ )
1378
+
1379
+ logging.info(f"Audio {args.audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel")
1380
+ # load vad model
1381
+ start = time.time()
1382
+ vad = FSMNVad(download_model_path)
1383
+ for channel_id, channel_data in enumerate(waveform.T):
1384
+ segments = vad.segments_offline(channel_data)
1385
+ results = ""
1386
+ for part in segments:
1387
+ audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16])
1388
+ asr_result = model(
1389
+ audio_feats[None, ...],
1390
+ language=languages[args.language],
1391
+ use_itn=args.use_itn,
1392
+ )
1393
+ logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}")
1394
+ vad.vad.all_reset_detection()
1395
+ decoding_time = time.time() - start
1396
+ logging.info(f"Decoder audio takes {decoding_time} seconds")
1397
+ logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.")
1398
+
1399
+
1400
+ if __name__ == "__main__":
1401
+ main()
1402
+