amanmibra commited on
Commit
77d5702
1 Parent(s): 89ab250

Update model and API

Browse files
cnn.py CHANGED
@@ -52,7 +52,7 @@ class CNNetwork(nn.Module):
52
  nn.MaxPool2d(kernel_size=2)
53
  )
54
  self.flatten = nn.Flatten()
55
- self.linear = nn.Linear(128 * 9 * 31, 3)
56
  self.softmax = nn.Softmax(dim=1)
57
 
58
  def forward(self, input_data):
 
52
  nn.MaxPool2d(kernel_size=2)
53
  )
54
  self.flatten = nn.Flatten()
55
+ self.linear = nn.Linear(128 * 9 * 19, 3)
56
  self.softmax = nn.Softmax(dim=1)
57
 
58
  def forward(self, input_data):
models/aisf/void_20230517_102846.pth ADDED
Binary file (655 kB). View file
 
server/main.py CHANGED
@@ -14,7 +14,7 @@ from cnn import CNNetwork
14
 
15
  # load model
16
  model = CNNetwork()
17
- state_dict = torch.load("../models/aisf/void_20230516_193200.pth")
18
  model.load_state_dict(state_dict)
19
 
20
  # TODO: update to grabbing labels stored on model
@@ -54,8 +54,11 @@ def model_predict(wav):
54
  model_input = wav.unsqueeze(0)
55
  output = model(model_input)
56
  prediction_index = torch.argmax(output, 1).item()
 
57
 
58
  return {
59
  "prediction_index": prediction_index,
60
- "prediciton": LABELS[prediction_index],
 
 
61
  }
 
14
 
15
  # load model
16
  model = CNNetwork()
17
+ state_dict = torch.load("../models/aisf/void_20230517_102846.pth")
18
  model.load_state_dict(state_dict)
19
 
20
  # TODO: update to grabbing labels stored on model
 
54
  model_input = wav.unsqueeze(0)
55
  output = model(model_input)
56
  prediction_index = torch.argmax(output, 1).item()
57
+ output = output.detach().cpu().numpy()[0]
58
 
59
  return {
60
  "prediction_index": prediction_index,
61
+ "labels": LABELS,
62
+ "prediction_label": LABELS[prediction_index],
63
+ "prediction_output": output.tolist(),
64
  }
server/preprocess.py CHANGED
@@ -9,6 +9,7 @@ from scipy.io import wavfile
9
  import wget
10
 
11
  DEFAULT_SAMPLE_RATE=48000
 
12
 
13
  def process_from_url(url):
14
  # download UI audio
@@ -26,7 +27,7 @@ def process_from_url(url):
26
  return spec
27
 
28
 
29
- def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=5):
30
  wav, sample_rate = torchaudio.load(filename)
31
 
32
  wav = process_raw_wav(wav, sample_rate, target_sample_rate, wav_length)
@@ -35,7 +36,7 @@ def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_
35
 
36
  return spec
37
 
38
- def process_raw_wav(wav, sample_rate=DEFAULT_SAMPLE_RATE, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=5):
39
  num_samples = wav_length * target_sample_rate
40
 
41
  wav = _resample(wav, sample_rate, target_sample_rate)
 
9
  import wget
10
 
11
  DEFAULT_SAMPLE_RATE=48000
12
+ DEFAULT_WAVE_LENGTH=3
13
 
14
  def process_from_url(url):
15
  # download UI audio
 
27
  return spec
28
 
29
 
30
+ def process_from_filename(filename, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=DEFAULT_WAVE_LENGTH):
31
  wav, sample_rate = torchaudio.load(filename)
32
 
33
  wav = process_raw_wav(wav, sample_rate, target_sample_rate, wav_length)
 
36
 
37
  return spec
38
 
39
+ def process_raw_wav(wav, sample_rate=DEFAULT_SAMPLE_RATE, target_sample_rate=DEFAULT_SAMPLE_RATE, wav_length=DEFAULT_WAVE_LENGTH):
40
  num_samples = wav_length * target_sample_rate
41
 
42
  wav = _resample(wav, sample_rate, target_sample_rate)