happyme531
commited on
Upload 9 files
Browse files- .gitattributes +1 -0
- README.md +131 -3
- chn_jpn_yue_eng_ko_spectok.bpe.model +3 -0
- convert_rknn.py +95 -0
- embedding.npy +3 -0
- fsmn-am.mvn +8 -0
- fsmn-config.yaml +59 -0
- fsmnvad-offline.onnx +3 -0
- output.wav +3 -0
- sensevoice_rknn.py +1402 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
output.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,131 @@
|
|
1 |
-
---
|
2 |
-
license: agpl-3.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: agpl-3.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
- zh
|
6 |
+
- ja
|
7 |
+
- ko
|
8 |
+
base_model: lovemefan/SenseVoice-onnx
|
9 |
+
tags:
|
10 |
+
- rknn
|
11 |
+
---
|
12 |
+
|
13 |
+
# SenseVoiceSmall-RKNN2
|
14 |
+
|
15 |
+
### (English README see below)
|
16 |
+
|
17 |
+
SenseVoice是具有音频理解能力的音频基础模型, 包括语音识别(ASR)、语种识别(LID)、语音情感识别(SER)和声学事件分类(AEC)或声学事件检测(AED)。
|
18 |
+
|
19 |
+
当前SenseVoice-small支持中、粤、英、日、韩语的多语言语音识别,情感识别和事件检测能力,具有极低的推理延迟。
|
20 |
+
|
21 |
+
- 推理速度(RKNN2):RK3588上单核NPU推理速度约20倍 (每秒识别20秒的音频), 比官方rknn-model-zoo中提供的whisper约快6倍.
|
22 |
+
- 内存占用(RKNN2):约1.1GB
|
23 |
+
|
24 |
+
## 使用方法
|
25 |
+
|
26 |
+
1. 克隆项目到本地
|
27 |
+
|
28 |
+
2. 安装依赖
|
29 |
+
|
30 |
+
```bash
|
31 |
+
pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
|
32 |
+
```
|
33 |
+
|
34 |
+
你还需要手动安装rknn-toolkit2-lite2.
|
35 |
+
|
36 |
+
3. 运行
|
37 |
+
|
38 |
+
```bash
|
39 |
+
python ./sensevoice_rknn.py --audio_file output.wav
|
40 |
+
```
|
41 |
+
|
42 |
+
如果使用自己的音频文件测试发现识别不正常,你可能需要提前将它转换为16kHz, 16bit, 单声道的wav格式。
|
43 |
+
|
44 |
+
```bash
|
45 |
+
ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
|
46 |
+
```
|
47 |
+
|
48 |
+
## RKNN模型转换
|
49 |
+
|
50 |
+
你需要提前安装rknn-toolkit2 v2.1.0或更高版本.
|
51 |
+
|
52 |
+
1. 下载或转换onnx模型
|
53 |
+
|
54 |
+
可以从 https://huggingface.co/lovemefan/SenseVoice-onnx 下载到onnx模型.
|
55 |
+
应该也可以根据 https://github.com/FunAudioLLM/SenseVoice 中的文档从Pytorch模型转换得到onnx模型.
|
56 |
+
|
57 |
+
模型文件应该命名为'sense-voice-encoder.onnx', 放在转换脚本所在目录.
|
58 |
+
|
59 |
+
2. 转换为rknn模型
|
60 |
+
```bash
|
61 |
+
python convert_rknn.py
|
62 |
+
```
|
63 |
+
|
64 |
+
## 已知问题
|
65 |
+
|
66 |
+
- RKNN2使用fp16推理时可能会出现溢出,导致结果为inf,可以尝试修改输入数据的缩放比例来解决.
|
67 |
+
在`sensevoice_rknn.py`中将`SPEECH_SCALE`设置为更小的值.
|
68 |
+
|
69 |
+
## 参考
|
70 |
+
- [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
|
71 |
+
- [lovemefan/SenseVoice-python](https://github.com/lovemefan/SenseVoice-python)
|
72 |
+
|
73 |
+
# English README
|
74 |
+
|
75 |
+
# SenseVoiceSmall-RKNN2
|
76 |
+
|
77 |
+
SenseVoice is an audio foundation model with audio understanding capabilities, including Automatic Speech Recognition (ASR), Language Identification (LID), Speech Emotion Recognition (SER), and Acoustic Event Classification (AEC) or Acoustic Event Detection (AED).
|
78 |
+
|
79 |
+
Currently, SenseVoice-small supports multilingual speech recognition, emotion recognition, and event detection for Chinese, Cantonese, English, Japanese, and Korean, with extremely low inference latency.
|
80 |
+
|
81 |
+
- Inference speed (RKNN2): About 20x real-time on a single NPU core of RK3588 (processing 20 seconds of audio per second), approximately 6 times faster than the official whisper model provided in the rknn-model-zoo.
|
82 |
+
- Memory usage (RKNN2): About 1.1GB
|
83 |
+
|
84 |
+
## Usage
|
85 |
+
|
86 |
+
1. Clone the project to your local machine
|
87 |
+
|
88 |
+
2. Install dependencies
|
89 |
+
|
90 |
+
```bash
|
91 |
+
pip install kaldi_native_fbank onnxruntime sentencepiece soundfile pyyaml numpy<2
|
92 |
+
```
|
93 |
+
|
94 |
+
You also need to manually install rknn-toolkit2-lite2.
|
95 |
+
|
96 |
+
3. Run
|
97 |
+
|
98 |
+
```bash
|
99 |
+
python ./sensevoice_rknn.py --audio_file output.wav
|
100 |
+
```
|
101 |
+
|
102 |
+
If you find that recognition is not working correctly when testing with your own audio files, you may need to convert them to 16kHz, 16-bit, mono WAV format in advance.
|
103 |
+
|
104 |
+
```bash
|
105 |
+
ffmpeg -i input.mp3 -f wav -acodec pcm_s16le -ac 1 -ar 16000 output.wav
|
106 |
+
```
|
107 |
+
|
108 |
+
## RKNN Model Conversion
|
109 |
+
|
110 |
+
You need to install rknn-toolkit2 v2.1.0 or higher in advance.
|
111 |
+
|
112 |
+
1. Download or convert the ONNX model
|
113 |
+
|
114 |
+
You can download the ONNX model from https://huggingface.co/lovemefan/SenseVoice-onnx.
|
115 |
+
It should also be possible to convert from a PyTorch model to an ONNX model according to the documentation at https://github.com/FunAudioLLM/SenseVoice.
|
116 |
+
|
117 |
+
The model file should be named 'sense-voice-encoder.onnx' and placed in the same directory as the conversion script.
|
118 |
+
|
119 |
+
2. Convert to RKNN model
|
120 |
+
```bash
|
121 |
+
python convert_rknn.py
|
122 |
+
```
|
123 |
+
|
124 |
+
## Known Issues
|
125 |
+
|
126 |
+
- When using fp16 inference with RKNN2, overflow may occur, resulting in inf values. You can try modifying the scaling ratio of the input data to resolve this.
|
127 |
+
Set `SPEECH_SCALE` to a smaller value in `sensevoice_rknn.py`.
|
128 |
+
|
129 |
+
## References
|
130 |
+
- [FunAudioLLM/SenseVoiceSmall](https://huggingface.co/FunAudioLLM/SenseVoiceSmall)
|
131 |
+
- [lovemefan/SenseVoice-python](https://github.com/lovemefan/SenseVoice-python)
|
chn_jpn_yue_eng_ko_spectok.bpe.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa87f86064c3730d799ddf7af3c04659151102cba548bce325cf06ba4da4e6a8
|
3 |
+
size 377341
|
convert_rknn.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import os
|
5 |
+
from rknn.api import RKNN
|
6 |
+
from math import exp
|
7 |
+
from sys import exit
|
8 |
+
import argparse
|
9 |
+
import onnxscript
|
10 |
+
from onnxscript.rewriter import pattern
|
11 |
+
import onnx.numpy_helper as onh
|
12 |
+
import numpy as np
|
13 |
+
import onnx
|
14 |
+
import onnxruntime as ort
|
15 |
+
from rknn.utils import onnx_edit
|
16 |
+
|
17 |
+
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
18 |
+
|
19 |
+
speech_length = 171
|
20 |
+
|
21 |
+
def convert_encoder():
|
22 |
+
rknn = RKNN(verbose=True)
|
23 |
+
|
24 |
+
ONNX_MODEL=f"sense-voice-encoder.onnx"
|
25 |
+
RKNN_MODEL=ONNX_MODEL.replace(".onnx",".rknn")
|
26 |
+
DATASET="dataset.txt"
|
27 |
+
QUANTIZE=False
|
28 |
+
|
29 |
+
#开局先给我来个大惊喜,rknn做第一步常量折叠的时候就会在这个子图里报错,所以要单独拿出来先跑一遍
|
30 |
+
#然后把这个子图的输出结果保存下来喂给rknn
|
31 |
+
onnx.utils.extract_model(ONNX_MODEL, "extract_model.onnx", ['speech_lengths'], ['/make_pad_mask/Cast_2_output_0'])
|
32 |
+
sess = ort.InferenceSession("extract_model.onnx", providers=['CPUExecutionProvider'])
|
33 |
+
extract_result = sess.run(None, {"speech_lengths": np.array([speech_length], dtype=np.int64)})[0]
|
34 |
+
|
35 |
+
# 删掉模型最后的多余transpose, 速度从365ms提升到350ms
|
36 |
+
ret = onnx_edit(model = ONNX_MODEL,
|
37 |
+
export_path = ONNX_MODEL.replace(".onnx", "_edited.onnx"),
|
38 |
+
# # 1, len, 25055 -> 1, 25055, 1, len # 这个是坏的, 我真服了,
|
39 |
+
# outputs_transform = {'encoder_out': 'a,b,c->a,c,1,b'},
|
40 |
+
outputs_transform = {'encoder_out': 'a,b,c->a,c,b'},
|
41 |
+
)
|
42 |
+
ONNX_MODEL = ONNX_MODEL.replace(".onnx", "_edited.onnx")
|
43 |
+
|
44 |
+
# pre-process config
|
45 |
+
print('--> Config model')
|
46 |
+
rknn.config(quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588', optimization_level=3)
|
47 |
+
print('done')
|
48 |
+
|
49 |
+
# Load ONNX model
|
50 |
+
print("--> Loading model")
|
51 |
+
ret = rknn.load_onnx(
|
52 |
+
model=ONNX_MODEL,
|
53 |
+
inputs=["speech", "/make_pad_mask/Cast_2_output_0"],
|
54 |
+
input_size_list=[[1, speech_length, 560], [extract_result.shape[0], extract_result.shape[1]]],
|
55 |
+
input_initial_val=[None, extract_result],
|
56 |
+
# outputs=["output"]
|
57 |
+
)
|
58 |
+
|
59 |
+
if ret != 0:
|
60 |
+
print('Load model failed!')
|
61 |
+
exit(ret)
|
62 |
+
print('done')
|
63 |
+
|
64 |
+
# Build model
|
65 |
+
print('--> Building model')
|
66 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
67 |
+
if ret != 0:
|
68 |
+
print('Build model failed!')
|
69 |
+
exit(ret)
|
70 |
+
print('done')
|
71 |
+
|
72 |
+
# export
|
73 |
+
print('--> Export RKNN model')
|
74 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
75 |
+
if ret != 0:
|
76 |
+
print('Export RKNN model failed!')
|
77 |
+
exit(ret)
|
78 |
+
print('done')
|
79 |
+
|
80 |
+
# usage: python convert_rknn.py encoder|all
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
parser = argparse.ArgumentParser()
|
84 |
+
parser.add_argument("model", type=str, help="model to convert", choices=["encoder", "all"], nargs='?')
|
85 |
+
args = parser.parse_args()
|
86 |
+
if args.model is None:
|
87 |
+
args.model = "all"
|
88 |
+
|
89 |
+
if args.model == "encoder":
|
90 |
+
convert_encoder()
|
91 |
+
elif args.model == "all":
|
92 |
+
convert_encoder()
|
93 |
+
else:
|
94 |
+
print(f"Unknown model: {args.model}")
|
95 |
+
exit(1)
|
embedding.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83cf1fc5680fdf6d7edb411be5ce351cad4eca03b29a5bf5050aa19dfcc12267
|
3 |
+
size 35968
|
fsmn-am.mvn
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<Nnet>
|
2 |
+
<Splice> 400 400
|
3 |
+
[ 0 ]
|
4 |
+
<AddShift> 400 400
|
5 |
+
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
|
6 |
+
<Rescale> 400 400
|
7 |
+
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
|
8 |
+
</Nnet>
|
fsmn-config.yaml
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
WavFrontend:
|
2 |
+
frontend_conf:
|
3 |
+
fs: 16000
|
4 |
+
window: hamming
|
5 |
+
n_mels: 80
|
6 |
+
frame_length: 25
|
7 |
+
frame_shift: 10
|
8 |
+
dither: 0.0
|
9 |
+
lfr_m: 5
|
10 |
+
lfr_n: 1
|
11 |
+
|
12 |
+
FSMN:
|
13 |
+
use_cuda: False
|
14 |
+
CUDAExecutionProvider:
|
15 |
+
device_id: 0
|
16 |
+
arena_extend_strategy: kNextPowerOfTwo
|
17 |
+
cudnn_conv_algo_search: EXHAUSTIVE
|
18 |
+
do_copy_in_default_stream: true
|
19 |
+
encoder_conf:
|
20 |
+
input_dim: 400
|
21 |
+
input_affine_dim: 140
|
22 |
+
fsmn_layers: 4
|
23 |
+
linear_dim: 250
|
24 |
+
proj_dim: 128
|
25 |
+
lorder: 20
|
26 |
+
rorder: 0
|
27 |
+
lstride: 1
|
28 |
+
rstride: 0
|
29 |
+
output_affine_dim: 140
|
30 |
+
output_dim: 248
|
31 |
+
|
32 |
+
vadPostArgs:
|
33 |
+
sample_rate: 16000
|
34 |
+
detect_mode: 1
|
35 |
+
snr_mode: 0
|
36 |
+
max_end_silence_time: 800
|
37 |
+
max_start_silence_time: 3000
|
38 |
+
do_start_point_detection: True
|
39 |
+
do_end_point_detection: True
|
40 |
+
window_size_ms: 200
|
41 |
+
sil_to_speech_time_thres: 150
|
42 |
+
speech_to_sil_time_thres: 150
|
43 |
+
speech_2_noise_ratio: 1.0
|
44 |
+
do_extend: 1
|
45 |
+
lookback_time_start_point: 200
|
46 |
+
lookahead_time_end_point: 100
|
47 |
+
max_single_segment_time: 10000
|
48 |
+
snr_thres: -100.0
|
49 |
+
noise_frame_num_used_for_snr: 100
|
50 |
+
decibel_thres: -100.0
|
51 |
+
speech_noise_thres: 0.6
|
52 |
+
fe_prior_thres: 0.0001
|
53 |
+
silence_pdf_num: 1
|
54 |
+
sil_pdf_ids: [ 0 ]
|
55 |
+
speech_noise_thresh_low: -0.1
|
56 |
+
speech_noise_thresh_high: 0.3
|
57 |
+
output_frame_probs: False
|
58 |
+
frame_in_ms: 10
|
59 |
+
frame_length_ms: 25
|
fsmnvad-offline.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4bbd68b11519e916b6871ff6f8df15e2100936b256be9cb104cd63fb7c859965
|
3 |
+
size 1725472
|
output.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6f02d2c58b9a8a294a306ccb60bdf667587d74984915a3ec87a6de5e04bb020
|
3 |
+
size 1289994
|
sensevoice_rknn.py
ADDED
@@ -0,0 +1,1402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# File: onnx/fsmn_vad_ort_session.py
|
2 |
+
# ```py
|
3 |
+
|
4 |
+
# -*- coding:utf-8 -*-
|
5 |
+
# @FileName :fsmn_vad_ort_session.py.py
|
6 |
+
# @Time :2024/8/31 16:45
|
7 |
+
# @Author :lovemefan
|
8 |
+
# @Email :lovemefan@outlook.com
|
9 |
+
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
import warnings
|
16 |
+
from enum import Enum
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Any, Dict, List, Tuple, Union
|
19 |
+
|
20 |
+
import kaldi_native_fbank as knf
|
21 |
+
import numpy as np
|
22 |
+
import sentencepiece as spm
|
23 |
+
import soundfile as sf
|
24 |
+
import yaml
|
25 |
+
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
|
26 |
+
SessionOptions, get_available_providers, get_device)
|
27 |
+
from rknnlite.api.rknn_lite import RKNNLite
|
28 |
+
|
29 |
+
RKNN_INPUT_LEN = 171
|
30 |
+
|
31 |
+
SPEECH_SCALE = 1/2 # 因为是fp16推理,如果中间结果太大可能会溢出变inf,所以需要缩放一下
|
32 |
+
|
33 |
+
class VadOrtInferRuntimeSession:
|
34 |
+
def __init__(self, config, root_dir: Path):
|
35 |
+
sess_opt = SessionOptions()
|
36 |
+
sess_opt.log_severity_level = 4
|
37 |
+
sess_opt.enable_cpu_mem_arena = False
|
38 |
+
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
39 |
+
|
40 |
+
cuda_ep = "CUDAExecutionProvider"
|
41 |
+
cpu_ep = "CPUExecutionProvider"
|
42 |
+
cpu_provider_options = {
|
43 |
+
"arena_extend_strategy": "kSameAsRequested",
|
44 |
+
}
|
45 |
+
|
46 |
+
EP_list = []
|
47 |
+
if (
|
48 |
+
config["use_cuda"]
|
49 |
+
and get_device() == "GPU"
|
50 |
+
and cuda_ep in get_available_providers()
|
51 |
+
):
|
52 |
+
EP_list = [(cuda_ep, config[cuda_ep])]
|
53 |
+
EP_list.append((cpu_ep, cpu_provider_options))
|
54 |
+
|
55 |
+
config["model_path"] = root_dir / str(config["model_path"])
|
56 |
+
self._verify_model(config["model_path"])
|
57 |
+
logging.info(f"Loading onnx model at {str(config['model_path'])}")
|
58 |
+
self.session = InferenceSession(
|
59 |
+
str(config["model_path"]), sess_options=sess_opt, providers=EP_list
|
60 |
+
)
|
61 |
+
|
62 |
+
if config["use_cuda"] and cuda_ep not in self.session.get_providers():
|
63 |
+
logging.warning(
|
64 |
+
f"{cuda_ep} is not available for current env, "
|
65 |
+
f"the inference part is automatically shifted to be "
|
66 |
+
f"executed under {cpu_ep}.\n "
|
67 |
+
"Please ensure the installed onnxruntime-gpu version"
|
68 |
+
" matches your cuda and cudnn version, "
|
69 |
+
"you can check their relations from the offical web site: "
|
70 |
+
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
|
71 |
+
RuntimeWarning,
|
72 |
+
)
|
73 |
+
|
74 |
+
def __call__(
|
75 |
+
self, input_content: List[Union[np.ndarray, np.ndarray]]
|
76 |
+
) -> np.ndarray:
|
77 |
+
if isinstance(input_content, list):
|
78 |
+
input_dict = {
|
79 |
+
"speech": input_content[0],
|
80 |
+
"in_cache0": input_content[1],
|
81 |
+
"in_cache1": input_content[2],
|
82 |
+
"in_cache2": input_content[3],
|
83 |
+
"in_cache3": input_content[4],
|
84 |
+
}
|
85 |
+
else:
|
86 |
+
input_dict = {"speech": input_content}
|
87 |
+
|
88 |
+
return self.session.run(None, input_dict)
|
89 |
+
|
90 |
+
def get_input_names(
|
91 |
+
self,
|
92 |
+
):
|
93 |
+
return [v.name for v in self.session.get_inputs()]
|
94 |
+
|
95 |
+
def get_output_names(
|
96 |
+
self,
|
97 |
+
):
|
98 |
+
return [v.name for v in self.session.get_outputs()]
|
99 |
+
|
100 |
+
def get_character_list(self, key: str = "character"):
|
101 |
+
return self.meta_dict[key].splitlines()
|
102 |
+
|
103 |
+
def have_key(self, key: str = "character") -> bool:
|
104 |
+
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
105 |
+
if key in self.meta_dict.keys():
|
106 |
+
return True
|
107 |
+
return False
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def _verify_model(model_path):
|
111 |
+
model_path = Path(model_path)
|
112 |
+
if not model_path.exists():
|
113 |
+
raise FileNotFoundError(f"{model_path} does not exists.")
|
114 |
+
if not model_path.is_file():
|
115 |
+
raise FileExistsError(f"{model_path} is not a file.")
|
116 |
+
|
117 |
+
# ```
|
118 |
+
|
119 |
+
# File: onnx/sense_voice_ort_session.py
|
120 |
+
# ```py
|
121 |
+
# -*- coding:utf-8 -*-
|
122 |
+
# @FileName :sense_voice_onnxruntime.py
|
123 |
+
# @Time :2024/7/17 20:53
|
124 |
+
# @Author :lovemefan
|
125 |
+
# @Email :lovemefan@outlook.com
|
126 |
+
|
127 |
+
|
128 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
129 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
130 |
+
|
131 |
+
|
132 |
+
class OrtInferRuntimeSession:
|
133 |
+
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
|
134 |
+
device_id = str(device_id)
|
135 |
+
sess_opt = SessionOptions()
|
136 |
+
sess_opt.intra_op_num_threads = intra_op_num_threads
|
137 |
+
sess_opt.log_severity_level = 4
|
138 |
+
sess_opt.enable_cpu_mem_arena = False
|
139 |
+
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
140 |
+
|
141 |
+
cuda_ep = "CUDAExecutionProvider"
|
142 |
+
cuda_provider_options = {
|
143 |
+
"device_id": device_id,
|
144 |
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
145 |
+
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
146 |
+
"do_copy_in_default_stream": "true",
|
147 |
+
}
|
148 |
+
cpu_ep = "CPUExecutionProvider"
|
149 |
+
cpu_provider_options = {
|
150 |
+
"arena_extend_strategy": "kSameAsRequested",
|
151 |
+
}
|
152 |
+
|
153 |
+
EP_list = []
|
154 |
+
if (
|
155 |
+
device_id != "-1"
|
156 |
+
and get_device() == "GPU"
|
157 |
+
and cuda_ep in get_available_providers()
|
158 |
+
):
|
159 |
+
EP_list = [(cuda_ep, cuda_provider_options)]
|
160 |
+
EP_list.append((cpu_ep, cpu_provider_options))
|
161 |
+
|
162 |
+
self._verify_model(model_file)
|
163 |
+
|
164 |
+
self.session = InferenceSession(
|
165 |
+
model_file, sess_options=sess_opt, providers=EP_list
|
166 |
+
)
|
167 |
+
|
168 |
+
# delete binary of model file to save memory
|
169 |
+
del model_file
|
170 |
+
|
171 |
+
if device_id != "-1" and cuda_ep not in self.session.get_providers():
|
172 |
+
warnings.warn(
|
173 |
+
f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
|
174 |
+
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
|
175 |
+
"you can check their relations from the offical web site: "
|
176 |
+
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
|
177 |
+
RuntimeWarning,
|
178 |
+
)
|
179 |
+
|
180 |
+
def __call__(self, input_content) -> np.ndarray:
|
181 |
+
input_dict = dict(zip(self.get_input_names(), input_content))
|
182 |
+
try:
|
183 |
+
result = self.session.run(self.get_output_names(), input_dict)
|
184 |
+
return result
|
185 |
+
except Exception as e:
|
186 |
+
print(e)
|
187 |
+
raise RuntimeError(f"ONNXRuntime inferece failed. ") from e
|
188 |
+
|
189 |
+
def get_input_names(
|
190 |
+
self,
|
191 |
+
):
|
192 |
+
return [v.name for v in self.session.get_inputs()]
|
193 |
+
|
194 |
+
def get_output_names(
|
195 |
+
self,
|
196 |
+
):
|
197 |
+
return [v.name for v in self.session.get_outputs()]
|
198 |
+
|
199 |
+
def get_character_list(self, key: str = "character"):
|
200 |
+
return self.meta_dict[key].splitlines()
|
201 |
+
|
202 |
+
def have_key(self, key: str = "character") -> bool:
|
203 |
+
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
204 |
+
if key in self.meta_dict.keys():
|
205 |
+
return True
|
206 |
+
return False
|
207 |
+
|
208 |
+
@staticmethod
|
209 |
+
def _verify_model(model_path):
|
210 |
+
model_path = Path(model_path)
|
211 |
+
if not model_path.exists():
|
212 |
+
raise FileNotFoundError(f"{model_path} does not exists.")
|
213 |
+
if not model_path.is_file():
|
214 |
+
raise FileExistsError(f"{model_path} is not a file.")
|
215 |
+
|
216 |
+
|
217 |
+
def log_softmax(x: np.ndarray) -> np.ndarray:
|
218 |
+
# Subtract the maximum value in each row for numerical stability
|
219 |
+
x_max = np.max(x, axis=-1, keepdims=True)
|
220 |
+
# Calculate the softmax of x
|
221 |
+
softmax = np.exp(x - x_max)
|
222 |
+
softmax_sum = np.sum(softmax, axis=-1, keepdims=True)
|
223 |
+
softmax = softmax / softmax_sum
|
224 |
+
# Calculate the log of the softmax values
|
225 |
+
return np.log(softmax)
|
226 |
+
|
227 |
+
|
228 |
+
class SenseVoiceInferenceSession:
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
embedding_model_file,
|
232 |
+
encoder_model_file,
|
233 |
+
bpe_model_file,
|
234 |
+
device_id=-1,
|
235 |
+
intra_op_num_threads=4,
|
236 |
+
):
|
237 |
+
logging.info(f"Loading model from {embedding_model_file}")
|
238 |
+
|
239 |
+
self.embedding = np.load(embedding_model_file)
|
240 |
+
logging.info(f"Loading model {encoder_model_file}")
|
241 |
+
start = time.time()
|
242 |
+
self.encoder = RKNNLite(verbose=False)
|
243 |
+
self.encoder.load_rknn(encoder_model_file)
|
244 |
+
self.encoder.init_runtime()
|
245 |
+
|
246 |
+
logging.info(
|
247 |
+
f"Loading {encoder_model_file} takes {time.time() - start:.2f} seconds"
|
248 |
+
)
|
249 |
+
self.blank_id = 0
|
250 |
+
self.sp = spm.SentencePieceProcessor()
|
251 |
+
self.sp.load(bpe_model_file)
|
252 |
+
|
253 |
+
def __call__(self, speech, language: int, use_itn: bool) -> np.ndarray:
|
254 |
+
language_query = self.embedding[[[language]]]
|
255 |
+
|
256 |
+
# 14 means with itn, 15 means without itn
|
257 |
+
text_norm_query = self.embedding[[[14 if use_itn else 15]]]
|
258 |
+
event_emo_query = self.embedding[[[1, 2]]]
|
259 |
+
|
260 |
+
# scale the speech
|
261 |
+
speech = speech * SPEECH_SCALE
|
262 |
+
|
263 |
+
input_content = np.concatenate(
|
264 |
+
[
|
265 |
+
language_query,
|
266 |
+
event_emo_query,
|
267 |
+
text_norm_query,
|
268 |
+
speech,
|
269 |
+
],
|
270 |
+
axis=1,
|
271 |
+
).astype(np.float32)
|
272 |
+
print(input_content.shape)
|
273 |
+
# pad [1, len, ...] to [1, RKNN_INPUT_LEN, ... ]
|
274 |
+
input_content = np.pad(input_content, ((0, 0), (0, RKNN_INPUT_LEN - input_content.shape[1]), (0, 0)))
|
275 |
+
print("padded shape:", input_content.shape)
|
276 |
+
start_time = time.time()
|
277 |
+
encoder_out = self.encoder.inference(inputs=[input_content])[0]
|
278 |
+
end_time = time.time()
|
279 |
+
print(f"encoder inference time: {end_time - start_time:.2f} seconds")
|
280 |
+
# print(encoder_out)
|
281 |
+
def unique_consecutive(arr):
|
282 |
+
if len(arr) == 0:
|
283 |
+
return arr
|
284 |
+
# Create a boolean mask where True indicates the element is different from the previous one
|
285 |
+
mask = np.append([True], arr[1:] != arr[:-1])
|
286 |
+
out = arr[mask]
|
287 |
+
out = out[out != self.blank_id]
|
288 |
+
return out.tolist()
|
289 |
+
|
290 |
+
#现在shape变成了1, n_vocab, n_seq. 这里axis需要改一下
|
291 |
+
# hypos = unique_consecutive(encoder_out[0].argmax(axis=-1))
|
292 |
+
hypos = unique_consecutive(encoder_out[0].argmax(axis=0))
|
293 |
+
text = self.sp.DecodeIds(hypos)
|
294 |
+
return text
|
295 |
+
|
296 |
+
# ```
|
297 |
+
|
298 |
+
# File: utils/frontend.py
|
299 |
+
# ```py
|
300 |
+
# -*- coding:utf-8 -*-
|
301 |
+
# @FileName :frontend.py
|
302 |
+
# @Time :2024/7/18 09:39
|
303 |
+
# @Author :lovemefan
|
304 |
+
# @Email :lovemefan@outlook.com
|
305 |
+
|
306 |
+
class WavFrontend:
|
307 |
+
"""Conventional frontend structure for ASR."""
|
308 |
+
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
cmvn_file: str = None,
|
312 |
+
fs: int = 16000,
|
313 |
+
window: str = "hamming",
|
314 |
+
n_mels: int = 80,
|
315 |
+
frame_length: int = 25,
|
316 |
+
frame_shift: int = 10,
|
317 |
+
lfr_m: int = 7,
|
318 |
+
lfr_n: int = 6,
|
319 |
+
dither: float = 0,
|
320 |
+
**kwargs,
|
321 |
+
) -> None:
|
322 |
+
opts = knf.FbankOptions()
|
323 |
+
opts.frame_opts.samp_freq = fs
|
324 |
+
opts.frame_opts.dither = dither
|
325 |
+
opts.frame_opts.window_type = window
|
326 |
+
opts.frame_opts.frame_shift_ms = float(frame_shift)
|
327 |
+
opts.frame_opts.frame_length_ms = float(frame_length)
|
328 |
+
opts.mel_opts.num_bins = n_mels
|
329 |
+
opts.energy_floor = 0
|
330 |
+
opts.frame_opts.snip_edges = True
|
331 |
+
opts.mel_opts.debug_mel = False
|
332 |
+
self.opts = opts
|
333 |
+
|
334 |
+
self.lfr_m = lfr_m
|
335 |
+
self.lfr_n = lfr_n
|
336 |
+
self.cmvn_file = cmvn_file
|
337 |
+
|
338 |
+
if self.cmvn_file:
|
339 |
+
self.cmvn = self.load_cmvn()
|
340 |
+
self.fbank_fn = None
|
341 |
+
self.fbank_beg_idx = 0
|
342 |
+
self.reset_status()
|
343 |
+
|
344 |
+
def reset_status(self):
|
345 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
346 |
+
self.fbank_beg_idx = 0
|
347 |
+
|
348 |
+
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
349 |
+
waveform = waveform * (1 << 15)
|
350 |
+
self.fbank_fn = knf.OnlineFbank(self.opts)
|
351 |
+
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
352 |
+
frames = self.fbank_fn.num_frames_ready
|
353 |
+
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
354 |
+
for i in range(frames):
|
355 |
+
mat[i, :] = self.fbank_fn.get_frame(i)
|
356 |
+
feat = mat.astype(np.float32)
|
357 |
+
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
358 |
+
return feat, feat_len
|
359 |
+
|
360 |
+
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
361 |
+
if self.lfr_m != 1 or self.lfr_n != 1:
|
362 |
+
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
|
363 |
+
|
364 |
+
if self.cmvn_file:
|
365 |
+
feat = self.apply_cmvn(feat)
|
366 |
+
|
367 |
+
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
368 |
+
return feat, feat_len
|
369 |
+
|
370 |
+
def load_audio(self, filename: str) -> Tuple[np.ndarray, int]:
|
371 |
+
data, sample_rate = sf.read(
|
372 |
+
filename,
|
373 |
+
always_2d=True,
|
374 |
+
dtype="float32",
|
375 |
+
)
|
376 |
+
assert (
|
377 |
+
sample_rate == 16000
|
378 |
+
), f"Only 16000 Hz is supported, but got {sample_rate}Hz"
|
379 |
+
self.sample_rate = sample_rate
|
380 |
+
data = data[:, 0] # use only the first channel
|
381 |
+
samples = np.ascontiguousarray(data)
|
382 |
+
|
383 |
+
return samples, sample_rate
|
384 |
+
|
385 |
+
@staticmethod
|
386 |
+
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
|
387 |
+
LFR_inputs = []
|
388 |
+
|
389 |
+
T = inputs.shape[0]
|
390 |
+
T_lfr = int(np.ceil(T / lfr_n))
|
391 |
+
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
|
392 |
+
inputs = np.vstack((left_padding, inputs))
|
393 |
+
T = T + (lfr_m - 1) // 2
|
394 |
+
for i in range(T_lfr):
|
395 |
+
if lfr_m <= T - i * lfr_n:
|
396 |
+
LFR_inputs.append(
|
397 |
+
(inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)
|
398 |
+
)
|
399 |
+
else:
|
400 |
+
# process last LFR frame
|
401 |
+
num_padding = lfr_m - (T - i * lfr_n)
|
402 |
+
frame = inputs[i * lfr_n :].reshape(-1)
|
403 |
+
for _ in range(num_padding):
|
404 |
+
frame = np.hstack((frame, inputs[-1]))
|
405 |
+
|
406 |
+
LFR_inputs.append(frame)
|
407 |
+
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
|
408 |
+
return LFR_outputs
|
409 |
+
|
410 |
+
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
|
411 |
+
"""
|
412 |
+
Apply CMVN with mvn data
|
413 |
+
"""
|
414 |
+
frame, dim = inputs.shape
|
415 |
+
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
|
416 |
+
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
|
417 |
+
inputs = (inputs + means) * vars
|
418 |
+
return inputs
|
419 |
+
|
420 |
+
def get_features(self, inputs: Union[str, np.ndarray]) -> Tuple[np.ndarray, int]:
|
421 |
+
if isinstance(inputs, str):
|
422 |
+
inputs, _ = self.load_audio(inputs)
|
423 |
+
|
424 |
+
fbank, _ = self.fbank(inputs)
|
425 |
+
feats = self.apply_cmvn(self.apply_lfr(fbank, self.lfr_m, self.lfr_n))
|
426 |
+
return feats
|
427 |
+
|
428 |
+
def load_cmvn(
|
429 |
+
self,
|
430 |
+
) -> np.ndarray:
|
431 |
+
with open(self.cmvn_file, "r", encoding="utf-8") as f:
|
432 |
+
lines = f.readlines()
|
433 |
+
|
434 |
+
means_list = []
|
435 |
+
vars_list = []
|
436 |
+
for i in range(len(lines)):
|
437 |
+
line_item = lines[i].split()
|
438 |
+
if line_item[0] == "<AddShift>":
|
439 |
+
line_item = lines[i + 1].split()
|
440 |
+
if line_item[0] == "<LearnRateCoef>":
|
441 |
+
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
442 |
+
means_list = list(add_shift_line)
|
443 |
+
continue
|
444 |
+
elif line_item[0] == "<Rescale>":
|
445 |
+
line_item = lines[i + 1].split()
|
446 |
+
if line_item[0] == "<LearnRateCoef>":
|
447 |
+
rescale_line = line_item[3 : (len(line_item) - 1)]
|
448 |
+
vars_list = list(rescale_line)
|
449 |
+
continue
|
450 |
+
|
451 |
+
means = np.array(means_list).astype(np.float64)
|
452 |
+
vars = np.array(vars_list).astype(np.float64)
|
453 |
+
cmvn = np.array([means, vars])
|
454 |
+
return cmvn
|
455 |
+
|
456 |
+
# ```
|
457 |
+
|
458 |
+
# File: utils/fsmn_vad.py
|
459 |
+
# ```py
|
460 |
+
# -*- coding:utf-8 -*-
|
461 |
+
# @FileName :fsmn_vad.py
|
462 |
+
# @Time :2024/8/31 16:50
|
463 |
+
# @Author :lovemefan
|
464 |
+
# @Email :lovemefan@outlook.com
|
465 |
+
|
466 |
+
|
467 |
+
|
468 |
+
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
|
469 |
+
if not Path(yaml_path).exists():
|
470 |
+
raise FileExistsError(f"The {yaml_path} does not exist.")
|
471 |
+
|
472 |
+
with open(str(yaml_path), "rb") as f:
|
473 |
+
data = yaml.load(f, Loader=yaml.Loader)
|
474 |
+
return data
|
475 |
+
|
476 |
+
|
477 |
+
class VadStateMachine(Enum):
|
478 |
+
kVadInStateStartPointNotDetected = 1
|
479 |
+
kVadInStateInSpeechSegment = 2
|
480 |
+
kVadInStateEndPointDetected = 3
|
481 |
+
|
482 |
+
|
483 |
+
class FrameState(Enum):
|
484 |
+
kFrameStateInvalid = -1
|
485 |
+
kFrameStateSpeech = 1
|
486 |
+
kFrameStateSil = 0
|
487 |
+
|
488 |
+
|
489 |
+
# final voice/unvoice state per frame
|
490 |
+
class AudioChangeState(Enum):
|
491 |
+
kChangeStateSpeech2Speech = 0
|
492 |
+
kChangeStateSpeech2Sil = 1
|
493 |
+
kChangeStateSil2Sil = 2
|
494 |
+
kChangeStateSil2Speech = 3
|
495 |
+
kChangeStateNoBegin = 4
|
496 |
+
kChangeStateInvalid = 5
|
497 |
+
|
498 |
+
|
499 |
+
class VadDetectMode(Enum):
|
500 |
+
kVadSingleUtteranceDetectMode = 0
|
501 |
+
kVadMutipleUtteranceDetectMode = 1
|
502 |
+
|
503 |
+
|
504 |
+
class VADXOptions:
|
505 |
+
def __init__(
|
506 |
+
self,
|
507 |
+
sample_rate: int = 16000,
|
508 |
+
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
|
509 |
+
snr_mode: int = 0,
|
510 |
+
max_end_silence_time: int = 800,
|
511 |
+
max_start_silence_time: int = 3000,
|
512 |
+
do_start_point_detection: bool = True,
|
513 |
+
do_end_point_detection: bool = True,
|
514 |
+
window_size_ms: int = 200,
|
515 |
+
sil_to_speech_time_thres: int = 150,
|
516 |
+
speech_to_sil_time_thres: int = 150,
|
517 |
+
speech_2_noise_ratio: float = 1.0,
|
518 |
+
do_extend: int = 1,
|
519 |
+
lookback_time_start_point: int = 200,
|
520 |
+
lookahead_time_end_point: int = 100,
|
521 |
+
max_single_segment_time: int = 60000,
|
522 |
+
nn_eval_block_size: int = 8,
|
523 |
+
dcd_block_size: int = 4,
|
524 |
+
snr_thres: int = -100.0,
|
525 |
+
noise_frame_num_used_for_snr: int = 100,
|
526 |
+
decibel_thres: int = -100.0,
|
527 |
+
speech_noise_thres: float = 0.6,
|
528 |
+
fe_prior_thres: float = 1e-4,
|
529 |
+
silence_pdf_num: int = 1,
|
530 |
+
sil_pdf_ids: List[int] = [0],
|
531 |
+
speech_noise_thresh_low: float = -0.1,
|
532 |
+
speech_noise_thresh_high: float = 0.3,
|
533 |
+
output_frame_probs: bool = False,
|
534 |
+
frame_in_ms: int = 10,
|
535 |
+
frame_length_ms: int = 25,
|
536 |
+
):
|
537 |
+
self.sample_rate = sample_rate
|
538 |
+
self.detect_mode = detect_mode
|
539 |
+
self.snr_mode = snr_mode
|
540 |
+
self.max_end_silence_time = max_end_silence_time
|
541 |
+
self.max_start_silence_time = max_start_silence_time
|
542 |
+
self.do_start_point_detection = do_start_point_detection
|
543 |
+
self.do_end_point_detection = do_end_point_detection
|
544 |
+
self.window_size_ms = window_size_ms
|
545 |
+
self.sil_to_speech_time_thres = sil_to_speech_time_thres
|
546 |
+
self.speech_to_sil_time_thres = speech_to_sil_time_thres
|
547 |
+
self.speech_2_noise_ratio = speech_2_noise_ratio
|
548 |
+
self.do_extend = do_extend
|
549 |
+
self.lookback_time_start_point = lookback_time_start_point
|
550 |
+
self.lookahead_time_end_point = lookahead_time_end_point
|
551 |
+
self.max_single_segment_time = max_single_segment_time
|
552 |
+
self.nn_eval_block_size = nn_eval_block_size
|
553 |
+
self.dcd_block_size = dcd_block_size
|
554 |
+
self.snr_thres = snr_thres
|
555 |
+
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
|
556 |
+
self.decibel_thres = decibel_thres
|
557 |
+
self.speech_noise_thres = speech_noise_thres
|
558 |
+
self.fe_prior_thres = fe_prior_thres
|
559 |
+
self.silence_pdf_num = silence_pdf_num
|
560 |
+
self.sil_pdf_ids = sil_pdf_ids
|
561 |
+
self.speech_noise_thresh_low = speech_noise_thresh_low
|
562 |
+
self.speech_noise_thresh_high = speech_noise_thresh_high
|
563 |
+
self.output_frame_probs = output_frame_probs
|
564 |
+
self.frame_in_ms = frame_in_ms
|
565 |
+
self.frame_length_ms = frame_length_ms
|
566 |
+
|
567 |
+
|
568 |
+
class E2EVadSpeechBufWithDoa(object):
|
569 |
+
def __init__(self):
|
570 |
+
self.start_ms = 0
|
571 |
+
self.end_ms = 0
|
572 |
+
self.buffer = []
|
573 |
+
self.contain_seg_start_point = False
|
574 |
+
self.contain_seg_end_point = False
|
575 |
+
self.doa = 0
|
576 |
+
|
577 |
+
def reset(self):
|
578 |
+
self.start_ms = 0
|
579 |
+
self.end_ms = 0
|
580 |
+
self.buffer = []
|
581 |
+
self.contain_seg_start_point = False
|
582 |
+
self.contain_seg_end_point = False
|
583 |
+
self.doa = 0
|
584 |
+
|
585 |
+
|
586 |
+
class E2EVadFrameProb(object):
|
587 |
+
def __init__(self):
|
588 |
+
self.noise_prob = 0.0
|
589 |
+
self.speech_prob = 0.0
|
590 |
+
self.score = 0.0
|
591 |
+
self.frame_id = 0
|
592 |
+
self.frm_state = 0
|
593 |
+
|
594 |
+
|
595 |
+
class WindowDetector(object):
|
596 |
+
def __init__(
|
597 |
+
self,
|
598 |
+
window_size_ms: int,
|
599 |
+
sil_to_speech_time: int,
|
600 |
+
speech_to_sil_time: int,
|
601 |
+
frame_size_ms: int,
|
602 |
+
):
|
603 |
+
self.window_size_ms = window_size_ms
|
604 |
+
self.sil_to_speech_time = sil_to_speech_time
|
605 |
+
self.speech_to_sil_time = speech_to_sil_time
|
606 |
+
self.frame_size_ms = frame_size_ms
|
607 |
+
|
608 |
+
self.win_size_frame = int(window_size_ms / frame_size_ms)
|
609 |
+
self.win_sum = 0
|
610 |
+
self.win_state = [0] * self.win_size_frame # 初始化窗
|
611 |
+
|
612 |
+
self.cur_win_pos = 0
|
613 |
+
self.pre_frame_state = FrameState.kFrameStateSil
|
614 |
+
self.cur_frame_state = FrameState.kFrameStateSil
|
615 |
+
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
|
616 |
+
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
|
617 |
+
|
618 |
+
self.voice_last_frame_count = 0
|
619 |
+
self.noise_last_frame_count = 0
|
620 |
+
self.hydre_frame_count = 0
|
621 |
+
|
622 |
+
def reset(self) -> None:
|
623 |
+
self.cur_win_pos = 0
|
624 |
+
self.win_sum = 0
|
625 |
+
self.win_state = [0] * self.win_size_frame
|
626 |
+
self.pre_frame_state = FrameState.kFrameStateSil
|
627 |
+
self.cur_frame_state = FrameState.kFrameStateSil
|
628 |
+
self.voice_last_frame_count = 0
|
629 |
+
self.noise_last_frame_count = 0
|
630 |
+
self.hydre_frame_count = 0
|
631 |
+
|
632 |
+
def get_win_size(self) -> int:
|
633 |
+
return int(self.win_size_frame)
|
634 |
+
|
635 |
+
def detect_one_frame(
|
636 |
+
self, frameState: FrameState, frame_count: int
|
637 |
+
) -> AudioChangeState:
|
638 |
+
cur_frame_state = FrameState.kFrameStateSil
|
639 |
+
if frameState == FrameState.kFrameStateSpeech:
|
640 |
+
cur_frame_state = 1
|
641 |
+
elif frameState == FrameState.kFrameStateSil:
|
642 |
+
cur_frame_state = 0
|
643 |
+
else:
|
644 |
+
return AudioChangeState.kChangeStateInvalid
|
645 |
+
self.win_sum -= self.win_state[self.cur_win_pos]
|
646 |
+
self.win_sum += cur_frame_state
|
647 |
+
self.win_state[self.cur_win_pos] = cur_frame_state
|
648 |
+
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
|
649 |
+
|
650 |
+
if (
|
651 |
+
self.pre_frame_state == FrameState.kFrameStateSil
|
652 |
+
and self.win_sum >= self.sil_to_speech_frmcnt_thres
|
653 |
+
):
|
654 |
+
self.pre_frame_state = FrameState.kFrameStateSpeech
|
655 |
+
return AudioChangeState.kChangeStateSil2Speech
|
656 |
+
|
657 |
+
if (
|
658 |
+
self.pre_frame_state == FrameState.kFrameStateSpeech
|
659 |
+
and self.win_sum <= self.speech_to_sil_frmcnt_thres
|
660 |
+
):
|
661 |
+
self.pre_frame_state = FrameState.kFrameStateSil
|
662 |
+
return AudioChangeState.kChangeStateSpeech2Sil
|
663 |
+
|
664 |
+
if self.pre_frame_state == FrameState.kFrameStateSil:
|
665 |
+
return AudioChangeState.kChangeStateSil2Sil
|
666 |
+
if self.pre_frame_state == FrameState.kFrameStateSpeech:
|
667 |
+
return AudioChangeState.kChangeStateSpeech2Speech
|
668 |
+
return AudioChangeState.kChangeStateInvalid
|
669 |
+
|
670 |
+
def frame_size_ms(self) -> int:
|
671 |
+
return int(self.frame_size_ms)
|
672 |
+
|
673 |
+
|
674 |
+
class E2EVadModel:
|
675 |
+
def __init__(self, config, vad_post_args: Dict[str, Any], root_dir: Path):
|
676 |
+
super(E2EVadModel, self).__init__()
|
677 |
+
self.vad_opts = VADXOptions(**vad_post_args)
|
678 |
+
self.windows_detector = WindowDetector(
|
679 |
+
self.vad_opts.window_size_ms,
|
680 |
+
self.vad_opts.sil_to_speech_time_thres,
|
681 |
+
self.vad_opts.speech_to_sil_time_thres,
|
682 |
+
self.vad_opts.frame_in_ms,
|
683 |
+
)
|
684 |
+
self.model = VadOrtInferRuntimeSession(config, root_dir)
|
685 |
+
self.all_reset_detection()
|
686 |
+
|
687 |
+
def all_reset_detection(self):
|
688 |
+
# init variables
|
689 |
+
self.is_final = False
|
690 |
+
self.data_buf_start_frame = 0
|
691 |
+
self.frm_cnt = 0
|
692 |
+
self.latest_confirmed_speech_frame = 0
|
693 |
+
self.lastest_confirmed_silence_frame = -1
|
694 |
+
self.continous_silence_frame_count = 0
|
695 |
+
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
696 |
+
self.confirmed_start_frame = -1
|
697 |
+
self.confirmed_end_frame = -1
|
698 |
+
self.number_end_time_detected = 0
|
699 |
+
self.sil_frame = 0
|
700 |
+
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
|
701 |
+
self.noise_average_decibel = -100.0
|
702 |
+
self.pre_end_silence_detected = False
|
703 |
+
self.next_seg = True
|
704 |
+
|
705 |
+
self.output_data_buf = []
|
706 |
+
self.output_data_buf_offset = 0
|
707 |
+
self.frame_probs = []
|
708 |
+
self.max_end_sil_frame_cnt_thresh = (
|
709 |
+
self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
|
710 |
+
)
|
711 |
+
self.speech_noise_thres = self.vad_opts.speech_noise_thres
|
712 |
+
self.scores = None
|
713 |
+
self.scores_offset = 0
|
714 |
+
self.max_time_out = False
|
715 |
+
self.decibel = []
|
716 |
+
self.decibel_offset = 0
|
717 |
+
self.data_buf_size = 0
|
718 |
+
self.data_buf_all_size = 0
|
719 |
+
self.waveform = None
|
720 |
+
self.reset_detection()
|
721 |
+
|
722 |
+
def reset_detection(self):
|
723 |
+
self.continous_silence_frame_count = 0
|
724 |
+
self.latest_confirmed_speech_frame = 0
|
725 |
+
self.lastest_confirmed_silence_frame = -1
|
726 |
+
self.confirmed_start_frame = -1
|
727 |
+
self.confirmed_end_frame = -1
|
728 |
+
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
|
729 |
+
self.windows_detector.reset()
|
730 |
+
self.sil_frame = 0
|
731 |
+
self.frame_probs = []
|
732 |
+
|
733 |
+
def compute_decibel(self) -> None:
|
734 |
+
frame_sample_length = int(
|
735 |
+
self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
|
736 |
+
)
|
737 |
+
frame_shift_length = int(
|
738 |
+
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
739 |
+
)
|
740 |
+
if self.data_buf_all_size == 0:
|
741 |
+
self.data_buf_all_size = len(self.waveform[0])
|
742 |
+
self.data_buf_size = self.data_buf_all_size
|
743 |
+
else:
|
744 |
+
self.data_buf_all_size += len(self.waveform[0])
|
745 |
+
|
746 |
+
for offset in range(
|
747 |
+
0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length
|
748 |
+
):
|
749 |
+
self.decibel.append(
|
750 |
+
10
|
751 |
+
* np.log10(
|
752 |
+
np.square(
|
753 |
+
self.waveform[0][offset : offset + frame_sample_length]
|
754 |
+
).sum()
|
755 |
+
+ 1e-6
|
756 |
+
)
|
757 |
+
)
|
758 |
+
|
759 |
+
def compute_scores(self, feats: np.ndarray) -> None:
|
760 |
+
scores = self.model(feats)
|
761 |
+
self.vad_opts.nn_eval_block_size = scores[0].shape[1]
|
762 |
+
self.frm_cnt += scores[0].shape[1] # count total frames
|
763 |
+
if isinstance(feats, list):
|
764 |
+
# return B * T * D
|
765 |
+
feats = feats[0]
|
766 |
+
|
767 |
+
assert (
|
768 |
+
scores[0].shape[1] == feats.shape[1]
|
769 |
+
), "The shape between feats and scores does not match"
|
770 |
+
|
771 |
+
self.scores = scores[0] # the first calculation
|
772 |
+
self.scores_offset += self.scores.shape[1]
|
773 |
+
|
774 |
+
return scores[1:]
|
775 |
+
|
776 |
+
def pop_data_buf_till_frame(self, frame_idx: int) -> None: # need check again
|
777 |
+
while self.data_buf_start_frame < frame_idx:
|
778 |
+
if self.data_buf_size >= int(
|
779 |
+
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
780 |
+
):
|
781 |
+
self.data_buf_start_frame += 1
|
782 |
+
self.data_buf_size = (
|
783 |
+
self.data_buf_all_size
|
784 |
+
- self.data_buf_start_frame
|
785 |
+
* int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
|
786 |
+
)
|
787 |
+
|
788 |
+
def pop_data_to_output_buf(
|
789 |
+
self,
|
790 |
+
start_frm: int,
|
791 |
+
frm_cnt: int,
|
792 |
+
first_frm_is_start_point: bool,
|
793 |
+
last_frm_is_end_point: bool,
|
794 |
+
end_point_is_sent_end: bool,
|
795 |
+
) -> None:
|
796 |
+
self.pop_data_buf_till_frame(start_frm)
|
797 |
+
expected_sample_number = int(
|
798 |
+
frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
|
799 |
+
)
|
800 |
+
if last_frm_is_end_point:
|
801 |
+
extra_sample = max(
|
802 |
+
0,
|
803 |
+
int(
|
804 |
+
self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000
|
805 |
+
- self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000
|
806 |
+
),
|
807 |
+
)
|
808 |
+
expected_sample_number += int(extra_sample)
|
809 |
+
if end_point_is_sent_end:
|
810 |
+
expected_sample_number = max(expected_sample_number, self.data_buf_size)
|
811 |
+
if self.data_buf_size < expected_sample_number:
|
812 |
+
logging.error("error in calling pop data_buf\n")
|
813 |
+
|
814 |
+
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
|
815 |
+
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
|
816 |
+
self.output_data_buf[-1].reset()
|
817 |
+
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
|
818 |
+
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
|
819 |
+
self.output_data_buf[-1].doa = 0
|
820 |
+
cur_seg = self.output_data_buf[-1]
|
821 |
+
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
|
822 |
+
logging.error("warning\n")
|
823 |
+
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
|
824 |
+
data_to_pop = 0
|
825 |
+
if end_point_is_sent_end:
|
826 |
+
data_to_pop = expected_sample_number
|
827 |
+
else:
|
828 |
+
data_to_pop = int(
|
829 |
+
frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000
|
830 |
+
)
|
831 |
+
if data_to_pop > self.data_buf_size:
|
832 |
+
logging.error("VAD data_to_pop is bigger than self.data_buf.size()!!!\n")
|
833 |
+
data_to_pop = self.data_buf_size
|
834 |
+
expected_sample_number = self.data_buf_size
|
835 |
+
|
836 |
+
cur_seg.doa = 0
|
837 |
+
for sample_cpy_out in range(0, data_to_pop):
|
838 |
+
# cur_seg.buffer[out_pos ++] = data_buf_.back();
|
839 |
+
out_pos += 1
|
840 |
+
for sample_cpy_out in range(data_to_pop, expected_sample_number):
|
841 |
+
# cur_seg.buffer[out_pos++] = data_buf_.back()
|
842 |
+
out_pos += 1
|
843 |
+
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
|
844 |
+
logging.error("Something wrong with the VAD algorithm\n")
|
845 |
+
self.data_buf_start_frame += frm_cnt
|
846 |
+
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
|
847 |
+
if first_frm_is_start_point:
|
848 |
+
cur_seg.contain_seg_start_point = True
|
849 |
+
if last_frm_is_end_point:
|
850 |
+
cur_seg.contain_seg_end_point = True
|
851 |
+
|
852 |
+
def on_silence_detected(self, valid_frame: int):
|
853 |
+
self.lastest_confirmed_silence_frame = valid_frame
|
854 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
|
855 |
+
self.pop_data_buf_till_frame(valid_frame)
|
856 |
+
# silence_detected_callback_
|
857 |
+
# pass
|
858 |
+
|
859 |
+
def on_voice_detected(self, valid_frame: int) -> None:
|
860 |
+
self.latest_confirmed_speech_frame = valid_frame
|
861 |
+
self.pop_data_to_output_buf(valid_frame, 1, False, False, False)
|
862 |
+
|
863 |
+
def on_voice_start(self, start_frame: int, fake_result: bool = False) -> None:
|
864 |
+
if self.vad_opts.do_start_point_detection:
|
865 |
+
pass
|
866 |
+
if self.confirmed_start_frame != -1:
|
867 |
+
logging.error("not reset vad properly\n")
|
868 |
+
else:
|
869 |
+
self.confirmed_start_frame = start_frame
|
870 |
+
|
871 |
+
if (
|
872 |
+
not fake_result
|
873 |
+
and self.vad_state_machine
|
874 |
+
== VadStateMachine.kVadInStateStartPointNotDetected
|
875 |
+
):
|
876 |
+
self.pop_data_to_output_buf(
|
877 |
+
self.confirmed_start_frame, 1, True, False, False
|
878 |
+
)
|
879 |
+
|
880 |
+
def on_voice_end(
|
881 |
+
self, end_frame: int, fake_result: bool, is_last_frame: bool
|
882 |
+
) -> None:
|
883 |
+
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
|
884 |
+
self.on_voice_detected(t)
|
885 |
+
if self.vad_opts.do_end_point_detection:
|
886 |
+
pass
|
887 |
+
if self.confirmed_end_frame != -1:
|
888 |
+
logging.error("not reset vad properly\n")
|
889 |
+
else:
|
890 |
+
self.confirmed_end_frame = end_frame
|
891 |
+
if not fake_result:
|
892 |
+
self.sil_frame = 0
|
893 |
+
self.pop_data_to_output_buf(
|
894 |
+
self.confirmed_end_frame, 1, False, True, is_last_frame
|
895 |
+
)
|
896 |
+
self.number_end_time_detected += 1
|
897 |
+
|
898 |
+
def maybe_on_voice_end_last_frame(
|
899 |
+
self, is_final_frame: bool, cur_frm_idx: int
|
900 |
+
) -> None:
|
901 |
+
if is_final_frame:
|
902 |
+
self.on_voice_end(cur_frm_idx, False, True)
|
903 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
904 |
+
|
905 |
+
def get_latency(self) -> int:
|
906 |
+
return int(self.latency_frm_num_at_start_point() * self.vad_opts.frame_in_ms)
|
907 |
+
|
908 |
+
def latency_frm_num_at_start_point(self) -> int:
|
909 |
+
vad_latency = self.windows_detector.get_win_size()
|
910 |
+
if self.vad_opts.do_extend:
|
911 |
+
vad_latency += int(
|
912 |
+
self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms
|
913 |
+
)
|
914 |
+
return vad_latency
|
915 |
+
|
916 |
+
def get_frame_state(self, t: int) -> FrameState:
|
917 |
+
frame_state = FrameState.kFrameStateInvalid
|
918 |
+
cur_decibel = self.decibel[t - self.decibel_offset]
|
919 |
+
cur_snr = cur_decibel - self.noise_average_decibel
|
920 |
+
# for each frame, calc log posterior probability of each state
|
921 |
+
if cur_decibel < self.vad_opts.decibel_thres:
|
922 |
+
frame_state = FrameState.kFrameStateSil
|
923 |
+
self.detect_one_frame(frame_state, t, False)
|
924 |
+
return frame_state
|
925 |
+
|
926 |
+
sum_score = 0.0
|
927 |
+
noise_prob = 0.0
|
928 |
+
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
|
929 |
+
if len(self.sil_pdf_ids) > 0:
|
930 |
+
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
|
931 |
+
sil_pdf_scores = [
|
932 |
+
self.scores[0][t - self.scores_offset][sil_pdf_id]
|
933 |
+
for sil_pdf_id in self.sil_pdf_ids
|
934 |
+
]
|
935 |
+
sum_score = sum(sil_pdf_scores)
|
936 |
+
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
|
937 |
+
total_score = 1.0
|
938 |
+
sum_score = total_score - sum_score
|
939 |
+
speech_prob = math.log(sum_score)
|
940 |
+
if self.vad_opts.output_frame_probs:
|
941 |
+
frame_prob = E2EVadFrameProb()
|
942 |
+
frame_prob.noise_prob = noise_prob
|
943 |
+
frame_prob.speech_prob = speech_prob
|
944 |
+
frame_prob.score = sum_score
|
945 |
+
frame_prob.frame_id = t
|
946 |
+
self.frame_probs.append(frame_prob)
|
947 |
+
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
|
948 |
+
if (
|
949 |
+
cur_snr >= self.vad_opts.snr_thres
|
950 |
+
and cur_decibel >= self.vad_opts.decibel_thres
|
951 |
+
):
|
952 |
+
frame_state = FrameState.kFrameStateSpeech
|
953 |
+
else:
|
954 |
+
frame_state = FrameState.kFrameStateSil
|
955 |
+
else:
|
956 |
+
frame_state = FrameState.kFrameStateSil
|
957 |
+
if self.noise_average_decibel < -99.9:
|
958 |
+
self.noise_average_decibel = cur_decibel
|
959 |
+
else:
|
960 |
+
self.noise_average_decibel = (
|
961 |
+
cur_decibel
|
962 |
+
+ self.noise_average_decibel
|
963 |
+
* (self.vad_opts.noise_frame_num_used_for_snr - 1)
|
964 |
+
) / self.vad_opts.noise_frame_num_used_for_snr
|
965 |
+
|
966 |
+
return frame_state
|
967 |
+
|
968 |
+
def infer_offline(
|
969 |
+
self,
|
970 |
+
feats: np.ndarray,
|
971 |
+
waveform: np.ndarray,
|
972 |
+
in_cache: Dict[str, np.ndarray] = dict(),
|
973 |
+
is_final: bool = False,
|
974 |
+
) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
|
975 |
+
self.waveform = waveform
|
976 |
+
self.compute_decibel()
|
977 |
+
|
978 |
+
self.compute_scores(feats)
|
979 |
+
if not is_final:
|
980 |
+
self.detect_common_frames()
|
981 |
+
else:
|
982 |
+
self.detect_last_frames()
|
983 |
+
segments = []
|
984 |
+
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
|
985 |
+
segment_batch = []
|
986 |
+
if len(self.output_data_buf) > 0:
|
987 |
+
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
988 |
+
if (
|
989 |
+
not self.output_data_buf[i].contain_seg_start_point
|
990 |
+
or not self.output_data_buf[i].contain_seg_end_point
|
991 |
+
):
|
992 |
+
continue
|
993 |
+
segment = [
|
994 |
+
self.output_data_buf[i].start_ms,
|
995 |
+
self.output_data_buf[i].end_ms,
|
996 |
+
]
|
997 |
+
segment_batch.append(segment)
|
998 |
+
self.output_data_buf_offset += 1 # need update this parameter
|
999 |
+
if segment_batch:
|
1000 |
+
segments.append(segment_batch)
|
1001 |
+
|
1002 |
+
if is_final:
|
1003 |
+
# reset class variables and clear the dict for the next query
|
1004 |
+
self.all_reset_detection()
|
1005 |
+
return segments, in_cache
|
1006 |
+
|
1007 |
+
def infer_online(
|
1008 |
+
self,
|
1009 |
+
feats: np.ndarray,
|
1010 |
+
waveform: np.ndarray,
|
1011 |
+
in_cache: list = None,
|
1012 |
+
is_final: bool = False,
|
1013 |
+
max_end_sil: int = 800,
|
1014 |
+
) -> Tuple[List[List[List[int]]], Dict[str, np.ndarray]]:
|
1015 |
+
feats = [feats]
|
1016 |
+
if in_cache is None:
|
1017 |
+
in_cache = []
|
1018 |
+
|
1019 |
+
self.max_end_sil_frame_cnt_thresh = (
|
1020 |
+
max_end_sil - self.vad_opts.speech_to_sil_time_thres
|
1021 |
+
)
|
1022 |
+
self.waveform = waveform # compute decibel for each frame
|
1023 |
+
feats.extend(in_cache)
|
1024 |
+
in_cache = self.compute_scores(feats)
|
1025 |
+
self.compute_decibel()
|
1026 |
+
|
1027 |
+
if is_final:
|
1028 |
+
self.detect_last_frames()
|
1029 |
+
else:
|
1030 |
+
self.detect_common_frames()
|
1031 |
+
|
1032 |
+
segments = []
|
1033 |
+
# only support batch_size = 1 now
|
1034 |
+
for batch_num in range(0, feats[0].shape[0]):
|
1035 |
+
if len(self.output_data_buf) > 0:
|
1036 |
+
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
|
1037 |
+
if not self.output_data_buf[i].contain_seg_start_point:
|
1038 |
+
continue
|
1039 |
+
if (
|
1040 |
+
not self.next_seg
|
1041 |
+
and not self.output_data_buf[i].contain_seg_end_point
|
1042 |
+
):
|
1043 |
+
continue
|
1044 |
+
start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
|
1045 |
+
if self.output_data_buf[i].contain_seg_end_point:
|
1046 |
+
end_ms = self.output_data_buf[i].end_ms
|
1047 |
+
self.next_seg = True
|
1048 |
+
self.output_data_buf_offset += 1
|
1049 |
+
else:
|
1050 |
+
end_ms = -1
|
1051 |
+
self.next_seg = False
|
1052 |
+
segments.append([start_ms, end_ms])
|
1053 |
+
|
1054 |
+
return segments, in_cache
|
1055 |
+
|
1056 |
+
def get_frames_state(
|
1057 |
+
self,
|
1058 |
+
feats: np.ndarray,
|
1059 |
+
waveform: np.ndarray,
|
1060 |
+
in_cache: list = None,
|
1061 |
+
is_final: bool = False,
|
1062 |
+
max_end_sil: int = 800,
|
1063 |
+
):
|
1064 |
+
feats = [feats]
|
1065 |
+
states = []
|
1066 |
+
if in_cache is None:
|
1067 |
+
in_cache = []
|
1068 |
+
|
1069 |
+
self.max_end_sil_frame_cnt_thresh = (
|
1070 |
+
max_end_sil - self.vad_opts.speech_to_sil_time_thres
|
1071 |
+
)
|
1072 |
+
self.waveform = waveform # compute decibel for each frame
|
1073 |
+
feats.extend(in_cache)
|
1074 |
+
in_cache = self.compute_scores(feats)
|
1075 |
+
self.compute_decibel()
|
1076 |
+
|
1077 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
1078 |
+
return states
|
1079 |
+
|
1080 |
+
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
1081 |
+
frame_state = FrameState.kFrameStateInvalid
|
1082 |
+
frame_state = self.get_frame_state(self.frm_cnt - 1 - i)
|
1083 |
+
states.append(frame_state)
|
1084 |
+
if i == 0 and is_final:
|
1085 |
+
logging.info("last frame detected")
|
1086 |
+
self.detect_one_frame(frame_state, self.frm_cnt - 1, True)
|
1087 |
+
else:
|
1088 |
+
self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False)
|
1089 |
+
|
1090 |
+
return states
|
1091 |
+
|
1092 |
+
def detect_common_frames(self) -> int:
|
1093 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
1094 |
+
return 0
|
1095 |
+
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
1096 |
+
frame_state = FrameState.kFrameStateInvalid
|
1097 |
+
frame_state = self.get_frame_state(self.frm_cnt - 1 - i)
|
1098 |
+
# print(f"cur frame: {self.frm_cnt - 1 - i}, state is {frame_state}")
|
1099 |
+
self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False)
|
1100 |
+
|
1101 |
+
self.decibel = self.decibel[self.vad_opts.nn_eval_block_size - 1 :]
|
1102 |
+
self.decibel_offset = self.frm_cnt - 1 - i
|
1103 |
+
return 0
|
1104 |
+
|
1105 |
+
def detect_last_frames(self) -> int:
|
1106 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
|
1107 |
+
return 0
|
1108 |
+
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
|
1109 |
+
frame_state = FrameState.kFrameStateInvalid
|
1110 |
+
frame_state = self.get_frame_state(self.frm_cnt - 1 - i)
|
1111 |
+
if i != 0:
|
1112 |
+
self.detect_one_frame(frame_state, self.frm_cnt - 1 - i, False)
|
1113 |
+
else:
|
1114 |
+
self.detect_one_frame(frame_state, self.frm_cnt - 1, True)
|
1115 |
+
|
1116 |
+
return 0
|
1117 |
+
|
1118 |
+
def detect_one_frame(
|
1119 |
+
self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool
|
1120 |
+
) -> None:
|
1121 |
+
tmp_cur_frm_state = FrameState.kFrameStateInvalid
|
1122 |
+
if cur_frm_state == FrameState.kFrameStateSpeech:
|
1123 |
+
if math.fabs(1.0) > float(self.vad_opts.fe_prior_thres):
|
1124 |
+
tmp_cur_frm_state = FrameState.kFrameStateSpeech
|
1125 |
+
else:
|
1126 |
+
tmp_cur_frm_state = FrameState.kFrameStateSil
|
1127 |
+
elif cur_frm_state == FrameState.kFrameStateSil:
|
1128 |
+
tmp_cur_frm_state = FrameState.kFrameStateSil
|
1129 |
+
state_change = self.windows_detector.detect_one_frame(
|
1130 |
+
tmp_cur_frm_state, cur_frm_idx
|
1131 |
+
)
|
1132 |
+
frm_shift_in_ms = self.vad_opts.frame_in_ms
|
1133 |
+
if AudioChangeState.kChangeStateSil2Speech == state_change:
|
1134 |
+
self.continous_silence_frame_count = 0
|
1135 |
+
self.pre_end_silence_detected = False
|
1136 |
+
|
1137 |
+
if (
|
1138 |
+
self.vad_state_machine
|
1139 |
+
== VadStateMachine.kVadInStateStartPointNotDetected
|
1140 |
+
):
|
1141 |
+
start_frame = max(
|
1142 |
+
self.data_buf_start_frame,
|
1143 |
+
cur_frm_idx - self.latency_frm_num_at_start_point(),
|
1144 |
+
)
|
1145 |
+
self.on_voice_start(start_frame)
|
1146 |
+
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
|
1147 |
+
for t in range(start_frame + 1, cur_frm_idx + 1):
|
1148 |
+
self.on_voice_detected(t)
|
1149 |
+
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
1150 |
+
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
|
1151 |
+
self.on_voice_detected(t)
|
1152 |
+
if (
|
1153 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
1154 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
1155 |
+
):
|
1156 |
+
self.on_voice_end(cur_frm_idx, False, False)
|
1157 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
1158 |
+
elif not is_final_frame:
|
1159 |
+
self.on_voice_detected(cur_frm_idx)
|
1160 |
+
else:
|
1161 |
+
self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
|
1162 |
+
else:
|
1163 |
+
pass
|
1164 |
+
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
|
1165 |
+
self.continous_silence_frame_count = 0
|
1166 |
+
if (
|
1167 |
+
self.vad_state_machine
|
1168 |
+
== VadStateMachine.kVadInStateStartPointNotDetected
|
1169 |
+
):
|
1170 |
+
pass
|
1171 |
+
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
1172 |
+
if (
|
1173 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
1174 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
1175 |
+
):
|
1176 |
+
self.on_voice_end(cur_frm_idx, False, False)
|
1177 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
1178 |
+
elif not is_final_frame:
|
1179 |
+
self.on_voice_detected(cur_frm_idx)
|
1180 |
+
else:
|
1181 |
+
self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
|
1182 |
+
else:
|
1183 |
+
pass
|
1184 |
+
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
|
1185 |
+
self.continous_silence_frame_count = 0
|
1186 |
+
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
1187 |
+
if (
|
1188 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
1189 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
1190 |
+
):
|
1191 |
+
self.max_time_out = True
|
1192 |
+
self.on_voice_end(cur_frm_idx, False, False)
|
1193 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
1194 |
+
elif not is_final_frame:
|
1195 |
+
self.on_voice_detected(cur_frm_idx)
|
1196 |
+
else:
|
1197 |
+
self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
|
1198 |
+
else:
|
1199 |
+
pass
|
1200 |
+
elif AudioChangeState.kChangeStateSil2Sil == state_change:
|
1201 |
+
self.continous_silence_frame_count += 1
|
1202 |
+
if (
|
1203 |
+
self.vad_state_machine
|
1204 |
+
== VadStateMachine.kVadInStateStartPointNotDetected
|
1205 |
+
):
|
1206 |
+
# silence timeout, return zero length decision
|
1207 |
+
if (
|
1208 |
+
(
|
1209 |
+
self.vad_opts.detect_mode
|
1210 |
+
== VadDetectMode.kVadSingleUtteranceDetectMode.value
|
1211 |
+
)
|
1212 |
+
and (
|
1213 |
+
self.continous_silence_frame_count * frm_shift_in_ms
|
1214 |
+
> self.vad_opts.max_start_silence_time
|
1215 |
+
)
|
1216 |
+
) or (is_final_frame and self.number_end_time_detected == 0):
|
1217 |
+
for t in range(
|
1218 |
+
self.lastest_confirmed_silence_frame + 1, cur_frm_idx
|
1219 |
+
):
|
1220 |
+
self.on_silence_detected(t)
|
1221 |
+
self.on_voice_start(0, True)
|
1222 |
+
self.on_voice_end(0, True, False)
|
1223 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
1224 |
+
else:
|
1225 |
+
if cur_frm_idx >= self.latency_frm_num_at_start_point():
|
1226 |
+
self.on_silence_detected(
|
1227 |
+
cur_frm_idx - self.latency_frm_num_at_start_point()
|
1228 |
+
)
|
1229 |
+
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
|
1230 |
+
if (
|
1231 |
+
self.continous_silence_frame_count * frm_shift_in_ms
|
1232 |
+
>= self.max_end_sil_frame_cnt_thresh
|
1233 |
+
):
|
1234 |
+
lookback_frame = int(
|
1235 |
+
self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms
|
1236 |
+
)
|
1237 |
+
if self.vad_opts.do_extend:
|
1238 |
+
lookback_frame -= int(
|
1239 |
+
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
|
1240 |
+
)
|
1241 |
+
lookback_frame -= 1
|
1242 |
+
lookback_frame = max(0, lookback_frame)
|
1243 |
+
self.on_voice_end(cur_frm_idx - lookback_frame, False, False)
|
1244 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
1245 |
+
elif (
|
1246 |
+
cur_frm_idx - self.confirmed_start_frame + 1
|
1247 |
+
> self.vad_opts.max_single_segment_time / frm_shift_in_ms
|
1248 |
+
):
|
1249 |
+
self.on_voice_end(cur_frm_idx, False, False)
|
1250 |
+
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
|
1251 |
+
elif self.vad_opts.do_extend and not is_final_frame:
|
1252 |
+
if self.continous_silence_frame_count <= int(
|
1253 |
+
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms
|
1254 |
+
):
|
1255 |
+
self.on_voice_detected(cur_frm_idx)
|
1256 |
+
else:
|
1257 |
+
self.maybe_on_voice_end_last_frame(is_final_frame, cur_frm_idx)
|
1258 |
+
else:
|
1259 |
+
pass
|
1260 |
+
|
1261 |
+
if (
|
1262 |
+
self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected
|
1263 |
+
and self.vad_opts.detect_mode
|
1264 |
+
== VadDetectMode.kVadMutipleUtteranceDetectMode.value
|
1265 |
+
):
|
1266 |
+
self.reset_detection()
|
1267 |
+
|
1268 |
+
|
1269 |
+
class FSMNVad(object):
|
1270 |
+
def __init__(self, config_dir: str):
|
1271 |
+
config_dir = Path(config_dir)
|
1272 |
+
self.config = read_yaml(config_dir / "fsmn-config.yaml")
|
1273 |
+
self.frontend = WavFrontend(
|
1274 |
+
cmvn_file=config_dir / "fsmn-am.mvn",
|
1275 |
+
**self.config["WavFrontend"]["frontend_conf"],
|
1276 |
+
)
|
1277 |
+
self.config["FSMN"]["model_path"] = config_dir / "fsmnvad-offline.onnx"
|
1278 |
+
|
1279 |
+
self.vad = E2EVadModel(
|
1280 |
+
self.config["FSMN"], self.config["vadPostArgs"], config_dir
|
1281 |
+
)
|
1282 |
+
|
1283 |
+
def set_parameters(self, mode):
|
1284 |
+
pass
|
1285 |
+
|
1286 |
+
def extract_feature(self, waveform):
|
1287 |
+
fbank, _ = self.frontend.fbank(waveform)
|
1288 |
+
feats, feats_len = self.frontend.lfr_cmvn(fbank)
|
1289 |
+
return feats.astype(np.float32), feats_len
|
1290 |
+
|
1291 |
+
def is_speech(self, buf, sample_rate=16000):
|
1292 |
+
assert sample_rate == 16000, "only support 16k sample rate"
|
1293 |
+
|
1294 |
+
def segments_offline(self, waveform_path: Union[str, Path, np.ndarray]):
|
1295 |
+
"""get sements of audio"""
|
1296 |
+
|
1297 |
+
if isinstance(waveform_path, np.ndarray):
|
1298 |
+
waveform = waveform_path
|
1299 |
+
else:
|
1300 |
+
if not os.path.exists(waveform_path):
|
1301 |
+
raise FileExistsError(f"{waveform_path} is not exist.")
|
1302 |
+
if os.path.isfile(waveform_path):
|
1303 |
+
logging.info(f"load audio {waveform_path}")
|
1304 |
+
waveform, _sample_rate = sf.read(
|
1305 |
+
waveform_path,
|
1306 |
+
dtype="float32",
|
1307 |
+
)
|
1308 |
+
else:
|
1309 |
+
raise FileNotFoundError(str(Path))
|
1310 |
+
assert (
|
1311 |
+
_sample_rate == 16000
|
1312 |
+
), f"only support 16k sample rate, current sample rate is {_sample_rate}"
|
1313 |
+
|
1314 |
+
feats, feats_len = self.extract_feature(waveform)
|
1315 |
+
waveform = waveform[None, ...]
|
1316 |
+
segments_part, in_cache = self.vad.infer_offline(
|
1317 |
+
feats[None, ...], waveform, is_final=True
|
1318 |
+
)
|
1319 |
+
return segments_part[0]
|
1320 |
+
|
1321 |
+
# ```
|
1322 |
+
|
1323 |
+
# File: sense_voice.py
|
1324 |
+
# ```py
|
1325 |
+
# -*- coding:utf-8 -*-
|
1326 |
+
# @FileName :sense_voice.py.py
|
1327 |
+
# @Time :2024/7/18 15:40
|
1328 |
+
# @Author :lovemefan
|
1329 |
+
# @Email :lovemefan@outlook.com
|
1330 |
+
|
1331 |
+
languages = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
|
1332 |
+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
1333 |
+
logging.basicConfig(format=formatter, level=logging.INFO)
|
1334 |
+
|
1335 |
+
def main():
|
1336 |
+
arg_parser = argparse.ArgumentParser(description="Sense Voice")
|
1337 |
+
arg_parser.add_argument("-a", "--audio_file", required=True, type=str, help="Model")
|
1338 |
+
download_model_path = os.path.dirname(__file__)
|
1339 |
+
arg_parser.add_argument(
|
1340 |
+
"-dp",
|
1341 |
+
"--download_path",
|
1342 |
+
default=download_model_path,
|
1343 |
+
type=str,
|
1344 |
+
help="dir path of resource downloaded",
|
1345 |
+
)
|
1346 |
+
arg_parser.add_argument("-d", "--device", default=-1, type=int, help="Device")
|
1347 |
+
arg_parser.add_argument(
|
1348 |
+
"-n", "--num_threads", default=4, type=int, help="Num threads"
|
1349 |
+
)
|
1350 |
+
arg_parser.add_argument(
|
1351 |
+
"-l",
|
1352 |
+
"--language",
|
1353 |
+
choices=languages.keys(),
|
1354 |
+
default="auto",
|
1355 |
+
type=str,
|
1356 |
+
help="Language",
|
1357 |
+
)
|
1358 |
+
arg_parser.add_argument("--use_itn", action="store_true", help="Use ITN")
|
1359 |
+
args = arg_parser.parse_args()
|
1360 |
+
|
1361 |
+
front = WavFrontend(os.path.join(download_model_path, "am.mvn"))
|
1362 |
+
|
1363 |
+
model = SenseVoiceInferenceSession(
|
1364 |
+
os.path.join(download_model_path, "embedding.npy"),
|
1365 |
+
os.path.join(
|
1366 |
+
download_model_path,
|
1367 |
+
"sense-voice-encoder.rknn",
|
1368 |
+
),
|
1369 |
+
os.path.join(download_model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"),
|
1370 |
+
args.device,
|
1371 |
+
args.num_threads,
|
1372 |
+
)
|
1373 |
+
waveform, _sample_rate = sf.read(
|
1374 |
+
args.audio_file,
|
1375 |
+
dtype="float32",
|
1376 |
+
always_2d=True
|
1377 |
+
)
|
1378 |
+
|
1379 |
+
logging.info(f"Audio {args.audio_file} is {len(waveform) / _sample_rate} seconds, {waveform.shape[1]} channel")
|
1380 |
+
# load vad model
|
1381 |
+
start = time.time()
|
1382 |
+
vad = FSMNVad(download_model_path)
|
1383 |
+
for channel_id, channel_data in enumerate(waveform.T):
|
1384 |
+
segments = vad.segments_offline(channel_data)
|
1385 |
+
results = ""
|
1386 |
+
for part in segments:
|
1387 |
+
audio_feats = front.get_features(channel_data[part[0] * 16 : part[1] * 16])
|
1388 |
+
asr_result = model(
|
1389 |
+
audio_feats[None, ...],
|
1390 |
+
language=languages[args.language],
|
1391 |
+
use_itn=args.use_itn,
|
1392 |
+
)
|
1393 |
+
logging.info(f"[Channel {channel_id}] [{part[0] / 1000}s - {part[1] / 1000}s] {asr_result}")
|
1394 |
+
vad.vad.all_reset_detection()
|
1395 |
+
decoding_time = time.time() - start
|
1396 |
+
logging.info(f"Decoder audio takes {decoding_time} seconds")
|
1397 |
+
logging.info(f"The RTF is {decoding_time/(waveform.shape[1] * len(waveform) / _sample_rate)}.")
|
1398 |
+
|
1399 |
+
|
1400 |
+
if __name__ == "__main__":
|
1401 |
+
main()
|
1402 |
+
|