Spaces:
Runtime error
Runtime error
Update model and API
Browse files- cnn.py +1 -1
- models/aisf/void_20230517_102846.pth +0 -0
- server/main.py +5 -2
- server/preprocess.py +3 -2
cnn.py
CHANGED
@@ -52,7 +52,7 @@ class CNNetwork(nn.Module):
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
-
self.linear = nn.Linear(128 * 9 *
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
|
|
52 |
nn.MaxPool2d(kernel_size=2)
|
53 |
)
|
54 |
self.flatten = nn.Flatten()
|
55 |
+
self.linear = nn.Linear(128 * 9 * 19, 3)
|
56 |
self.softmax = nn.Softmax(dim=1)
|
57 |
|
58 |
def forward(self, input_data):
|
models/aisf/void_20230517_102846.pth
ADDED
Binary file (655 kB). View file
|
|
server/main.py
CHANGED
@@ -14,7 +14,7 @@ from cnn import CNNetwork
|
|
14 |
|
15 |
# load model
|
16 |
model = CNNetwork()
|
17 |
-
state_dict = torch.load("../models/aisf/
|
18 |
model.load_state_dict(state_dict)
|
19 |
|
20 |
# TODO: update to grabbing labels stored on model
|
@@ -54,8 +54,11 @@ def model_predict(wav):
|
|
54 |
model_input = wav.unsqueeze(0)
|
55 |
output = model(model_input)
|
56 |
prediction_index = torch.argmax(output, 1).item()
|
|
|
57 |
|
58 |
return {
|
59 |
"prediction_index": prediction_index,
|
60 |
-
"
|
|
|
|
|
61 |
}
|
|
|
14 |
|
15 |
# load model
|
16 |
model = CNNetwork()
|
17 |
+
state_dict = torch.load("../models/aisf/void_20230517_102846.pth")
|
18 |
model.load_state_dict(state_dict)
|
19 |
|
20 |
# TODO: update to grabbing labels stored on model
|
|
|
54 |
model_input = wav.unsqueeze(0)
|
55 |
output = model(model_input)
|
56 |
prediction_index = torch.argmax(output, 1).item()
|
57 |
+
output = output.detach().cpu().numpy()[0]
|
58 |
|
59 |
return {
|
60 |
"prediction_index": prediction_index,
|
61 |
+
"labels": LABELS,
|
62 |
+
"prediction_label": LABELS[prediction_index],
|
63 |
+
"prediction_output": output.tolist(),
|
64 |
}
|
server/preprocess.py
CHANGED
@@ -9,6 +9,7 @@ from scipy.io import wavfile
|
|
9 |
import wget
|
10 |
|
11 |
DEFAULT_SAMPLE_RATE=48000
|
|
|
12 |
|
13 |
def process_from_url(url):
|
14 |
# download UI audio
|
@@ -26,7 +27,7 @@ def process_from_url(url):
|
|
26 |
return spec
|
27 |
|
28 |
|
29 |
-
def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=
|
30 |
wav, sample_rate = torchaudio.load(filename)
|
31 |
|
32 |
wav = process_raw_wav(wav, sample_rate, target_sample_rate, wav_length)
|
@@ -35,7 +36,7 @@ def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_
|
|
35 |
|
36 |
return spec
|
37 |
|
38 |
-
def process_raw_wav(wav, sample_rate=DEFAULT_SAMPLE_RATE, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=
|
39 |
num_samples = wav_length * target_sample_rate
|
40 |
|
41 |
wav = _resample(wav, sample_rate, target_sample_rate)
|
|
|
9 |
import wget
|
10 |
|
11 |
DEFAULT_SAMPLE_RATE=48000
|
12 |
+
DEFAULT_WAVE_LENGTH=3
|
13 |
|
14 |
def process_from_url(url):
|
15 |
# download UI audio
|
|
|
27 |
return spec
|
28 |
|
29 |
|
30 |
+
def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=DEFAULT_WAVE_LENGTH):
|
31 |
wav, sample_rate = torchaudio.load(filename)
|
32 |
|
33 |
wav = process_raw_wav(wav, sample_rate, target_sample_rate, wav_length)
|
|
|
36 |
|
37 |
return spec
|
38 |
|
39 |
+
def process_raw_wav(wav, sample_rate=DEFAULT_SAMPLE_RATE, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=DEFAULT_WAVE_LENGTH):
|
40 |
num_samples = wav_length * target_sample_rate
|
41 |
|
42 |
wav = _resample(wav, sample_rate, target_sample_rate)
|