amanmibra commited on
Commit
f352ab2
1 Parent(s): 3e88903

Add predict endpoints

Browse files
Files changed (1) hide show
  1. server/main.py +58 -1
server/main.py CHANGED
@@ -1,7 +1,64 @@
 
 
 
 
 
1
  from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
  @app.get("/")
6
  async def root():
7
- return { "message": "Hello World" }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ import sys
3
+ sys.path.append('..')
4
+
5
+ import os
6
  from fastapi import FastAPI
7
+ from pydantic import BaseModel
8
+ import wget
9
+
10
+ # torch
11
+ import torch
12
+
13
+ # utils
14
+ from preprocess import process_from_filename, process_raw_wav
15
+ from cnn import CNNetwork
16
+
17
+ # load model
18
+ model = CNNetwork()
19
+ state_dict = torch.load("../models/void_demo.pth")
20
+ model.load_state_dict(state_dict)
21
+
22
+ print(f"Model loaded! \n {model}")
23
+
24
+ # /predict input
25
+ # class Data(BaseModel):
26
+ # wav:
27
+
28
 
29
  app = FastAPI()
30
 
31
  @app.get("/")
32
  async def root():
33
+ return { "message": "Hello World" }
34
+
35
+ @app.get("/urlpredict")
36
+ def url_predict(url: str):
37
+ filename = wget.download(url)
38
+ wav = process_from_filename(filename)
39
+ print(f"\ntest {wav.shape}\n")
40
+
41
+ model_prediction = model_predict(wav)
42
+ return model_prediction["predicition_index"]
43
+
44
+ @app.put("/predict")
45
+ def predict(wav):
46
+ print(f"wav {wav}")
47
+ # return wav
48
+ wav = process_raw_wav(wav)
49
+ model_prediction = model_predict(wav)
50
+
51
+ return {
52
+ "message": "Voiced Identified!",
53
+ "data": model_prediction,
54
+ }
55
+
56
+ def model_predict(wav):
57
+ model_input = wav.unsqueeze(0)
58
+ output = model(model_input)
59
+ prediction = torch.argmax(output, 1).item()
60
+
61
+ return {
62
+ "output": output,
63
+ "prediction_index": prediction,
64
+ }