Spaces:
Runtime error
Runtime error
# | |
import sys | |
sys.path.append('..') | |
import os | |
from fastapi import FastAPI | |
# torch | |
import torch | |
# utils | |
from preprocess import process_from_filename, process_from_url, process_raw_wav | |
from cnn import CNNetwork | |
# load model | |
model = CNNetwork() | |
state_dict = torch.load("../models/aisf/void_20230517_113634.pth") | |
model.load_state_dict(state_dict) | |
# TODO: update to grabbing labels stored on model | |
LABELS = ["shafqat", "aman", "jake"] | |
print(f"Model loaded! \n {model}") | |
app = FastAPI() | |
async def root(): | |
return { "message": "Hello World" } | |
def url_predict(url: str): | |
wav = process_from_url(url) | |
model_prediction = model_predict(wav) | |
return { | |
"message": "Voice Identified!", | |
"data": model_prediction, | |
} | |
def predict(wav): | |
print(f"wav {wav}") | |
# return wav | |
wav = process_raw_wav(wav) | |
model_prediction = model_predict(wav) | |
return { | |
"message": "Voice Identified!", | |
"data": model_prediction, | |
} | |
def model_predict(wav): | |
model_input = wav.unsqueeze(0) | |
output = model(model_input) | |
prediction_index = torch.argmax(output, 1).item() | |
output = output.detach().cpu().numpy()[0] | |
return { | |
"prediction_index": prediction_index, | |
"labels": LABELS, | |
"prediction_label": LABELS[prediction_index], | |
"prediction_output": output.tolist(), | |
} |