# import sys sys.path.append('..') import os from fastapi import FastAPI from pydantic import BaseModel import wget # torch import torch # utils from preprocess import process_from_filename, process_raw_wav from cnn import CNNetwork # load model model = CNNetwork() state_dict = torch.load("../models/void_demo.pth") model.load_state_dict(state_dict) print(f"Model loaded! \n {model}") # /predict input # class Data(BaseModel): # wav: app = FastAPI() @app.get("/") async def root(): return { "message": "Hello World" } @app.get("/urlpredict") def url_predict(url: str): filename = wget.download(url) wav = process_from_filename(filename) print(f"\ntest {wav.shape}\n") model_prediction = model_predict(wav) return model_prediction["predicition_index"] @app.put("/predict") def predict(wav): print(f"wav {wav}") # return wav wav = process_raw_wav(wav) model_prediction = model_predict(wav) return { "message": "Voiced Identified!", "data": model_prediction, } def model_predict(wav): model_input = wav.unsqueeze(0) output = model(model_input) prediction = torch.argmax(output, 1).item() return { "output": output, "prediction_index": prediction, }