Fazhong Liu
commited on
Commit
·
9a70c5d
1
Parent(s):
53d9d2d
init
Browse files- .gitattributes +35 -35
- .gitignore +1 -0
- app.py +73 -0
- src/classifier.py +97 -0
- src/direction_detection.py +248 -0
- src/generate_array_feature.py +235 -0
- src/main_functions.py +153 -0
- src/main_gui.py +218 -0
- translate/cmd_judge.py +227 -0
- translate/test.py +414 -0
- translate/train_man.py +160 -0
- translate/train_name.py +162 -0
- translate/wav2com.py +0 -0
- translate/wav2npy.py +118 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# *.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
# *.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
# *.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
# *.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
# *.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
# *.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
# *.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
# *.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
# *.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
# *.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
# *.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
# *.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
# *.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
# *.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
# *.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
# *.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
# *.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
# *.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
# *.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
# *.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
# *.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
# *.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
# *.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
# *.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
# *.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
# saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
# *.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
# *.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
# *.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
# *.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
# *.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
# *.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
# *.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
# *.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
# *tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.wav
|
app.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import whisper
|
2 |
+
from pydub import AudioSegment
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
def convert_6ch_wav_to_stereo(input_file_path, output_file_path):
|
6 |
+
sound = AudioSegment.from_file(input_file_path, format="wav")
|
7 |
+
if sound.channels != 6:
|
8 |
+
raise ValueError("The input file does not have 6 channels.")
|
9 |
+
front_left = sound.split_to_mono()[0]
|
10 |
+
front_right = sound.split_to_mono()[1]
|
11 |
+
center = sound.split_to_mono()[2]
|
12 |
+
back_left = sound.split_to_mono()[4]
|
13 |
+
back_right = sound.split_to_mono()[5]
|
14 |
+
center = center - 6
|
15 |
+
back_left = back_left - 6
|
16 |
+
back_right = back_right - 6
|
17 |
+
stereo_left = front_left.overlay(center).overlay(back_left)
|
18 |
+
stereo_right = front_right.overlay(center).overlay(back_right)
|
19 |
+
stereo_sound = AudioSegment.from_mono_audiosegments(stereo_left, stereo_right)
|
20 |
+
stereo_sound.export(output_file_path, format="wav")
|
21 |
+
|
22 |
+
|
23 |
+
def judge_command(file_path):
|
24 |
+
whisper_model = whisper.load_model("large", device="cpu")
|
25 |
+
out_path='./out.wav'
|
26 |
+
convert_6ch_wav_to_stereo(file_path,out_path)
|
27 |
+
result = whisper_model.transcribe(out_path,language="en")
|
28 |
+
text_result = result['text']
|
29 |
+
print(text_result)
|
30 |
+
return text_result
|
31 |
+
|
32 |
+
|
33 |
+
def handle_audio_transcription(file_path):
|
34 |
+
try:
|
35 |
+
text_result = judge_command(file_path)
|
36 |
+
message = "Transcription successful!"
|
37 |
+
except Exception as e:
|
38 |
+
message = str(e)
|
39 |
+
text_result = ""
|
40 |
+
return message, text_result
|
41 |
+
|
42 |
+
with gr.Blocks() as audio_transcription_page:
|
43 |
+
|
44 |
+
gr.Markdown(
|
45 |
+
'''
|
46 |
+
This space transcribes the spoken words from an audio file to text.
|
47 |
+
## How to use this Space?
|
48 |
+
- Upload a '.wav' file.
|
49 |
+
- The transcription of the audio will be shown after you click the transcribe button.
|
50 |
+
'''
|
51 |
+
)
|
52 |
+
|
53 |
+
with gr.Row():
|
54 |
+
with gr.Column():
|
55 |
+
audio_file = gr.File(
|
56 |
+
file_types=[".wav"],
|
57 |
+
label="Upload a '.wav' file",
|
58 |
+
)
|
59 |
+
info = gr.Textbox(
|
60 |
+
value="",
|
61 |
+
label="Log",
|
62 |
+
placeholder="Transcription results will appear here...",
|
63 |
+
)
|
64 |
+
transcribe_button = gr.Button("Transcribe")
|
65 |
+
|
66 |
+
transcribe_button.click(
|
67 |
+
handle_audio_transcription,
|
68 |
+
[audio_file],
|
69 |
+
[info]
|
70 |
+
)
|
71 |
+
|
72 |
+
if __name__ == "__main__":
|
73 |
+
audio_transcription_page.launch(debug=True)
|
src/classifier.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
这是带注释的,我用中文写了
|
3 |
+
"""
|
4 |
+
#%% 导入必要的包
|
5 |
+
import numpy as np
|
6 |
+
from tensorflow.keras.models import Sequential
|
7 |
+
from tensorflow.keras.layers import Dense, Dropout
|
8 |
+
from tensorflow.keras.losses import binary_crossentropy
|
9 |
+
from tensorflow.keras.optimizers import Adam
|
10 |
+
from sklearn.metrics import roc_curve
|
11 |
+
from scipy.interpolate import interp1d
|
12 |
+
from scipy.optimize import brentq
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
from scipy.io.wavfile import read
|
15 |
+
from sklearn.preprocessing import normalize
|
16 |
+
from generate_array_feature import mald_feature, get_filelist
|
17 |
+
import time
|
18 |
+
|
19 |
+
|
20 |
+
#%% 定义分类器model
|
21 |
+
# 这一个代码块是用来定义model的。
|
22 |
+
# 定义model的batch_size, feature长度之类的
|
23 |
+
batch_size = 10
|
24 |
+
feature_len = 110
|
25 |
+
loss_function = binary_crossentropy
|
26 |
+
no_epochs = 150
|
27 |
+
optimizer = Adam()
|
28 |
+
verbosity = 1
|
29 |
+
model = Sequential()
|
30 |
+
model.add(Dense(64, input_dim=feature_len, activation='relu'))
|
31 |
+
model.add(Dropout(0.2))
|
32 |
+
model.add(Dense(32, activation='relu'))
|
33 |
+
model.add(Dropout(0.2))
|
34 |
+
model.add(Dense(16, activation='relu'))
|
35 |
+
model.add(Dense(1, activation='sigmoid'))
|
36 |
+
model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])
|
37 |
+
# 至此,分类器模型的基本参数已经设置完毕,接下来可以从hdf5文件中导入预先训练好的model
|
38 |
+
model.load_weights(r"/home/fazhong/Github/czx/model.hdf5")
|
39 |
+
# 从train2.hdf5导入model。
|
40 |
+
# train2.hdf5 是从 data2.npy训练来的。
|
41 |
+
# 这样与 data1.npy数据不会有重叠
|
42 |
+
|
43 |
+
|
44 |
+
#%% 导入音频
|
45 |
+
|
46 |
+
|
47 |
+
data_npy = np.load('./data.npy',allow_pickle=True)
|
48 |
+
labels_npy = np.load('./labels.npy',allow_pickle=True)
|
49 |
+
|
50 |
+
data = data_npy.tolist()
|
51 |
+
labels_org = labels_npy.tolist()
|
52 |
+
labels = []
|
53 |
+
for x in labels_org:
|
54 |
+
labels.append(x[0])
|
55 |
+
|
56 |
+
|
57 |
+
voice = []
|
58 |
+
# voice 是从 一堆 wav 音频文件中提取的波形
|
59 |
+
X = [] # X is the feature ~ data[0]
|
60 |
+
y = [] # y is the normal (1) or attack (0) ~ data[1]
|
61 |
+
|
62 |
+
# for file_path in name_all:
|
63 |
+
# file_name = file_path.split("\\")[-1]
|
64 |
+
# # define the normal or attack in variable cur_y
|
65 |
+
# if 'normal' in file_name:
|
66 |
+
# cur_y = 1 # normal case
|
67 |
+
# elif 'attack' in file_name:
|
68 |
+
# cur_y = 0
|
69 |
+
# # split the file name
|
70 |
+
# # read the data
|
71 |
+
# rate, data = read(file_path)
|
72 |
+
# voice += [list(data)]
|
73 |
+
|
74 |
+
# X += [list(mald_feature(rate, data))]
|
75 |
+
# print(list(mald_feature(rate, data)))
|
76 |
+
# # 从wav 文件提取特征的函数是 generate_array_feature.py
|
77 |
+
# # X 是特征,特征的维度是110维
|
78 |
+
# y += [cur_y]
|
79 |
+
# # y是标签,1代表正常样本,0代表攻击样本
|
80 |
+
|
81 |
+
|
82 |
+
X = data
|
83 |
+
Y = labels
|
84 |
+
# normalization
|
85 |
+
norm_X = normalize(X, axis=0, norm='max')
|
86 |
+
|
87 |
+
X = np.asarray(norm_X)
|
88 |
+
y = np.asarray(y)
|
89 |
+
|
90 |
+
#%% 开始预测
|
91 |
+
scores = model.evaluate(X, y) # 这是一个总体的预测
|
92 |
+
y_pred = np.round(model.predict(X)) # 这里会给出一个预测的结论
|
93 |
+
print(y_pred)
|
94 |
+
acc = 0
|
95 |
+
for i in range(len(y)):
|
96 |
+
if y_pred[i] == y: acc+=1
|
97 |
+
print(acc/len(y))
|
src/direction_detection.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This python script is used for direction detection.
|
3 |
+
We design the direction detection for 3 wav files types which has
|
4 |
+
4, 6 and 8 channels.
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from scipy.io import wavfile
|
9 |
+
from scipy.signal import butter, lfilter, freqz
|
10 |
+
from scipy import signal
|
11 |
+
import matplotlib
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
|
14 |
+
offsetVector = []
|
15 |
+
|
16 |
+
"""
|
17 |
+
The funtion butter_highpass, butter_highpass_filter and calculateResidues are shared
|
18 |
+
Function offset is used for getAngle_for_eight
|
19 |
+
"""
|
20 |
+
def butter_highpass(cutoff, fs, order=5):
|
21 |
+
nyq = 0.5 * fs
|
22 |
+
normal_cutoff = cutoff / nyq
|
23 |
+
b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
|
24 |
+
return b, a
|
25 |
+
|
26 |
+
|
27 |
+
def butter_highpass_filter(data, cutoff, fs, order=5):
|
28 |
+
b, a = butter_highpass(cutoff, fs, order=order)
|
29 |
+
y = signal.filtfilt(b, a, data)
|
30 |
+
return y
|
31 |
+
|
32 |
+
|
33 |
+
def calculateResidues(Chan1, Chan2, fs):
|
34 |
+
S1 = butter_highpass_filter(Chan1, 100, fs, 7)
|
35 |
+
S2 = butter_highpass_filter(Chan2, 100, fs, 7)
|
36 |
+
|
37 |
+
index1 = -1
|
38 |
+
index2 = -1
|
39 |
+
index = -1
|
40 |
+
|
41 |
+
for i in range(len(S1)):
|
42 |
+
if S1[i] > 0.03:
|
43 |
+
index1 = i
|
44 |
+
break
|
45 |
+
|
46 |
+
for i in range(len(S2)):
|
47 |
+
if S2[i] > 0.03:
|
48 |
+
index2 = i
|
49 |
+
break
|
50 |
+
|
51 |
+
if (index1 < index2):
|
52 |
+
index = index1
|
53 |
+
else:
|
54 |
+
index = index2
|
55 |
+
|
56 |
+
residues = np.mean(np.square(S1[index:index + 401] - S2[index:index + 401]))
|
57 |
+
# offsetVector.append( index1 )
|
58 |
+
|
59 |
+
return residues
|
60 |
+
|
61 |
+
|
62 |
+
def do_iac(signal, pairs, fs):
|
63 |
+
# signal = data / 32767
|
64 |
+
residuesVector = []
|
65 |
+
|
66 |
+
for offset in [5, -5]:
|
67 |
+
|
68 |
+
# Computer overall cancellation error for this angle
|
69 |
+
iterator = 0
|
70 |
+
residues = 0
|
71 |
+
for mic1, mic2 in pairs:
|
72 |
+
|
73 |
+
Chan1 = signal[:, mic1]
|
74 |
+
Chan2 = signal[:, mic2]
|
75 |
+
|
76 |
+
S1 = Chan1 # butter_highpass_filter(Chan1 , 100 , fs , 7)
|
77 |
+
S2 = Chan2 # butter_highpass_filter(Chan2 , 100 , fs , 7)
|
78 |
+
|
79 |
+
index = -1
|
80 |
+
for i in range(len(S1)):
|
81 |
+
if (S1[i] > 0.003 and i > 40):
|
82 |
+
index = i
|
83 |
+
break
|
84 |
+
|
85 |
+
if (iterator == 0 or iterator == 4):
|
86 |
+
a = S1[index - 15:index + 15]
|
87 |
+
b = S2[index - 15:index + 15]
|
88 |
+
residues += np.square(np.subtract(a, b))
|
89 |
+
elif (iterator == 1 or iterator == 3):
|
90 |
+
a = S1[index - 15 + offset // 2:index + 15 + offset // 2]
|
91 |
+
b = S2[index - 15:index + 15]
|
92 |
+
residues += np.square(np.subtract(a, b))
|
93 |
+
elif (iterator == 2):
|
94 |
+
a = S1[index - 15 + offset:index + 15 + offset]
|
95 |
+
b = S2[index - 15:index + 15]
|
96 |
+
residues += np.square(np.subtract(a, b))
|
97 |
+
elif (iterator == 5 or iterator == 7):
|
98 |
+
a = S1[index - 15 - offset // 2:index + 15 - offset // 2]
|
99 |
+
b = S2[index - 15:index + 15]
|
100 |
+
residues += np.square(np.subtract(a, b))
|
101 |
+
elif (iterator == 6):
|
102 |
+
a = S1[index - 15 - offset:index + 15 - offset]
|
103 |
+
b = S2[index - 15:index + 15]
|
104 |
+
residues += np.square(np.subtract(a, b))
|
105 |
+
|
106 |
+
iterator += 1
|
107 |
+
|
108 |
+
residuesVector.append(np.mean(residues))
|
109 |
+
|
110 |
+
return residuesVector[0] < residuesVector[1]
|
111 |
+
|
112 |
+
|
113 |
+
def calculateResidues_eight(Chan1, Chan2, fs):
|
114 |
+
S1 = Chan1 # butter_highpass_filter(Chan1 , 100 , fs , 7 )
|
115 |
+
S2 = Chan2 # butter_highpass_filter(Chan2 , 100 , fs , 7 )
|
116 |
+
|
117 |
+
index1 = -1
|
118 |
+
index2 = -1
|
119 |
+
index = -1
|
120 |
+
|
121 |
+
for i in range(len(S1)):
|
122 |
+
if S1[i] > 0.01:
|
123 |
+
index1 = i
|
124 |
+
break
|
125 |
+
|
126 |
+
for i in range(len(S2)):
|
127 |
+
if S2[i] > 0.01:
|
128 |
+
index2 = i
|
129 |
+
break
|
130 |
+
|
131 |
+
if (index1 < index2):
|
132 |
+
index = index1
|
133 |
+
else:
|
134 |
+
index = index2
|
135 |
+
|
136 |
+
residues = np.mean(np.square(S1[index:index + 401] - S2[index:index + 401]))
|
137 |
+
|
138 |
+
return residues
|
139 |
+
|
140 |
+
|
141 |
+
def getAngle_for_eight(data, fs):
|
142 |
+
signal = data / 32767
|
143 |
+
for i in range(8):
|
144 |
+
column = butter_highpass_filter(signal[:, i], 100, fs, 7)
|
145 |
+
signal[:, i] = column
|
146 |
+
|
147 |
+
pairs = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 0)]
|
148 |
+
smallestResidues = 100
|
149 |
+
closestPair = (0, 0)
|
150 |
+
offsetIndex = -1
|
151 |
+
|
152 |
+
for iter in range(8):
|
153 |
+
|
154 |
+
chan1 = signal[:, pairs[iter][0]]
|
155 |
+
chan2 = signal[:, pairs[iter][1]]
|
156 |
+
|
157 |
+
residues = calculateResidues_eight(chan1, chan2, fs)
|
158 |
+
|
159 |
+
if (residues < smallestResidues):
|
160 |
+
smallestResidues = residues
|
161 |
+
closestPair = (pairs[iter])
|
162 |
+
offsetIndex = iter
|
163 |
+
|
164 |
+
if do_iac(signal, pairs, fs) == True:
|
165 |
+
d1 = abs(offsetIndex - 4)
|
166 |
+
d2 = abs((offsetIndex + 4) % 8 - 4)
|
167 |
+
if (d1 < d2):
|
168 |
+
pass
|
169 |
+
else:
|
170 |
+
closestPair = pairs[(offsetIndex + 4) % 8]
|
171 |
+
|
172 |
+
mics = (closestPair[0] + 1, closestPair[1] + 1)
|
173 |
+
|
174 |
+
return mics
|
175 |
+
|
176 |
+
|
177 |
+
def getAngle_for_six(data, fs):
|
178 |
+
signal = data / 32767
|
179 |
+
pairs = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 0)]
|
180 |
+
smallestResidues = 100
|
181 |
+
closestPair = (0, 0)
|
182 |
+
offsetIndex = -1
|
183 |
+
|
184 |
+
for iter in range(6):
|
185 |
+
|
186 |
+
chan1 = signal[:, pairs[iter][0]]
|
187 |
+
chan2 = signal[:, pairs[iter][1]]
|
188 |
+
|
189 |
+
residues = calculateResidues(chan1, chan2, fs)
|
190 |
+
|
191 |
+
if (residues < smallestResidues):
|
192 |
+
smallestResidues = residues
|
193 |
+
closestPair = (pairs[iter])
|
194 |
+
offsetIndex = iter
|
195 |
+
|
196 |
+
""" if (offsetVector[offsetIndex] > offsetVector[(offsetIndex+3)%6] ):
|
197 |
+
closestPair = pairs[(offsetIndex+3)%6] """
|
198 |
+
|
199 |
+
mics = (closestPair[0] + 1, closestPair[1] + 1)
|
200 |
+
# print(offsetVector)
|
201 |
+
|
202 |
+
return mics
|
203 |
+
|
204 |
+
|
205 |
+
def getAngle_for_four(data, fs):
|
206 |
+
signal = data / 32767
|
207 |
+
pairs = [(0, 1), (1, 2), (2, 3), (3, 0)]
|
208 |
+
smallestResidues = 100
|
209 |
+
closestPair = (0, 0)
|
210 |
+
offsetIndex = -1
|
211 |
+
|
212 |
+
for iter in range(4):
|
213 |
+
|
214 |
+
chan1 = signal[:, pairs[iter][0]]
|
215 |
+
chan2 = signal[:, pairs[iter][1]]
|
216 |
+
|
217 |
+
residues = calculateResidues(chan1, chan2, fs)
|
218 |
+
|
219 |
+
if (residues < smallestResidues):
|
220 |
+
smallestResidues = residues
|
221 |
+
closestPair = (pairs[iter])
|
222 |
+
offsetIndex = iter
|
223 |
+
|
224 |
+
""" if (offsetVector[offsetIndex] > offsetVector[(offsetIndex+3)%6] ):
|
225 |
+
closestPair = pairs[(offsetIndex+3)%6] """
|
226 |
+
|
227 |
+
mics = (closestPair[0] + 1, closestPair[1] + 1)
|
228 |
+
# print(offsetVector)
|
229 |
+
|
230 |
+
return mics
|
231 |
+
|
232 |
+
|
233 |
+
def getDirection_Pair(closestPair, num_chan):
|
234 |
+
"""
|
235 |
+
:param closestPair: two closet pair, such as (0,1)
|
236 |
+
:param num_chan: channel numbers, such as 8
|
237 |
+
:return: in above parameters, should be [7 0 1 2]
|
238 |
+
"""
|
239 |
+
pairs = [0, 0, 0, 0]
|
240 |
+
pairs[1] = closestPair[0] - 1
|
241 |
+
pairs[2] = closestPair[1] - 1
|
242 |
+
pairs[0] = (pairs[1] - int(num_chan/2)) % num_chan
|
243 |
+
pairs[3] = (pairs[2] + int(num_chan/2)) % num_chan
|
244 |
+
|
245 |
+
return pairs
|
246 |
+
|
247 |
+
|
248 |
+
|
src/generate_array_feature.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This is the main ArrayID feature building script
|
3 |
+
|
4 |
+
revised: April 04, 2021
|
5 |
+
|
6 |
+
'''
|
7 |
+
|
8 |
+
import glob
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from scipy.io.wavfile import read
|
13 |
+
from scipy.fftpack import fft, ifft, fftfreq
|
14 |
+
from scipy import signal
|
15 |
+
import random
|
16 |
+
from librosa.core import lpc
|
17 |
+
import librosa.feature
|
18 |
+
import csv
|
19 |
+
from sklearn.preprocessing import normalize
|
20 |
+
from direction_detection import *
|
21 |
+
|
22 |
+
|
23 |
+
##############################################
|
24 |
+
# HELPER FUNCTIONS
|
25 |
+
|
26 |
+
# converts hz to indices -> allows splicing of freq data
|
27 |
+
def hz_to_indices(freqs, lowcut, highcut):
|
28 |
+
i = 0
|
29 |
+
while freqs[i] < lowcut:
|
30 |
+
i += 1
|
31 |
+
low = i
|
32 |
+
while freqs[i] < highcut:
|
33 |
+
i += 1
|
34 |
+
return low, i
|
35 |
+
|
36 |
+
|
37 |
+
# compresses our feature vectors
|
38 |
+
# After extracting our features, they could be different lengths depending on
|
39 |
+
# the input signal, so we normalize each feature vector to be the same no matter
|
40 |
+
# the speaker
|
41 |
+
def get_row_compressor(old_dimension, new_dimension):
|
42 |
+
dim_compressor = np.zeros((new_dimension, old_dimension))
|
43 |
+
bin_size = float(old_dimension) / new_dimension
|
44 |
+
next_bin_break = bin_size
|
45 |
+
which_row = 0
|
46 |
+
which_column = 0
|
47 |
+
while which_row < dim_compressor.shape[0] and which_column < dim_compressor.shape[1]:
|
48 |
+
if round(next_bin_break - which_column, 10) >= 1:
|
49 |
+
dim_compressor[which_row, which_column] = 1
|
50 |
+
which_column += 1
|
51 |
+
elif next_bin_break == which_column:
|
52 |
+
|
53 |
+
which_row += 1
|
54 |
+
next_bin_break += bin_size
|
55 |
+
else:
|
56 |
+
partial_credit = next_bin_break - which_column
|
57 |
+
dim_compressor[which_row, which_column] = partial_credit
|
58 |
+
which_row += 1
|
59 |
+
dim_compressor[which_row, which_column] = 1 - partial_credit
|
60 |
+
which_column += 1
|
61 |
+
next_bin_break += bin_size
|
62 |
+
dim_compressor /= bin_size
|
63 |
+
return dim_compressor
|
64 |
+
|
65 |
+
# helper functions for above function
|
66 |
+
def get_column_compressor(old_dimension, new_dimension):
|
67 |
+
return get_row_compressor(old_dimension, new_dimension).transpose()
|
68 |
+
|
69 |
+
def compress_and_average(array, new_shape):
|
70 |
+
return np.mat(get_row_compressor(array.shape[0], new_shape[0])) * \
|
71 |
+
np.mat(array) * \
|
72 |
+
np.mat(get_column_compressor(array.shape[1], new_shape[1]))
|
73 |
+
##############################################
|
74 |
+
|
75 |
+
|
76 |
+
##############################################
|
77 |
+
# MAIN FEATURE EXTRACTION FUNCTIONS
|
78 |
+
|
79 |
+
|
80 |
+
def get_filelist(dir):
|
81 |
+
Filelist = []
|
82 |
+
for home, dirs, files in os.walk(dir):
|
83 |
+
for filename in files:
|
84 |
+
Filelist.append(os.path.join(home, filename))
|
85 |
+
return Filelist
|
86 |
+
|
87 |
+
|
88 |
+
def lpcc(data, n=15):
|
89 |
+
"""
|
90 |
+
f_LPC = lpcc(data, n): get the LPCC from the voice data
|
91 |
+
The order n is 15
|
92 |
+
"""
|
93 |
+
size_lpc = n # define the order of LPCC
|
94 |
+
a = lpc(data, order = size_lpc) # use the built-in function
|
95 |
+
a = -a
|
96 |
+
f_LPC = np.zeros(len(a))
|
97 |
+
f_LPC[0] = np.log(size_lpc)
|
98 |
+
for i in range(1, len(a)):
|
99 |
+
k = np.arange(1, i) # k from 1 to i-1
|
100 |
+
f_LPC[i] = a[i] + np.sum((1 - k/i) * a[k] * f_LPC[i - k])
|
101 |
+
return f_LPC[1:]
|
102 |
+
|
103 |
+
|
104 |
+
# returns long term fft
|
105 |
+
def get_ltfd(spec, m=20, start_index=1, end_index=86):
|
106 |
+
# only get the useful part
|
107 |
+
spec = spec[:, start_index: end_index, :(spec.shape[2] - spec.shape[2] % m)]
|
108 |
+
|
109 |
+
# merge the spec in the time line
|
110 |
+
channels = np.sum(spec, axis=2)
|
111 |
+
|
112 |
+
all_ffts = np.sum(channels, axis=0)
|
113 |
+
all_ffts /= np.max(all_ffts)
|
114 |
+
|
115 |
+
channels_ffts = np.asarray([channels[i, :] / np.max(channels[i, :]) for i in range(channels.shape[0])])
|
116 |
+
|
117 |
+
return all_ffts, channels_ffts
|
118 |
+
|
119 |
+
|
120 |
+
# returns long term fft
|
121 |
+
def get_ltfp(spec, m=20, start_index_fp=1, end_index_fp=86):
|
122 |
+
# only get the useful part
|
123 |
+
spec = spec[:, start_index_fp:end_index_fp, :(spec.shape[2] - spec.shape[2] % m)]
|
124 |
+
|
125 |
+
# split the data
|
126 |
+
splices = np.asarray(np.split(spec, m, axis=2))
|
127 |
+
|
128 |
+
# merge the data (wang ge hua)
|
129 |
+
mesh = np.zeros((splices.shape[0], splices.shape[1], splices.shape[2]))
|
130 |
+
for i in range(mesh.shape[0]):
|
131 |
+
for j in range(mesh.shape[1]):
|
132 |
+
for k in range(mesh.shape[2]):
|
133 |
+
mesh[i, j, k] = np.sum(splices[i, j, k, :])
|
134 |
+
|
135 |
+
# calculate the standard deviation
|
136 |
+
std_feature = np.zeros((mesh.shape[0], mesh.shape[2]))
|
137 |
+
for i in range(std_feature.shape[0]):
|
138 |
+
for j in range(std_feature.shape[1]):
|
139 |
+
std_feature[i, j] = np.std(mesh[i, :, j]) / np.mean(mesh[i, :, j])
|
140 |
+
|
141 |
+
# define the ltfp
|
142 |
+
LTFP = np.mean(std_feature, axis=0)
|
143 |
+
LTFP = LTFP / np.max(LTFP)
|
144 |
+
return LTFP
|
145 |
+
|
146 |
+
|
147 |
+
def feature_distribution(channel_fft):
|
148 |
+
num_feature = 5
|
149 |
+
f_dis = np.zeros(2 * num_feature)
|
150 |
+
co = np.zeros((num_feature, len(channel_fft)))
|
151 |
+
for num in range(len(channel_fft)):
|
152 |
+
a = channel_fft[num]
|
153 |
+
for i in range(1, len(a)):
|
154 |
+
a[i] = a[i-1] + a[i]
|
155 |
+
a = a / np.max(a)
|
156 |
+
dis_index = [0.1, 0.3, 0.5, 0.7, 0.9]
|
157 |
+
for i in range(len(dis_index)):
|
158 |
+
co[i, num] = find_value(a, dis_index[i])
|
159 |
+
co[:, num] /= len(a)
|
160 |
+
for i in range(num_feature):
|
161 |
+
f_dis[i] = np.mean(co[i, :])
|
162 |
+
f_dis[i + num_feature] = np.std(co[i, :])
|
163 |
+
return co, f_dis
|
164 |
+
|
165 |
+
|
166 |
+
def find_value(a, dis_index):
|
167 |
+
c = 0
|
168 |
+
for i in range(len(a) - 1):
|
169 |
+
if a[i] <= dis_index <= a[i + 1]:
|
170 |
+
c = i
|
171 |
+
break
|
172 |
+
return c
|
173 |
+
|
174 |
+
|
175 |
+
def mald_feature(rate, data):
|
176 |
+
n_fft = 4096
|
177 |
+
# detect the direction
|
178 |
+
if data.shape[1] == 4:
|
179 |
+
closestPair = getAngle_for_four(data, fs=rate)
|
180 |
+
elif data.shape[1] == 6:
|
181 |
+
closestPair = getAngle_for_six(data, fs=rate)
|
182 |
+
elif data.shape[1] == 8:
|
183 |
+
closestPair = getAngle_for_eight(data, fs=rate)
|
184 |
+
pairs = getDirection_Pair(closestPair, data.shape[1])
|
185 |
+
|
186 |
+
# low and high thresholds for field print features -> we want 1 - 10kHz range
|
187 |
+
lowcut_fp = 1
|
188 |
+
highcut_fp = 5000
|
189 |
+
if highcut_fp > rate / 2: # in case the sampling rate is very small
|
190 |
+
highcut_fp = rate / 2 - 100
|
191 |
+
highcut_fd = 1000
|
192 |
+
|
193 |
+
# input rate -> make sure to change this based on device.
|
194 |
+
# All of the devices are 44100 except for the AMLOGIC, which is 16kHz.
|
195 |
+
# If this rate is not changed acccordingly, the _ltfp and _ltfft features
|
196 |
+
# will be off
|
197 |
+
|
198 |
+
# just some helper splicing globals
|
199 |
+
freq = fftfreq(n_fft, 1. / rate) # data = logmmse(data, rate)
|
200 |
+
start_index, end_index = hz_to_indices(freq, lowcut_fp, highcut_fd)
|
201 |
+
start_index_fp, end_index_fp = hz_to_indices(freq, lowcut_fp, highcut_fp)
|
202 |
+
|
203 |
+
|
204 |
+
# empty feature vectors
|
205 |
+
_lpcc = []
|
206 |
+
# extract lfp and lpcc from each channel independently, then sum
|
207 |
+
for i in pairs:
|
208 |
+
a = np.asfortranarray(data[:, i]).astype(dtype=float)
|
209 |
+
_lpcc += list(lpcc(a))
|
210 |
+
|
211 |
+
# calculate the spectrogram
|
212 |
+
spec = [signal.stft(data[:, i], fs=rate, window='hann', nperseg=1024, noverlap=768, nfft=n_fft)[2] for i in range(data.shape[1])]
|
213 |
+
spec = np.asarray(spec) # convert list to numpy
|
214 |
+
# obtain the absolute value
|
215 |
+
spec = np.abs(spec)
|
216 |
+
|
217 |
+
# get the ltfp feature
|
218 |
+
|
219 |
+
|
220 |
+
# get ltfp features and compress to a 50 feature vectoc
|
221 |
+
_ltfd, channel_fft = get_ltfd(spec=spec, start_index=start_index, end_index=end_index)
|
222 |
+
|
223 |
+
_ltfd = list(compress_and_average(_ltfd.reshape(len(_ltfd), 1), (20, 1)).flat)
|
224 |
+
|
225 |
+
co, _fdis = feature_distribution(channel_fft)
|
226 |
+
|
227 |
+
# get ltfp features and compress to a 50 feature vector
|
228 |
+
_ltfp = get_ltfp(spec=spec, start_index_fp=start_index_fp, end_index_fp=end_index_fp)
|
229 |
+
_ltfp = list(compress_and_average(_ltfp.reshape(len(_ltfp), 1), (20, 1)).flat)
|
230 |
+
|
231 |
+
# out is final feature vector, each data point formed as a tuple : (X, y), where X is the feature vector and y is the label
|
232 |
+
# X_y is just compiled l ist of all the tuples
|
233 |
+
feature = np.concatenate((_lpcc, _ltfd, _fdis, _ltfp))
|
234 |
+
return feature
|
235 |
+
|
src/main_functions.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
这是带注释的,我用中文写了
|
3 |
+
"""
|
4 |
+
#%% 导入必要的包
|
5 |
+
import numpy as np
|
6 |
+
from tensorflow.keras.models import Sequential
|
7 |
+
from tensorflow.keras.layers import Dense, Dropout
|
8 |
+
from tensorflow.keras.losses import binary_crossentropy
|
9 |
+
from tensorflow.keras.optimizers import Adam
|
10 |
+
from sklearn.metrics import roc_curve
|
11 |
+
from scipy.interpolate import interp1d
|
12 |
+
from scipy.optimize import brentq
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
from scipy.io.wavfile import read
|
15 |
+
from sklearn.preprocessing import normalize
|
16 |
+
from generate_array_feature import mald_feature, get_filelist
|
17 |
+
import time
|
18 |
+
|
19 |
+
|
20 |
+
#%% 定义分类器model
|
21 |
+
# 这一个代码块是用来定义model的。
|
22 |
+
# 定义model的batch_size, feature长度之类的
|
23 |
+
batch_size = 10
|
24 |
+
feature_len = 110
|
25 |
+
loss_function = binary_crossentropy
|
26 |
+
no_epochs = 150
|
27 |
+
optimizer = Adam()
|
28 |
+
verbosity = 1
|
29 |
+
model = Sequential()
|
30 |
+
model.add(Dense(64, input_dim=feature_len, activation='relu'))
|
31 |
+
model.add(Dropout(0.2))
|
32 |
+
model.add(Dense(32, activation='relu'))
|
33 |
+
model.add(Dropout(0.2))
|
34 |
+
model.add(Dense(16, activation='relu'))
|
35 |
+
model.add(Dense(1, activation='sigmoid'))
|
36 |
+
model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])
|
37 |
+
# 至此,分类器模型的基本参数已经设置完毕,接下来可以从hdf5文件中导入预先训练好的model
|
38 |
+
model.load_weights(r"/home/fazhong/Github/czx/model.hdf5")
|
39 |
+
# 从train2.hdf5导入model。
|
40 |
+
# train2.hdf5 是从 data2.npy训练来的。
|
41 |
+
# 这样与 data1.npy数据不会有重叠
|
42 |
+
|
43 |
+
|
44 |
+
#%% 导入音频
|
45 |
+
path_wave = r"/home/fazhong/Github/czx/voice"
|
46 |
+
print("Loading data ...")
|
47 |
+
name_all = get_filelist(path_wave)
|
48 |
+
voice = []
|
49 |
+
# voice 是从 一堆 wav 音频文件中提取的波形
|
50 |
+
X = [] # X is the feature ~ data[0]
|
51 |
+
y = [] # y is the normal (1) or attack (0) ~ data[1]
|
52 |
+
|
53 |
+
for file_path in name_all:
|
54 |
+
file_name = file_path.split("\\")[-1]
|
55 |
+
# define the normal or attack in variable cur_y
|
56 |
+
if 'normal' in file_name:
|
57 |
+
cur_y = 1 # normal case
|
58 |
+
elif 'attack' in file_name:
|
59 |
+
cur_y = 0
|
60 |
+
# split the file name
|
61 |
+
# read the data
|
62 |
+
rate, data = read(file_path)
|
63 |
+
voice += [list(data)]
|
64 |
+
|
65 |
+
X += [list(mald_feature(rate, data))]
|
66 |
+
print(list(mald_feature(rate, data)))
|
67 |
+
# 从wav 文件提取特征的函数是 generate_array_feature.py
|
68 |
+
# X 是特征,特征的维度是110维
|
69 |
+
y += [cur_y]
|
70 |
+
# y是标签,1代表正常样本,0代表攻击样本
|
71 |
+
|
72 |
+
# normalization
|
73 |
+
norm_X = normalize(X, axis=0, norm='max')
|
74 |
+
# X_y = [(norm_X[i], y[i]) for i in range(len(norm_X))]
|
75 |
+
# # print(len(X_y))
|
76 |
+
# # for i in X_y: print(i[1])
|
77 |
+
# X_y = np.asarray(X_y)
|
78 |
+
|
79 |
+
X = np.asarray(norm_X)
|
80 |
+
y = np.asarray(y)
|
81 |
+
|
82 |
+
# X = np.asarray([x[0] for x in X_y])
|
83 |
+
# y = np.asarray([x[1] for x in X_y])
|
84 |
+
|
85 |
+
#%% 画出特征来
|
86 |
+
index1 = [5] # 选第2121个元素
|
87 |
+
x1 = X[index1]
|
88 |
+
y1 = y[index1] # 1,代表normal
|
89 |
+
plt.plot(x1.T, label='normal')
|
90 |
+
index2 = [1] # 选择第10个元素
|
91 |
+
x2 = X[index2]
|
92 |
+
y2 = y[index2] # 0, 代表attack
|
93 |
+
plt.plot(x2.T, label='attack')
|
94 |
+
plt.legend()
|
95 |
+
plt.show()
|
96 |
+
# 可以明显看出 normal 与 attack 的区别,这也是我们分类的基础
|
97 |
+
|
98 |
+
#%% 开始预测
|
99 |
+
scores = model.evaluate(X, y) # 这是一个总体的预测
|
100 |
+
y_pred = np.round(model.predict(X)) # 这里会给出一个预测的结论
|
101 |
+
index1 = 8 # 8 是一个正常样本
|
102 |
+
index3 = [1, 3, 5, 7, 9] # 选一些样本,等wav 文件到了,输入就直接是wav
|
103 |
+
for i in index3:
|
104 |
+
print('Starting detection:')
|
105 |
+
plt.plot(voice[i], label='Voice Signal')
|
106 |
+
plt.show()
|
107 |
+
time.sleep(2)
|
108 |
+
if y[i] == 1: # 正常情况
|
109 |
+
print('the ' + str(i) + ' sample is normal')
|
110 |
+
title = 'the ' + str(i) + ' sample is normal'
|
111 |
+
plt.subplot(1, 2, 1)
|
112 |
+
plt.plot(X[index1])
|
113 |
+
plt.subplot(1, 2, 2)
|
114 |
+
plt.plot(X[i], label='New')
|
115 |
+
plt.title(title)
|
116 |
+
plt.show()
|
117 |
+
time.sleep(1)
|
118 |
+
if y_pred[i] == y[i]:
|
119 |
+
print("Successfully Detect") # 成功预测
|
120 |
+
print("Run the car")
|
121 |
+
title = "Successfully Detect, " + "Run the car"
|
122 |
+
plt.title(title)
|
123 |
+
plt.show()
|
124 |
+
else:
|
125 |
+
print("Detection is false.") # 失败预测
|
126 |
+
print("Don't run the car")
|
127 |
+
title = "Detection is false, " + "Don't run the car"
|
128 |
+
plt.title(title)
|
129 |
+
plt.show()
|
130 |
+
else: # 异常情况,决策是相反的
|
131 |
+
print('the ' + str(i) + ' sample is attack')
|
132 |
+
title = 'the ' + str(i) + ' sample is attack'
|
133 |
+
plt.subplot(1, 2, 1)
|
134 |
+
plt.plot(X[index1], label='Normal')
|
135 |
+
plt.subplot(1, 2, 2)
|
136 |
+
plt.plot(X[i], label='New')
|
137 |
+
plt.title(title)
|
138 |
+
plt.show()
|
139 |
+
time.sleep(1)
|
140 |
+
if y_pred[i] == y[i]:
|
141 |
+
print("Successfully Detect") # 成功预测
|
142 |
+
print("Don't run the car")
|
143 |
+
title = "Successfully Detect, " + "Don't run the car"
|
144 |
+
plt.title(title)
|
145 |
+
plt.show()
|
146 |
+
else:
|
147 |
+
print("Detection is false.") # 失败预测
|
148 |
+
print("Run the car")
|
149 |
+
title = "Detection is false, " + "Run the car"
|
150 |
+
plt.title(title)
|
151 |
+
plt.show()
|
152 |
+
|
153 |
+
print("-------------------------")
|
src/main_gui.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#code
|
2 |
+
#coding=UTF-8
|
3 |
+
# ! -*- coding: utf-8 -*-
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
#%% 导入必要的包
|
7 |
+
import numpy as np
|
8 |
+
from tensorflow.keras.models import Sequential
|
9 |
+
from tensorflow.keras.layers import Dense, Dropout
|
10 |
+
from tensorflow.keras.losses import binary_crossentropy
|
11 |
+
from tensorflow.keras.optimizers import Adam
|
12 |
+
from sklearn.metrics import roc_curve
|
13 |
+
from scipy.interpolate import interp1d
|
14 |
+
from scipy.optimize import brentq
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
from scipy.io.wavfile import read
|
17 |
+
from sklearn.preprocessing import normalize
|
18 |
+
from generate_array_feature import mald_feature, get_filelist
|
19 |
+
|
20 |
+
import os
|
21 |
+
import threading
|
22 |
+
import tkinter as tk
|
23 |
+
from tkinter import filedialog
|
24 |
+
from PIL import Image,ImageTk
|
25 |
+
from tkinter.messagebox import *
|
26 |
+
from tkinter import scrolledtext
|
27 |
+
top1=None
|
28 |
+
top2=None
|
29 |
+
top4=None
|
30 |
+
top5=None
|
31 |
+
top6=None
|
32 |
+
top7=None
|
33 |
+
img_open=None
|
34 |
+
img=None
|
35 |
+
v1=None
|
36 |
+
v2=None
|
37 |
+
ll=0
|
38 |
+
s2=''
|
39 |
+
s1=''
|
40 |
+
top3=None
|
41 |
+
t2=None
|
42 |
+
s=''
|
43 |
+
f1="fg.txt"
|
44 |
+
f2="fg.txt"
|
45 |
+
v=None
|
46 |
+
top=None
|
47 |
+
v={}
|
48 |
+
d1={}
|
49 |
+
d2={}
|
50 |
+
message=""
|
51 |
+
ermsg=""
|
52 |
+
picn=0
|
53 |
+
arg = []
|
54 |
+
class MyThread(threading.Thread):
|
55 |
+
def __init__(self, func, *args):#多线程启动,防止界面卡死
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.func = func
|
59 |
+
self.args = args
|
60 |
+
|
61 |
+
self.setDaemon(True)
|
62 |
+
self.start()
|
63 |
+
|
64 |
+
def run(self):
|
65 |
+
self.func(*self.args)
|
66 |
+
|
67 |
+
def chf(tt1):#选择音频文件
|
68 |
+
global f1
|
69 |
+
f1=filedialog.askopenfilename()
|
70 |
+
showinfo("Open File", "Open a new File.")
|
71 |
+
tt1.delete(0.0, tk.END)
|
72 |
+
tt1.insert(0.0, f1)
|
73 |
+
|
74 |
+
|
75 |
+
def info():
|
76 |
+
pp='语言接口安全'
|
77 |
+
showinfo('Information',pp)
|
78 |
+
|
79 |
+
def build_model():
|
80 |
+
# %% 定义分类器model
|
81 |
+
# 这一个代码块是用来定义model的。
|
82 |
+
# 定义model的batch_size, feature长度之类的
|
83 |
+
batch_size = 10
|
84 |
+
feature_len = 110
|
85 |
+
loss_function = binary_crossentropy
|
86 |
+
no_epochs = 150
|
87 |
+
optimizer = Adam()
|
88 |
+
verbosity = 1
|
89 |
+
model = Sequential()
|
90 |
+
model.add(Dense(64, input_dim=feature_len, activation='relu'))
|
91 |
+
model.add(Dropout(0.2))
|
92 |
+
model.add(Dense(32, activation='relu'))
|
93 |
+
model.add(Dropout(0.2))
|
94 |
+
model.add(Dense(16, activation='relu'))
|
95 |
+
model.add(Dense(1, activation='sigmoid'))
|
96 |
+
model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])
|
97 |
+
# 至此,分类器模型的基本参数已经设置完毕,接下来可以从hdf5文件中导入预先训练好的model
|
98 |
+
model.load_weights("model.hdf5")
|
99 |
+
# 从train2.hdf5导入model。
|
100 |
+
# train2.hdf5 是从 data2.npy训练来的。
|
101 |
+
# 这样与 data1.npy数据不会有重叠
|
102 |
+
return model
|
103 |
+
|
104 |
+
def show_data(f1):
|
105 |
+
file_path = f1
|
106 |
+
print(f1)
|
107 |
+
rate, data = read(file_path)
|
108 |
+
plt.plot(data, label='Voice Signal')
|
109 |
+
plt.show()
|
110 |
+
|
111 |
+
|
112 |
+
def show_feature(f1):
|
113 |
+
file_path = f1
|
114 |
+
file_name = file_path.split("\\")[-1]
|
115 |
+
# define the normal or attack in variable cur_y
|
116 |
+
if 'normal' in file_name:
|
117 |
+
cur_y = 1 # normal case
|
118 |
+
elif 'attack' in file_name:
|
119 |
+
cur_y = 0
|
120 |
+
# split the file name
|
121 |
+
# read the data
|
122 |
+
rate, data = read(file_path)
|
123 |
+
X = mald_feature(rate, data)
|
124 |
+
# 从wav 文件提取特征的函数是 generate_array_feature.py
|
125 |
+
# X 是特征,特征的维度是110维
|
126 |
+
y = cur_y
|
127 |
+
# y是标签,1代表正常样本,0代表攻击样本
|
128 |
+
if y == 1: # 正常情况
|
129 |
+
title = 'the sample is normal'
|
130 |
+
else:
|
131 |
+
title = 'the sample is attack'
|
132 |
+
plt.plot(X)
|
133 |
+
plt.title(title)
|
134 |
+
plt.show()
|
135 |
+
|
136 |
+
|
137 |
+
def detect(f1, model):
|
138 |
+
file_path = f1
|
139 |
+
file_name = file_path.split("\\")[-1]
|
140 |
+
# define the normal or attack in variable cur_y
|
141 |
+
if 'normal' in file_name:
|
142 |
+
cur_y = 1 # normal case
|
143 |
+
elif 'attack' in file_name:
|
144 |
+
cur_y = 0
|
145 |
+
# split the file name
|
146 |
+
# read the data
|
147 |
+
rate, data = read(file_path)
|
148 |
+
X = []
|
149 |
+
X += [list(mald_feature(rate, data))]
|
150 |
+
X += [list(mald_feature(rate, data))]
|
151 |
+
# 加2次,因为model需要一个二维的
|
152 |
+
X = np.asarray(X)
|
153 |
+
|
154 |
+
# 从wav 文件提取特征的函数是 generate_array_feature.py
|
155 |
+
# X 是特征,特征的维度是110维
|
156 |
+
y = cur_y
|
157 |
+
# y是标签,1代表正常样本,0代表攻击样本
|
158 |
+
y_pred = np.round(model.predict(X))
|
159 |
+
# 开始预测
|
160 |
+
y_pred = y_pred[0]
|
161 |
+
|
162 |
+
if y == 1: # 正常情况
|
163 |
+
if y_pred == y:
|
164 |
+
print("成功预测") # 成功预测
|
165 |
+
print("车辆运行")
|
166 |
+
title = "指令正常,预测正确,车辆运行"
|
167 |
+
print('--------------')
|
168 |
+
print(title)
|
169 |
+
else:
|
170 |
+
print("失败预测") # 失败预测
|
171 |
+
print("车辆静止")
|
172 |
+
title = "指令正常,预测失败,车辆静止"
|
173 |
+
print('--------------')
|
174 |
+
print(title)
|
175 |
+
else: # 异常情况,决策是相反的
|
176 |
+
if y_pred == y:
|
177 |
+
print("��功预测") # 成功预测
|
178 |
+
print("车辆静止")
|
179 |
+
title = "指令异常,预测正确,车辆静止"
|
180 |
+
print('--------------')
|
181 |
+
print(title)
|
182 |
+
else:
|
183 |
+
print("失败预测") # 失败预测
|
184 |
+
print("车辆运行")
|
185 |
+
title = "指令异常,预测失败,车辆运行"
|
186 |
+
print('--------------')
|
187 |
+
print(title)
|
188 |
+
|
189 |
+
|
190 |
+
ans=""
|
191 |
+
|
192 |
+
|
193 |
+
root=tk.Tk(className='语音接口认证系统')
|
194 |
+
#root.iconbitmap('bf.ico')
|
195 |
+
root.attributes("-alpha",0.9)
|
196 |
+
tk.Label(root,height=10,width=5).grid(row=0,column=0)
|
197 |
+
fra=tk.Frame(root,width=55,height=100)
|
198 |
+
fra.grid(row=0,column=1)
|
199 |
+
tk.Label(root,height=10,width=5).grid(row=0,column=2)
|
200 |
+
tk.Label(fra,text='',height=1,width=10).grid(row=0,column=0)
|
201 |
+
|
202 |
+
tt1=tk.Text(fra,height=2,width=30)
|
203 |
+
tt1.grid(row=1,column=0)
|
204 |
+
tk.Button(fra, text='请先选择语音数据', command=lambda: chf(tt1)).grid(row=1,column=1)
|
205 |
+
model = build_model()
|
206 |
+
|
207 |
+
|
208 |
+
train=tk.Button(fra,text='显示音频内容',font=('楷体,bold'),borderwidth=3,command=lambda :MyThread(show_data,f1)) #完成
|
209 |
+
train.grid(row=3,column=0)
|
210 |
+
|
211 |
+
train=tk.Button(fra,text='显示音频的特征',font=('楷体,bold'),borderwidth=3,command=lambda :MyThread(show_feature,f1)) #完成
|
212 |
+
train.grid(row=5,column=0)
|
213 |
+
|
214 |
+
train=tk.Button(fra,text='显示检测结果',font=('楷体,bold'),borderwidth=3,command=lambda :MyThread(detect,f1,model)) #完成
|
215 |
+
train.grid(row=7,column=0)
|
216 |
+
|
217 |
+
|
218 |
+
tk.mainloop()
|
translate/cmd_judge.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tensorflow.keras.models import Sequential
|
3 |
+
from tensorflow.keras.layers import Dense, Dropout
|
4 |
+
from tensorflow.keras.losses import binary_crossentropy
|
5 |
+
from tensorflow.keras.optimizers import Adam
|
6 |
+
from tensorflow.keras.models import load_model
|
7 |
+
from sklearn.metrics import roc_curve
|
8 |
+
from scipy.interpolate import interp1d
|
9 |
+
from scipy.optimize import brentq
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from scipy.io.wavfile import read
|
12 |
+
from sklearn.preprocessing import normalize
|
13 |
+
from generate_array_feature import mald_feature, get_filelist
|
14 |
+
import time
|
15 |
+
from pydub import AudioSegment
|
16 |
+
import whisper
|
17 |
+
import os
|
18 |
+
import spacy
|
19 |
+
|
20 |
+
# To deal with one wav file.
|
21 |
+
|
22 |
+
def is_command_reasonable(command, time, location):
|
23 |
+
|
24 |
+
commands = [
|
25 |
+
"OK Google.",
|
26 |
+
"Turn on Bluetooth.",
|
27 |
+
"Record a video.",
|
28 |
+
"Take a photo.",
|
29 |
+
"Open music player.",
|
30 |
+
"Set an alarm for 6:30 am.",
|
31 |
+
"Remind me to buy coffee at 7 am.",
|
32 |
+
"What is my schedule for tomorrow?",
|
33 |
+
"Square root of 2105?",
|
34 |
+
"Open browser.",
|
35 |
+
"Decrease volume.",
|
36 |
+
"Turn on flashlight.",
|
37 |
+
"Set the volume to full.",
|
38 |
+
"Mute the volume.",
|
39 |
+
"What's the definition of transmit?",
|
40 |
+
"Call Pizza Hut.",
|
41 |
+
"Call the nearest computer shop.",
|
42 |
+
"Show me my messages.",
|
43 |
+
"Translate please give me directions to Chinese.",
|
44 |
+
"How do you say good night in Japanese?"
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
# Time : Work-0 / Rest-1 / Sleep-2
|
49 |
+
# Location : Work-0 / Home-1
|
50 |
+
|
51 |
+
commands_daily = [
|
52 |
+
"Call Pizza Hut.",
|
53 |
+
"Remind me to buy coffee at 7 am.",
|
54 |
+
"Open music player.",
|
55 |
+
"Record a video.",
|
56 |
+
"Take a photo.",
|
57 |
+
]
|
58 |
+
commands_work = [
|
59 |
+
"Open browser.",
|
60 |
+
"What is my schedule for tomorrow?",
|
61 |
+
"Square root of 2105?",
|
62 |
+
"Call the nearest computer shop.",
|
63 |
+
"Show me my messages.",
|
64 |
+
"Translate please give me directions to Chinese.",
|
65 |
+
"How do you say good night in Japanese?",
|
66 |
+
"What's the definition of transmit?",
|
67 |
+
]
|
68 |
+
commands_basic = [
|
69 |
+
"OK Google.",
|
70 |
+
"Decrease volume.",
|
71 |
+
"Turn on Bluetooth.",
|
72 |
+
"Turn on flashlight.",
|
73 |
+
"Set the volume to full.",
|
74 |
+
"Mute the volume.",
|
75 |
+
"Set an alarm for 6:30 am."]
|
76 |
+
|
77 |
+
|
78 |
+
if time == 0 and location == 0:
|
79 |
+
if command in commands_daily:
|
80 |
+
return False
|
81 |
+
else:
|
82 |
+
return True
|
83 |
+
elif time ==2:
|
84 |
+
if command in commands_basic:
|
85 |
+
return True
|
86 |
+
else:
|
87 |
+
return False
|
88 |
+
else:
|
89 |
+
if command in commands_work:
|
90 |
+
return False
|
91 |
+
else:
|
92 |
+
return True
|
93 |
+
|
94 |
+
def convert_6ch_wav_to_stereo(input_file_path, output_file_path):
|
95 |
+
sound = AudioSegment.from_file(input_file_path, format="wav")
|
96 |
+
if sound.channels != 6:
|
97 |
+
raise ValueError("The input file does not have 6 channels.")
|
98 |
+
front_left = sound.split_to_mono()[0]
|
99 |
+
front_right = sound.split_to_mono()[1]
|
100 |
+
center = sound.split_to_mono()[2]
|
101 |
+
back_left = sound.split_to_mono()[4]
|
102 |
+
back_right = sound.split_to_mono()[5]
|
103 |
+
center = center - 6
|
104 |
+
back_left = back_left - 6
|
105 |
+
back_right = back_right - 6
|
106 |
+
stereo_left = front_left.overlay(center).overlay(back_left)
|
107 |
+
stereo_right = front_right.overlay(center).overlay(back_right)
|
108 |
+
stereo_sound = AudioSegment.from_mono_audiosegments(stereo_left, stereo_right)
|
109 |
+
stereo_sound.export(output_file_path, format="wav")
|
110 |
+
|
111 |
+
def judge_human(rate,data):
|
112 |
+
model = load_model('/home/fazhong/Github/czx/data-task0_1/train1.keras')
|
113 |
+
feature =list(mald_feature(rate, data))
|
114 |
+
features=np.array([feature])
|
115 |
+
y_pred = model.predict(features)
|
116 |
+
return y_pred[0]
|
117 |
+
|
118 |
+
def judge_name(rate,data):
|
119 |
+
model = load_model('/home/fazhong/Github/czx/data-task0/train1.keras')
|
120 |
+
feature =list(mald_feature(rate, data))
|
121 |
+
features=np.array([feature])
|
122 |
+
y_pred = model.predict(features)
|
123 |
+
y_pred_classes = np.argmax(y_pred,axis=1)
|
124 |
+
return y_pred_classes[0]
|
125 |
+
|
126 |
+
def judge_command(file_path):
|
127 |
+
whisper_model = whisper.load_model("large")
|
128 |
+
out_path='/home/fazhong/Github/czx/temp/temp.wav'
|
129 |
+
convert_6ch_wav_to_stereo(file_path,out_path)
|
130 |
+
# print(out_path)
|
131 |
+
result = whisper_model.transcribe(out_path,language="en")
|
132 |
+
text_result = result['text']
|
133 |
+
print(text_result)
|
134 |
+
return text_result
|
135 |
+
|
136 |
+
def judge_classifier(command):
|
137 |
+
nlp = spacy.load('en_core_web_md')
|
138 |
+
commands = [
|
139 |
+
"OK Google.",
|
140 |
+
"Turn on Bluetooth.",
|
141 |
+
"Record a video.",
|
142 |
+
"Take a photo.",
|
143 |
+
"Open music player.",
|
144 |
+
"Set an alarm for 6:30 am.",
|
145 |
+
"Remind me to buy coffee at 7 am.",
|
146 |
+
"What is my schedule for tomorrow?",
|
147 |
+
"Square root of 2105?",
|
148 |
+
"Open browser.",
|
149 |
+
"Decrease volume.",
|
150 |
+
"Turn on flashlight.",
|
151 |
+
"Set the volume to full.",
|
152 |
+
"Mute the volume.",
|
153 |
+
"What’s the definition of transmit?",
|
154 |
+
"Call Pizza Hut.",
|
155 |
+
"Call the nearest computer shop.",
|
156 |
+
"Show me my messages.",
|
157 |
+
"Translate please give me directions to Chinese.",
|
158 |
+
"How do you say good night in Japanese?"
|
159 |
+
]
|
160 |
+
def classify_key(command):
|
161 |
+
if 'ok google' in command:
|
162 |
+
return 1
|
163 |
+
elif 'bluetooth' in command and 'on' in command:
|
164 |
+
return 2
|
165 |
+
elif 'record' in command and 'video' in command:
|
166 |
+
return 3
|
167 |
+
elif 'take' in command and 'photo' in command:
|
168 |
+
return 4
|
169 |
+
elif 'music player' in command and 'open' in command:
|
170 |
+
return 5
|
171 |
+
elif 'set' in command and 'alarm' in command:
|
172 |
+
return 6
|
173 |
+
elif 'remind' in command and 'coffee' in command:
|
174 |
+
return 7
|
175 |
+
elif 'schedule' in command or 'tomorrow' in command:
|
176 |
+
return 8
|
177 |
+
elif 'square root' in command:
|
178 |
+
return 9
|
179 |
+
elif 'open browser' in command:
|
180 |
+
return 10
|
181 |
+
elif 'decrease volume' in command:
|
182 |
+
return 11
|
183 |
+
elif 'flashlight' in command and 'on' in command:
|
184 |
+
return 12
|
185 |
+
elif 'volume' in command and 'full' in command:
|
186 |
+
return 13
|
187 |
+
elif 'mute' in command and 'volume' in command:
|
188 |
+
return 14
|
189 |
+
elif 'definition of' in command:
|
190 |
+
return 15
|
191 |
+
elif 'call' in command and 'pizza hut' in command.lower():
|
192 |
+
return 16
|
193 |
+
elif 'call' in command and 'computer shop' in command.lower():
|
194 |
+
return 17
|
195 |
+
elif 'messages' in command and 'show' in command:
|
196 |
+
return 18
|
197 |
+
elif 'translate' in command:
|
198 |
+
return 19
|
199 |
+
elif 'good night' in command and 'in japanese' in command:
|
200 |
+
return 20
|
201 |
+
else:
|
202 |
+
return None # or some default value if command is not recognized
|
203 |
+
|
204 |
+
file_content = command
|
205 |
+
result_pre = classify_key(file_content.replace('.', '').replace(',', '').lower().strip())
|
206 |
+
if result_pre is not None:
|
207 |
+
return result_pre
|
208 |
+
input_doc = nlp(file_content.replace('.', '').replace(',', '').lower().strip())
|
209 |
+
similarities = [(command, input_doc.similarity(nlp(command))) for command in commands]
|
210 |
+
best_match = max(similarities, key=lambda item: item[1])
|
211 |
+
return best_match[0]
|
212 |
+
|
213 |
+
def judge(file_path,time,location):
|
214 |
+
|
215 |
+
rate, data = read(file_path)
|
216 |
+
# Maybe change to paths?
|
217 |
+
temp = judge_human(rate,data)
|
218 |
+
temp2 = judge_name(rate,data)
|
219 |
+
command = judge_command(file_path)
|
220 |
+
text = judge_classifier(command)
|
221 |
+
if is_command_reasonable(text, time, location):
|
222 |
+
return True
|
223 |
+
else:
|
224 |
+
return False
|
225 |
+
|
226 |
+
if __name__ == "__main__":
|
227 |
+
judge('/home/fazhong/Github/czx2/example/data/fengattack60/feng_attack_echo_60_01_3.150-4.000.wav',0,0)
|
translate/test.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tensorflow.keras.models import Sequential
|
3 |
+
from tensorflow.keras.layers import Dense, Dropout
|
4 |
+
from tensorflow.keras.losses import binary_crossentropy
|
5 |
+
from tensorflow.keras.optimizers import Adam
|
6 |
+
from tensorflow.keras.models import load_model
|
7 |
+
from tensorflow.keras.callbacks import ModelCheckpoint
|
8 |
+
from tensorflow.keras.utils import to_categorical
|
9 |
+
import tensorflow as tf
|
10 |
+
from sklearn.metrics import roc_curve
|
11 |
+
from scipy.interpolate import interp1d
|
12 |
+
from scipy.optimize import brentq
|
13 |
+
import os
|
14 |
+
import random
|
15 |
+
import spacy
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
# == Part 1 - Read data ==
|
18 |
+
|
19 |
+
data = np.load("/home/fazhong/Github/czx/data.npy", allow_pickle=True)
|
20 |
+
labels = np.load("/home/fazhong/Github/czx/labels.npy", allow_pickle=True)
|
21 |
+
texts = np.load("/home/fazhong/Github/czx/texts.npy", allow_pickle=True)
|
22 |
+
commands = [
|
23 |
+
"OK Google.",
|
24 |
+
"Turn on Bluetooth.",
|
25 |
+
"Record a video.",
|
26 |
+
"Take a photo.",
|
27 |
+
"Open music player.",
|
28 |
+
"Set an alarm for 6:30 am.",
|
29 |
+
"Remind me to buy coffee at 7 am.",
|
30 |
+
"What is my schedule for tomorrow?",
|
31 |
+
"Square root of 2105?",
|
32 |
+
"Open browser.",
|
33 |
+
"Decrease volume.",
|
34 |
+
"Turn on flashlight.",
|
35 |
+
"Set the volume to full.",
|
36 |
+
"Mute the volume.",
|
37 |
+
"What's the definition of transmit?",
|
38 |
+
"Call Pizza Hut.",
|
39 |
+
"Call the nearest computer shop.",
|
40 |
+
"Show me my messages.",
|
41 |
+
"Translate please give me directions to Chinese.",
|
42 |
+
"How do you say good night in Japanese?"
|
43 |
+
]
|
44 |
+
commands_basic = [
|
45 |
+
0,# "OK Google.",
|
46 |
+
1,#"Turn on Bluetooth.",
|
47 |
+
5,#"Set an alarm for 6:30 am.",
|
48 |
+
10,#"Decrease volume.",
|
49 |
+
11,#"Turn on flashlight.",
|
50 |
+
12,#"Set the volume to full.",
|
51 |
+
13,#"Mute the volume.",
|
52 |
+
]
|
53 |
+
commands_daily = [
|
54 |
+
2,#"Record a video.",
|
55 |
+
3,#"Take a photo.",
|
56 |
+
4,#"Open music player.",
|
57 |
+
6,#"Remind me to buy coffee at 7 am.",
|
58 |
+
15,#"Call Pizza Hut.",
|
59 |
+
|
60 |
+
]
|
61 |
+
commands_work = [
|
62 |
+
7,#"What is my schedule for tomorrow?",
|
63 |
+
8,#"Square root of 2105?",
|
64 |
+
9,#"Open browser.",
|
65 |
+
14,#"What's the definition of transmit?",
|
66 |
+
16,#"Call the nearest computer shop.",
|
67 |
+
17,#"Show me my messages.",
|
68 |
+
18,#"Translate please give me directions to Chinese.",
|
69 |
+
19,#"How do you say good night in Japanese?"
|
70 |
+
]
|
71 |
+
|
72 |
+
def rule_judge(type,time,location):
|
73 |
+
if type in commands_basic:
|
74 |
+
if time == 0:
|
75 |
+
return False
|
76 |
+
else:
|
77 |
+
return True
|
78 |
+
elif type in commands_daily:
|
79 |
+
if time == 2:
|
80 |
+
return True
|
81 |
+
else:
|
82 |
+
return False
|
83 |
+
elif type in commands_work:
|
84 |
+
if time == 1 and location ==1:
|
85 |
+
return True
|
86 |
+
else:
|
87 |
+
return False
|
88 |
+
|
89 |
+
# 0 - sleep time / 1 - work time / 2 - daily time
|
90 |
+
times_label = [0,1,2]
|
91 |
+
# 0 - home / 1 - factory
|
92 |
+
location_label = [0,1]
|
93 |
+
|
94 |
+
data_all = []
|
95 |
+
data = data.tolist()
|
96 |
+
labels = labels.tolist()
|
97 |
+
texts = texts.tolist()
|
98 |
+
|
99 |
+
acc_num = 0
|
100 |
+
all_num = len(data)
|
101 |
+
atk_list = []
|
102 |
+
atk_err = []
|
103 |
+
name_err = []
|
104 |
+
type_err = []
|
105 |
+
|
106 |
+
gt_label = []
|
107 |
+
pre_label = []
|
108 |
+
|
109 |
+
name_err_num = [0,0,0,0]
|
110 |
+
name_acc_num = [0,0,0,0]
|
111 |
+
command_err_num = []
|
112 |
+
command_acc_num = []
|
113 |
+
for i in range(20):
|
114 |
+
command_err_num.append(0)
|
115 |
+
command_acc_num.append(0)
|
116 |
+
|
117 |
+
for i in range(len(data)):
|
118 |
+
tmp = []
|
119 |
+
tmp.append(np.array(data[i][0]))
|
120 |
+
tmp.extend([labels[i][0]])
|
121 |
+
tmp.extend([labels[i][1]])
|
122 |
+
tmp.extend([labels[i][2]])
|
123 |
+
data_all.append(tmp)
|
124 |
+
data = data_all
|
125 |
+
|
126 |
+
time_labels = []
|
127 |
+
location_labels = []
|
128 |
+
for i in range(len(data)):
|
129 |
+
time_labels.append(random.randint(0,2))
|
130 |
+
location_labels.append(random.randint(0,1))
|
131 |
+
|
132 |
+
rule_err = []
|
133 |
+
|
134 |
+
for i in range(len(data)):
|
135 |
+
if not rule_judge(data[i][2],time_labels[i],location_labels[i]):
|
136 |
+
rule_err.append(i)
|
137 |
+
|
138 |
+
# == Part 2 - Judge of Human ==
|
139 |
+
model = load_model('/home/fazhong/Github/czx/data-task0_1/train1.keras')
|
140 |
+
X = np.asarray([x[0] for x in data])
|
141 |
+
y = np.asarray([x[1] for x in data])
|
142 |
+
type = np.asarray([x[3] for x in data])
|
143 |
+
|
144 |
+
y_pred = model.predict(X)
|
145 |
+
y_pred = y_pred.reshape((len(y_pred), 1))
|
146 |
+
y = y.reshape((len(y), 1))
|
147 |
+
for i in range(len(y)):
|
148 |
+
if(y_pred[i]>0.5):y_pred[i]=1
|
149 |
+
else:
|
150 |
+
y_pred[i] = 0
|
151 |
+
atk_list.append(i)
|
152 |
+
if(y_pred[i]!=y[i]):
|
153 |
+
atk_err.append(i)
|
154 |
+
ACCU = np.sum((y_pred == y)) / len(y)
|
155 |
+
print(len(y))
|
156 |
+
print("ACCU is " + str(100 * ACCU))
|
157 |
+
|
158 |
+
# == Part 3 - Judge of Name ==
|
159 |
+
|
160 |
+
model = load_model('/home/fazhong/Github/czx/data-task0/train1.keras')
|
161 |
+
y_name = np.asarray([x[2] for x in data])
|
162 |
+
y_pred = model.predict(X)
|
163 |
+
y_pred_classes = np.argmax(y_pred,axis=1)
|
164 |
+
ACCU = np.sum((y_pred_classes == y_name)) / len(y_name)
|
165 |
+
for i in range(len(y_name)):
|
166 |
+
if(y_pred_classes[i]!=y_name[i]):
|
167 |
+
name_err.append(i)
|
168 |
+
print("ACCU is " + str(100 * ACCU))
|
169 |
+
|
170 |
+
|
171 |
+
# Part 4 - Transcribe and Judge of Reason
|
172 |
+
|
173 |
+
# PS! Attack的文本不需要跑分类
|
174 |
+
nlp = spacy.load('en_core_web_md')
|
175 |
+
|
176 |
+
|
177 |
+
def classify_key(command):
|
178 |
+
if 'ok google' in command:
|
179 |
+
return 1
|
180 |
+
elif 'okay' in command:
|
181 |
+
return 1
|
182 |
+
elif 'bluetooth' in command:
|
183 |
+
return 2
|
184 |
+
elif 'record' in command and 'video' in command:
|
185 |
+
return 3
|
186 |
+
elif 'take' in command and 'photo' in command:
|
187 |
+
return 4
|
188 |
+
elif 'music' in command:
|
189 |
+
return 5
|
190 |
+
elif 'alarm' in command:
|
191 |
+
return 6
|
192 |
+
elif 'remind' in command and 'coffee' in command:
|
193 |
+
return 7
|
194 |
+
elif 'am' in command :
|
195 |
+
return 7
|
196 |
+
elif 'schedule' in command or 'tomorrow' in command:
|
197 |
+
return 8
|
198 |
+
elif 'square root' in command:
|
199 |
+
return 9
|
200 |
+
elif 'open browser' in command:
|
201 |
+
return 10
|
202 |
+
elif 'decrease volume' in command:
|
203 |
+
return 11
|
204 |
+
elif 'flashlight' in command and 'on' in command:
|
205 |
+
return 12
|
206 |
+
elif 'hello freshlight' in command.lower():
|
207 |
+
return 12
|
208 |
+
elif 'turn on' in command:
|
209 |
+
return 12
|
210 |
+
elif 'volume' in command and 'full' in command:
|
211 |
+
return 13
|
212 |
+
elif 'mute' in command :
|
213 |
+
return 14
|
214 |
+
elif 'move' in command :
|
215 |
+
return 14
|
216 |
+
elif 'more' in command :
|
217 |
+
return 14
|
218 |
+
elif 'motor' in command :
|
219 |
+
return 14
|
220 |
+
elif 'mood' in command :
|
221 |
+
return 14
|
222 |
+
elif 'most' in command :
|
223 |
+
return 14
|
224 |
+
elif 'what' in command :
|
225 |
+
return 14
|
226 |
+
elif 'with' in command :
|
227 |
+
return 14
|
228 |
+
elif 'milk' in command :
|
229 |
+
return 14
|
230 |
+
elif 'use' in command :
|
231 |
+
return 14
|
232 |
+
elif 'definition of' in command:
|
233 |
+
return 15
|
234 |
+
elif 'call' in command and 'pizza hut' in command.lower():
|
235 |
+
return 16
|
236 |
+
elif 'copies are' in command.lower() or 'call a piece of heart' in command.lower() or 'copies of' in command.lower():
|
237 |
+
return 16
|
238 |
+
elif 'peace' in command.lower():
|
239 |
+
return 16
|
240 |
+
elif 'heart' in command.lower():
|
241 |
+
return 16
|
242 |
+
elif 'pisa' in command.lower():
|
243 |
+
return 16
|
244 |
+
elif 'piece' in command.lower():
|
245 |
+
return 16
|
246 |
+
elif 'hard' in command.lower():
|
247 |
+
return 16
|
248 |
+
elif 'call' in command and 'computer shop' in command.lower():
|
249 |
+
return 17
|
250 |
+
elif 'message' in command :
|
251 |
+
return 18
|
252 |
+
elif 'translate' in command:
|
253 |
+
return 19
|
254 |
+
elif 'good night' in command and 'in japanese' in command:
|
255 |
+
return 20
|
256 |
+
else:
|
257 |
+
return None # or some default value if command is not recognized
|
258 |
+
|
259 |
+
correct_count = 0
|
260 |
+
total_count = 0
|
261 |
+
category_number = 0
|
262 |
+
total_normal = 0
|
263 |
+
|
264 |
+
normal_texts = []
|
265 |
+
normal_labels = []
|
266 |
+
|
267 |
+
All_Normal_names = []
|
268 |
+
|
269 |
+
# Test of rule module
|
270 |
+
test_flag = True
|
271 |
+
atk_org_list = []
|
272 |
+
for i in range(len(texts)):
|
273 |
+
if test_flag:
|
274 |
+
normal_texts.append(texts[i])
|
275 |
+
All_Normal_names.append(y_name[i])
|
276 |
+
normal_labels.append(type[i])
|
277 |
+
if y[i] == 0:
|
278 |
+
atk_org_list.append(i)
|
279 |
+
else:
|
280 |
+
if y[i] == 1:
|
281 |
+
normal_texts.append(texts[i])
|
282 |
+
All_Normal_names.append(y_name[i])
|
283 |
+
normal_labels.append(type[i])
|
284 |
+
|
285 |
+
print(len(atk_org_list))
|
286 |
+
# for text in texts:
|
287 |
+
# if texts.index(text) in atk_list:
|
288 |
+
# print(texts.index(text))
|
289 |
+
# continue
|
290 |
+
# else:
|
291 |
+
# normal_texts.append(text)
|
292 |
+
|
293 |
+
weird_name = []
|
294 |
+
weird_command = []
|
295 |
+
|
296 |
+
|
297 |
+
# for i in range(len(data)):
|
298 |
+
# if not rule_judge(data[i][2],time_labels[i],location_labels[i]):
|
299 |
+
# rule_err.append(i)
|
300 |
+
|
301 |
+
for i in range(len(normal_texts)):
|
302 |
+
text = normal_texts[i]
|
303 |
+
category_number = normal_labels[i]
|
304 |
+
# print(text)
|
305 |
+
# print(category_number)
|
306 |
+
|
307 |
+
result_pre = classify_key(text.replace('.', '').replace(',', '').lower().strip())
|
308 |
+
|
309 |
+
# IF rule - judge
|
310 |
+
|
311 |
+
# if not rule_judge(category_number-1,time_labels[i],location_labels[i]):
|
312 |
+
# command_err_num[category_number-1]+=1
|
313 |
+
# name_err_num[All_Normal_names[i]]+=1
|
314 |
+
# continue
|
315 |
+
if i in atk_org_list:
|
316 |
+
command_err_num[category_number-1]+=1
|
317 |
+
name_err_num[All_Normal_names[i]]+=1
|
318 |
+
continue
|
319 |
+
if result_pre is not None:
|
320 |
+
if result_pre == category_number:
|
321 |
+
correct_count += 1
|
322 |
+
command_acc_num[category_number-1]+=1
|
323 |
+
name_acc_num[All_Normal_names[i]]+=1
|
324 |
+
continue
|
325 |
+
input_doc = nlp(text.replace('.', '').replace(',', '').lower().strip())
|
326 |
+
similarities = [(command, input_doc.similarity(nlp(command))) for command in commands]
|
327 |
+
best_match = max(similarities, key=lambda item: item[1])
|
328 |
+
best_match_index = commands.index(best_match[0]) + 1
|
329 |
+
if best_match_index == category_number:
|
330 |
+
correct_count += 1
|
331 |
+
command_acc_num[category_number-1]+=1
|
332 |
+
name_acc_num[All_Normal_names[i]]+=1
|
333 |
+
else:
|
334 |
+
# print(text.replace('.', '').replace(',', '').lower().strip())
|
335 |
+
# if category_number==16:
|
336 |
+
# print(input_doc,commands[category_number-1],commands[best_match_index-1])
|
337 |
+
command_err_num[category_number-1]+=1
|
338 |
+
name_err_num[All_Normal_names[i]]+=1
|
339 |
+
|
340 |
+
|
341 |
+
# if 'thank' in str(input_doc):
|
342 |
+
# pass
|
343 |
+
# # print('?')
|
344 |
+
# # print(texts.index(text))
|
345 |
+
# # print(data[texts.index(text)])
|
346 |
+
# weird_name.append(y_name[texts.index(text)])
|
347 |
+
# weird_command.append(type[texts.index(text)])
|
348 |
+
type_err.append(texts.index(text))
|
349 |
+
|
350 |
+
# 计算正确率
|
351 |
+
accuracy = correct_count / len(normal_texts)
|
352 |
+
print(f"Accuracy: {accuracy:.2f}")
|
353 |
+
|
354 |
+
|
355 |
+
# Part 5 - Results
|
356 |
+
atk_set = set(atk_err)
|
357 |
+
name_set = set(name_err)
|
358 |
+
type_set = set(type_err)
|
359 |
+
#rule_set = set(rule_err)
|
360 |
+
err_list = list(atk_set | name_set | type_set)
|
361 |
+
|
362 |
+
|
363 |
+
print(len(err_list))
|
364 |
+
# print(weird_name)
|
365 |
+
|
366 |
+
print(name_err_num)
|
367 |
+
print(name_acc_num)
|
368 |
+
print(command_err_num)
|
369 |
+
print(command_acc_num)
|
370 |
+
|
371 |
+
# print(weird_command)
|
372 |
+
#print(atk_list)
|
373 |
+
# print(len(atk_list))
|
374 |
+
# print(all_num)
|
375 |
+
# print(atk_err)
|
376 |
+
# print(name_err)
|
377 |
+
# print(type_err)
|
378 |
+
# print(type_set)
|
379 |
+
# print(err_list)
|
380 |
+
|
381 |
+
# # 设置柱状图的位置编号
|
382 |
+
# x = np.arange(len(name_err_num))
|
383 |
+
|
384 |
+
# # 画柱状图
|
385 |
+
# plt.bar(x - 0.2, name_acc_num, width=0.4, label='Correct', color='green')
|
386 |
+
# plt.bar(x + 0.2, name_err_num, width=0.4, label='Error', color='red')
|
387 |
+
|
388 |
+
# # 添加标题和标签
|
389 |
+
# plt.xlabel('Names')
|
390 |
+
# plt.ylabel('Counts')
|
391 |
+
# plt.title('Accuracy and Errors by Name')
|
392 |
+
# plt.xticks(x, ['User1', 'User2', 'User3', 'User4']) # 假设有四个名字
|
393 |
+
# plt.legend()
|
394 |
+
# #plt.savefig('/home/fazhong/Github/czx/user.png')
|
395 |
+
# # 显示图形
|
396 |
+
# plt.close()
|
397 |
+
|
398 |
+
|
399 |
+
# # 设置柱状图的位置编号
|
400 |
+
# x = np.arange(len(command_err_num))
|
401 |
+
|
402 |
+
# # 画柱状图
|
403 |
+
# plt.bar(x - 0.2, command_acc_num, width=0.4, label='Correct', color='blue')
|
404 |
+
# plt.bar(x + 0.2, command_err_num, width=0.4, label='Error', color='orange')
|
405 |
+
|
406 |
+
# # 添加标题和标签
|
407 |
+
# plt.xlabel('Commands')
|
408 |
+
# plt.ylabel('Counts')
|
409 |
+
# plt.title('Accuracy and Errors by Command')
|
410 |
+
# plt.xticks(x, [i for i in range(20)]) # 假设有六个命令
|
411 |
+
# plt.legend()
|
412 |
+
|
413 |
+
# # 显示图形
|
414 |
+
# #plt.savefig('/home/fazhong/Github/czx/com.png')
|
translate/train_man.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This task is running a cross validation.
|
3 |
+
We start from the two-fold validation.
|
4 |
+
"""
|
5 |
+
#%% Import necessary packages and EER function
|
6 |
+
# test the numpy
|
7 |
+
import numpy as np
|
8 |
+
from tensorflow.keras.models import Sequential
|
9 |
+
from tensorflow.keras.layers import Dense, Dropout
|
10 |
+
from tensorflow.keras.losses import binary_crossentropy
|
11 |
+
from tensorflow.keras.optimizers import Adam
|
12 |
+
from tensorflow.keras.callbacks import ModelCheckpoint
|
13 |
+
import tensorflow as tf
|
14 |
+
from sklearn.metrics import roc_curve
|
15 |
+
from scipy.interpolate import interp1d
|
16 |
+
from scipy.optimize import brentq
|
17 |
+
import os
|
18 |
+
import random
|
19 |
+
|
20 |
+
def eer(x_test, y_test, model):
|
21 |
+
preds = model.predict(x_test)
|
22 |
+
fpr, tpr, thresholds = roc_curve(y_test, preds)
|
23 |
+
return brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
24 |
+
|
25 |
+
|
26 |
+
#%%
|
27 |
+
data = np.load("/home/fazhong/Github/czx/data.npy", allow_pickle=True)
|
28 |
+
labels = np.load("/home/fazhong/Github/czx/labels.npy", allow_pickle=True)
|
29 |
+
|
30 |
+
|
31 |
+
data_all = []
|
32 |
+
data = data.tolist()
|
33 |
+
#print(data[0])
|
34 |
+
labels = labels.tolist()
|
35 |
+
for i in range(len(data)):
|
36 |
+
tmp = []
|
37 |
+
tmp.append(np.array(data[i][0]))
|
38 |
+
tmp.extend([labels[i][0]])
|
39 |
+
tmp.extend([labels[i][1]])
|
40 |
+
tmp.extend([labels[i][2]])
|
41 |
+
data_all.append(tmp)
|
42 |
+
random.shuffle(data_all)
|
43 |
+
data = data_all
|
44 |
+
# ?
|
45 |
+
#np.random.shuffle(data)
|
46 |
+
batch_size = 10
|
47 |
+
feature_len = 110
|
48 |
+
loss_function = binary_crossentropy
|
49 |
+
## batch
|
50 |
+
no_epochs = 150
|
51 |
+
optimizer = Adam()
|
52 |
+
verbosity = 1
|
53 |
+
model = Sequential()
|
54 |
+
model.add(Dense(64, input_dim=feature_len, activation='relu'))
|
55 |
+
model.add(Dropout(0.2))
|
56 |
+
model.add(Dense(32, activation='relu'))
|
57 |
+
model.add(Dropout(0.2))
|
58 |
+
model.add(Dense(16, activation='relu'))
|
59 |
+
model.add(Dense(1, activation='sigmoid'))
|
60 |
+
model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])
|
61 |
+
|
62 |
+
#%% training and save the hdf5 file
|
63 |
+
data_train = data[:int(0.5*(len(data)))]
|
64 |
+
print(len(data_train))
|
65 |
+
X1 = np.asarray([x[0] for x in data_train])
|
66 |
+
print(X1.shape)
|
67 |
+
y1 = np.asarray([x[1] for x in data_train])
|
68 |
+
print(y1.shape)
|
69 |
+
data_test = data[int(0.5*(len(data))):]
|
70 |
+
X2 = np.asarray([x[0] for x in data_test])
|
71 |
+
y2 = np.asarray([x[1] for x in data_test])
|
72 |
+
checkpointer = ModelCheckpoint(filepath="./data-task0/train1.keras",
|
73 |
+
verbose=verbosity, save_best_only=True)
|
74 |
+
print('-' * 30)
|
75 |
+
print('Training for whole data set')
|
76 |
+
history = model.fit(X1, y1,
|
77 |
+
# validation_data=(x[test], y[test]),
|
78 |
+
validation_split=0.1,
|
79 |
+
batch_size=batch_size,
|
80 |
+
epochs=no_epochs,
|
81 |
+
verbose=verbosity,
|
82 |
+
callbacks=[checkpointer, tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)]
|
83 |
+
)
|
84 |
+
|
85 |
+
## train for X2
|
86 |
+
checkpointer = ModelCheckpoint(filepath="./data-task0/train2.keras",
|
87 |
+
verbose=verbosity, save_best_only=True)
|
88 |
+
print('-' * 30)
|
89 |
+
print('Training for whole data set')
|
90 |
+
history = model.fit(X2, y2,
|
91 |
+
# validation_data=(x[test], y[test]),
|
92 |
+
validation_split=0.1,
|
93 |
+
batch_size=batch_size,
|
94 |
+
epochs=no_epochs,
|
95 |
+
verbose=verbosity,
|
96 |
+
callbacks=[checkpointer, tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)]
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
#%% calculate the final result.
|
101 |
+
#data_train = np.load("./main_task/data-task0/data1.npy", allow_pickle=True)
|
102 |
+
X1 = np.asarray([x[0] for x in data_train])
|
103 |
+
y1 = np.asarray([x[1] for x in data_train])
|
104 |
+
#data_test = np.load("./main_task/data-task0/data2.npy", allow_pickle=True)
|
105 |
+
X2 = np.asarray([x[0] for x in data_test])
|
106 |
+
y2 = np.asarray([x[1] for x in data_test])
|
107 |
+
|
108 |
+
|
109 |
+
model.load_weights("./data-task0/train1.keras")
|
110 |
+
scores = model.evaluate(X2, y2)
|
111 |
+
y_pred2 = model.predict(X2)
|
112 |
+
print(y_pred2.shape)
|
113 |
+
|
114 |
+
model.load_weights("./data-task0/train2.keras")
|
115 |
+
scores = model.evaluate(X1, y1)
|
116 |
+
y_pred1 = model.predict(X1)
|
117 |
+
|
118 |
+
y_pred = np.concatenate((y_pred1, y_pred2))
|
119 |
+
y_pred = y_pred.reshape((len(y_pred), 1))
|
120 |
+
y_label = np.concatenate((y1, y2))
|
121 |
+
y_label = y_label.reshape((len(y_label), 1))
|
122 |
+
for i in range(len(y_label)):
|
123 |
+
if(y_pred[i]>0.5):y_pred[i]=1
|
124 |
+
else:y_pred[i] = 0
|
125 |
+
ACCU = np.sum((y_pred == y_label)) / len(y_label)
|
126 |
+
print("ACCU is " + str(100 * ACCU))
|
127 |
+
fpr, tpr, thresholds = roc_curve(y_label, y_pred)
|
128 |
+
EER = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
129 |
+
print(EER)
|
130 |
+
|
131 |
+
|
132 |
+
# #%% calculate the final result.
|
133 |
+
# num_all = np.zeros((20, 1))
|
134 |
+
# num_success = np.zeros((20, 1))
|
135 |
+
# for user_num in range(1, 21, 1):
|
136 |
+
# # testing the data on the train1.hdf5
|
137 |
+
# model.load_weights("./data-task0/train1.keras")
|
138 |
+
# print("user number is " + str(user_num))
|
139 |
+
# X_test = np.asarray([x[0] for x in data_test if (x[5] == user_num and x[1] == 0)])
|
140 |
+
# y_test = np.asarray([x[1] for x in data_test if (x[5] == user_num and x[1] == 0)])
|
141 |
+
# scores = model.evaluate(X_test, y_test)
|
142 |
+
# num_all[user_num - 1] += len(y_test)
|
143 |
+
# num_success[user_num - 1] += np.round(len(y_test)*scores[1])
|
144 |
+
# for user_num in range(1, 21, 1):
|
145 |
+
# # testing the data on the train2.hdf5
|
146 |
+
# model.load_weights("./data-task0/train2.keras")
|
147 |
+
# print("user number is " + str(user_num))
|
148 |
+
# X_test = np.asarray([x[0] for x in data_train if (x[5] == user_num and x[1] == 0)])
|
149 |
+
# y_test = np.asarray([x[1] for x in data_train if (x[5] == user_num and x[1] == 0)])
|
150 |
+
# scores = model.evaluate(X_test, y_test)
|
151 |
+
# num_all[user_num - 1] += len(y_test)
|
152 |
+
# num_success[user_num - 1] += np.round(len(y_test)*scores[1])
|
153 |
+
|
154 |
+
# #%% show the results
|
155 |
+
# for user_num in range(1, 21, 1):
|
156 |
+
# print("user number is " + str(user_num))
|
157 |
+
# print("[=========] total number is " + str(int(num_all[user_num - 1])) + ", and wrong detect " + str(int(num_all[user_num - 1] - num_success[user_num - 1]))
|
158 |
+
# + " samples, rate is " + str(np.round(num_success[user_num - 1] / num_all[user_num - 1], 4)))
|
159 |
+
# print("total number is " + str(int(np.sum(num_all))) + ", and detect " + str(int(np.sum(num_all) - np.sum(num_success)))
|
160 |
+
# + " samples, rate is " + str((np.sum(num_success) / np.sum(num_all))))
|
translate/train_name.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This task is running a cross validation.
|
3 |
+
We start from the two-fold validation.
|
4 |
+
"""
|
5 |
+
#%% Import necessary packages and EER function
|
6 |
+
# test the numpy
|
7 |
+
import numpy as np
|
8 |
+
from tensorflow.keras.models import Sequential
|
9 |
+
from tensorflow.keras.layers import Dense, Dropout
|
10 |
+
from tensorflow.keras.losses import categorical_crossentropy
|
11 |
+
from tensorflow.keras.optimizers import Adam
|
12 |
+
from tensorflow.keras.callbacks import ModelCheckpoint
|
13 |
+
from tensorflow.keras.utils import to_categorical
|
14 |
+
import tensorflow as tf
|
15 |
+
from sklearn.metrics import roc_curve
|
16 |
+
from scipy.interpolate import interp1d
|
17 |
+
from scipy.optimize import brentq
|
18 |
+
import os
|
19 |
+
import random
|
20 |
+
|
21 |
+
def eer(x_test, y_test, model):
|
22 |
+
preds = model.predict(x_test)
|
23 |
+
fpr, tpr, thresholds = roc_curve(y_test, preds)
|
24 |
+
return brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
25 |
+
|
26 |
+
|
27 |
+
#%%
|
28 |
+
data = np.load("/home/fazhong/Github/czx/data.npy", allow_pickle=True)
|
29 |
+
labels = np.load("/home/fazhong/Github/czx/labels.npy", allow_pickle=True)
|
30 |
+
|
31 |
+
|
32 |
+
data_all = []
|
33 |
+
data = data.tolist()
|
34 |
+
#print(data[0])
|
35 |
+
labels = labels.tolist()
|
36 |
+
for i in range(len(data)):
|
37 |
+
tmp = []
|
38 |
+
tmp.append(np.array(data[i][0]))
|
39 |
+
tmp.extend(to_categorical([labels[i][1]],num_classes=4).tolist())
|
40 |
+
tmp.extend([labels[i][0]])
|
41 |
+
tmp.extend([labels[i][2]])
|
42 |
+
data_all.append(tmp)
|
43 |
+
random.shuffle(data_all)
|
44 |
+
data = data_all
|
45 |
+
# ?
|
46 |
+
#print(data)
|
47 |
+
#np.random.shuffle(data)
|
48 |
+
batch_size = 10
|
49 |
+
feature_len = 110
|
50 |
+
loss_function = categorical_crossentropy
|
51 |
+
## batch
|
52 |
+
no_epochs = 150
|
53 |
+
optimizer = Adam()
|
54 |
+
verbosity = 1
|
55 |
+
model = Sequential()
|
56 |
+
model.add(Dense(64, input_dim=feature_len, activation='relu'))
|
57 |
+
model.add(Dropout(0.2))
|
58 |
+
model.add(Dense(32, activation='relu'))
|
59 |
+
model.add(Dropout(0.2))
|
60 |
+
model.add(Dense(16, activation='relu'))
|
61 |
+
model.add(Dense(4, activation='softmax'))
|
62 |
+
model.compile(loss=loss_function, optimizer=optimizer, metrics=['accuracy'])
|
63 |
+
|
64 |
+
#%% training and save the hdf5 file
|
65 |
+
data_train = data[:int(0.5*(len(data)))]
|
66 |
+
print(len(data_train))
|
67 |
+
X1 = np.asarray([x[0] for x in data_train])
|
68 |
+
print(X1.shape)
|
69 |
+
temp = [x[1] for x in data_train]
|
70 |
+
print(len(temp))
|
71 |
+
print(len(temp[1]))
|
72 |
+
y1 = np.asarray([x[1] for x in data_train])
|
73 |
+
print(y1.shape)
|
74 |
+
data_test = data[int(0.5*(len(data))):]
|
75 |
+
X2 = np.asarray([x[0] for x in data_test])
|
76 |
+
y2 = np.asarray([x[1] for x in data_test])
|
77 |
+
checkpointer = ModelCheckpoint(filepath="./data-task0/train1.keras",
|
78 |
+
verbose=verbosity, save_best_only=True)
|
79 |
+
print('-' * 30)
|
80 |
+
print('Training for whole data set')
|
81 |
+
history = model.fit(X1, y1,
|
82 |
+
# validation_data=(x[test], y[test]),
|
83 |
+
validation_split=0.1,
|
84 |
+
batch_size=batch_size,
|
85 |
+
epochs=no_epochs,
|
86 |
+
verbose=verbosity,
|
87 |
+
callbacks=[checkpointer, tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)]
|
88 |
+
)
|
89 |
+
|
90 |
+
## train for X2
|
91 |
+
checkpointer = ModelCheckpoint(filepath="./data-task0/train2.keras",
|
92 |
+
verbose=verbosity, save_best_only=True)
|
93 |
+
print('-' * 30)
|
94 |
+
print('Training for whole data set')
|
95 |
+
history = model.fit(X2, y2,
|
96 |
+
# validation_data=(x[test], y[test]),
|
97 |
+
validation_split=0.1,
|
98 |
+
batch_size=batch_size,
|
99 |
+
epochs=no_epochs,
|
100 |
+
verbose=verbosity,
|
101 |
+
callbacks=[checkpointer, tf.keras.callbacks.EarlyStopping(monitor='loss', patience=7)]
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
#%% calculate the final result.
|
106 |
+
#data_train = np.load("./main_task/data-task0/data1.npy", allow_pickle=True)
|
107 |
+
X1 = np.asarray([x[0] for x in data_train])
|
108 |
+
y1 = np.asarray([x[1] for x in data_train])
|
109 |
+
#data_test = np.load("./main_task/data-task0/data2.npy", allow_pickle=True)
|
110 |
+
X2 = np.asarray([x[0] for x in data_test])
|
111 |
+
y2 = np.asarray([x[1] for x in data_test])
|
112 |
+
|
113 |
+
|
114 |
+
model.load_weights("./data-task0/train1.keras")
|
115 |
+
scores = model.evaluate(X2, y2)
|
116 |
+
y_pred2 = model.predict(X2)
|
117 |
+
print(y_pred2.shape)
|
118 |
+
|
119 |
+
model.load_weights("./data-task0/train2.keras")
|
120 |
+
scores = model.evaluate(X1, y1)
|
121 |
+
y_pred1 = model.predict(X1)
|
122 |
+
|
123 |
+
y_pred = np.concatenate((y_pred1, y_pred2))
|
124 |
+
y_pred_classes = np.argmax(y_pred,axis=1)
|
125 |
+
y_label_classes = np.argmax(np.concatenate((y1, y2)),axis=1)
|
126 |
+
print(y_pred_classes)
|
127 |
+
ACCU = np.sum((y_pred_classes == y_label_classes)) / len(y_label_classes)
|
128 |
+
print("ACCU is " + str(100 * ACCU))
|
129 |
+
fpr, tpr, thresholds = roc_curve(y_label_classes, y_pred)
|
130 |
+
EER = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
|
131 |
+
print(EER)
|
132 |
+
|
133 |
+
|
134 |
+
# #%% calculate the final result.
|
135 |
+
# num_all = np.zeros((20, 1))
|
136 |
+
# num_success = np.zeros((20, 1))
|
137 |
+
# for user_num in range(1, 21, 1):
|
138 |
+
# # testing the data on the train1.hdf5
|
139 |
+
# model.load_weights("./data-task0/train1.keras")
|
140 |
+
# print("user number is " + str(user_num))
|
141 |
+
# X_test = np.asarray([x[0] for x in data_test if (x[5] == user_num and x[1] == 0)])
|
142 |
+
# y_test = np.asarray([x[1] for x in data_test if (x[5] == user_num and x[1] == 0)])
|
143 |
+
# scores = model.evaluate(X_test, y_test)
|
144 |
+
# num_all[user_num - 1] += len(y_test)
|
145 |
+
# num_success[user_num - 1] += np.round(len(y_test)*scores[1])
|
146 |
+
# for user_num in range(1, 21, 1):
|
147 |
+
# # testing the data on the train2.hdf5
|
148 |
+
# model.load_weights("./data-task0/train2.keras")
|
149 |
+
# print("user number is " + str(user_num))
|
150 |
+
# X_test = np.asarray([x[0] for x in data_train if (x[5] == user_num and x[1] == 0)])
|
151 |
+
# y_test = np.asarray([x[1] for x in data_train if (x[5] == user_num and x[1] == 0)])
|
152 |
+
# scores = model.evaluate(X_test, y_test)
|
153 |
+
# num_all[user_num - 1] += len(y_test)
|
154 |
+
# num_success[user_num - 1] += np.round(len(y_test)*scores[1])
|
155 |
+
|
156 |
+
# #%% show the results
|
157 |
+
# for user_num in range(1, 21, 1):
|
158 |
+
# print("user number is " + str(user_num))
|
159 |
+
# print("[=========] total number is " + str(int(num_all[user_num - 1])) + ", and wrong detect " + str(int(num_all[user_num - 1] - num_success[user_num - 1]))
|
160 |
+
# + " samples, rate is " + str(np.round(num_success[user_num - 1] / num_all[user_num - 1], 4)))
|
161 |
+
# print("total number is " + str(int(np.sum(num_all))) + ", and detect " + str(int(np.sum(num_all) - np.sum(num_success)))
|
162 |
+
# + " samples, rate is " + str((np.sum(num_success) / np.sum(num_all))))
|
translate/wav2com.py
ADDED
File without changes
|
translate/wav2npy.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from tensorflow.keras.models import Sequential
|
3 |
+
from tensorflow.keras.layers import Dense, Dropout
|
4 |
+
from tensorflow.keras.losses import binary_crossentropy
|
5 |
+
from tensorflow.keras.optimizers import Adam
|
6 |
+
from sklearn.metrics import roc_curve
|
7 |
+
from scipy.interpolate import interp1d
|
8 |
+
from scipy.optimize import brentq
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from scipy.io.wavfile import read
|
11 |
+
from sklearn.preprocessing import normalize
|
12 |
+
from generate_array_feature import mald_feature, get_filelist
|
13 |
+
import time
|
14 |
+
import os
|
15 |
+
from pydub import AudioSegment
|
16 |
+
import whisper
|
17 |
+
folder_path = '/home/fazhong/Github/czx2/example/data'
|
18 |
+
names = ['feng','jc','meng','zhan']
|
19 |
+
types = ['01','02','03','04','05','06','07','08','09','09','10','11','12','13','14','15','16','17','18','19','20']
|
20 |
+
voice = []
|
21 |
+
|
22 |
+
def convert_6ch_wav_to_stereo(input_file_path, output_file_path):
|
23 |
+
sound = AudioSegment.from_file(input_file_path, format="wav")
|
24 |
+
if sound.channels != 6:
|
25 |
+
raise ValueError("The input file does not have 6 channels.")
|
26 |
+
front_left = sound.split_to_mono()[0]
|
27 |
+
front_right = sound.split_to_mono()[1]
|
28 |
+
center = sound.split_to_mono()[2]
|
29 |
+
back_left = sound.split_to_mono()[4]
|
30 |
+
back_right = sound.split_to_mono()[5]
|
31 |
+
center = center - 6
|
32 |
+
back_left = back_left - 6
|
33 |
+
back_right = back_right - 6
|
34 |
+
stereo_left = front_left.overlay(center).overlay(back_left)
|
35 |
+
stereo_right = front_right.overlay(center).overlay(back_right)
|
36 |
+
stereo_sound = AudioSegment.from_mono_audiosegments(stereo_left, stereo_right)
|
37 |
+
stereo_sound.export(output_file_path, format="wav")
|
38 |
+
|
39 |
+
def read_all_files(directory):
|
40 |
+
data = []
|
41 |
+
labels = []
|
42 |
+
texts = []
|
43 |
+
whisper_model = whisper.load_model("large")
|
44 |
+
out_path='/home/fazhong/Github/czx/temp/temp.wav'
|
45 |
+
i=0
|
46 |
+
for root, dirs, files in os.walk(directory):
|
47 |
+
|
48 |
+
for file in files:
|
49 |
+
#if i > 10:return data,labels,texts
|
50 |
+
content = []
|
51 |
+
content_label = []
|
52 |
+
file_path = os.path.join(root, file)
|
53 |
+
convert_6ch_wav_to_stereo(file_path,out_path)
|
54 |
+
result = whisper_model.transcribe(out_path,language="en")
|
55 |
+
text_result = result['text']
|
56 |
+
texts.append(text_result)
|
57 |
+
print(file)
|
58 |
+
if 'normal' in file:
|
59 |
+
label = 1 # normal case
|
60 |
+
elif 'attack' in file:
|
61 |
+
label = 0
|
62 |
+
for name in names:
|
63 |
+
if name in file:
|
64 |
+
name_index = names.index(name)
|
65 |
+
if label == 0:
|
66 |
+
category_number = int(file.split('_')[4])
|
67 |
+
elif label == 1:
|
68 |
+
category_number = int(file.split('_')[3])
|
69 |
+
|
70 |
+
rate, wavdata = read(file_path)
|
71 |
+
content.append(list(mald_feature(rate, wavdata)))
|
72 |
+
content_label.append(label)
|
73 |
+
content_label.append(name_index)
|
74 |
+
content_label.append(category_number)
|
75 |
+
data.append(content)
|
76 |
+
labels.append(content_label)
|
77 |
+
i+=1
|
78 |
+
return data,labels,texts
|
79 |
+
|
80 |
+
# 调用函数
|
81 |
+
data,labels,texts = read_all_files(folder_path)
|
82 |
+
data_array = np.array(data)
|
83 |
+
labels_array = np.array(labels)
|
84 |
+
texts_array = np.array(texts)
|
85 |
+
filename = 'data.npy'
|
86 |
+
filename2 = 'labels.npy'
|
87 |
+
filename3 = 'texts.npy'
|
88 |
+
np.save(filename, data_array)
|
89 |
+
np.save(filename2, labels_array)
|
90 |
+
np.save(filename3, texts_array)
|
91 |
+
print('fin')
|
92 |
+
# #%% 导入音频
|
93 |
+
# path_wave = r"/home/fazhong/Github/czx/voice"
|
94 |
+
# print("Loading data ...")
|
95 |
+
# name_all = get_filelist(path_wave)
|
96 |
+
# voice = []
|
97 |
+
# # voice 是从 一堆 wav 音频文件中提取的波形
|
98 |
+
# X = [] # X is the feature ~ data[0]
|
99 |
+
# y = [] # y is the normal (1) or attack (0) ~ data[1]
|
100 |
+
|
101 |
+
# for file_path in name_all:
|
102 |
+
# file_name = file_path.split("\\")[-1]
|
103 |
+
# # define the normal or attack in variable cur_y
|
104 |
+
# if 'normal' in file_name:
|
105 |
+
# cur_y = 1 # normal case
|
106 |
+
# elif 'attack' in file_name:
|
107 |
+
# cur_y = 0
|
108 |
+
# # split the file name
|
109 |
+
# # read the data
|
110 |
+
# rate, data = read(file_path)
|
111 |
+
# voice += [list(data)]
|
112 |
+
|
113 |
+
# X += [list(mald_feature(rate, data))]
|
114 |
+
# y += [cur_y]
|
115 |
+
|
116 |
+
# norm_X = normalize(X, axis=0, norm='max')
|
117 |
+
# X = np.asarray(norm_X)
|
118 |
+
# y = np.asarray(y)
|