# import sys sys.path.append('..') import os from fastapi import FastAPI # torch import torch # utils from preprocess import process_from_filename, process_from_url, process_raw_wav from cnn import CNNetwork # load model model = CNNetwork() state_dict = torch.load("../models/void_demo.pth") model.load_state_dict(state_dict) print(f"Model loaded! \n {model}") app = FastAPI() @app.get("/") async def root(): return { "message": "Hello World" } @app.get("/urlpredict") def url_predict(url: str): wav = process_from_url(url) model_prediction = model_predict(wav) return { "message": "Voice Identified!", "data": model_prediction, } @app.put("/predict") def predict(wav): print(f"wav {wav}") # return wav wav = process_raw_wav(wav) model_prediction = model_predict(wav) return { "message": "Voice Identified!", "data": model_prediction, } def model_predict(wav): model_input = wav.unsqueeze(0) output = model(model_input) prediction_index = torch.argmax(output, 1).item() return { "prediction_index": prediction_index, }