Spaces:
Runtime error
Runtime error
# | |
import sys | |
sys.path.append('..') | |
import os | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import wget | |
# torch | |
import torch | |
# utils | |
from preprocess import process_from_filename, process_raw_wav | |
from cnn import CNNetwork | |
# load model | |
model = CNNetwork() | |
state_dict = torch.load("../models/void_demo.pth") | |
model.load_state_dict(state_dict) | |
print(f"Model loaded! \n {model}") | |
# /predict input | |
# class Data(BaseModel): | |
# wav: | |
app = FastAPI() | |
async def root(): | |
return { "message": "Hello World" } | |
def url_predict(url: str): | |
filename = wget.download(url) | |
wav = process_from_filename(filename) | |
print(f"\ntest {wav.shape}\n") | |
model_prediction = model_predict(wav) | |
return model_prediction["predicition_index"] | |
def predict(wav): | |
print(f"wav {wav}") | |
# return wav | |
wav = process_raw_wav(wav) | |
model_prediction = model_predict(wav) | |
return { | |
"message": "Voiced Identified!", | |
"data": model_prediction, | |
} | |
def model_predict(wav): | |
model_input = wav.unsqueeze(0) | |
output = model(model_input) | |
prediction = torch.argmax(output, 1).item() | |
return { | |
"output": output, | |
"prediction_index": prediction, | |
} |