Spaces:
Runtime error
Runtime error
import torch | |
import fastapi | |
import numpy as np | |
from PIL import Image | |
from typing import Any, Type | |
class TorchTensor(torch.Tensor): | |
pass | |
class Prediction: | |
prediction: TorchTensor | |
app = fastapi.FastAPI() | |
model = torch.load("model67.bin", map_location='cpu') | |
# Define a function to preprocess the input image | |
def preprocess_input(input: Any): | |
image = Image.open(BytesIO(input)) | |
image = image.resize((224, 224)) | |
input = np.array(image) | |
input = torch.from_numpy(input).float() | |
input = input.permute(2, 0, 1) | |
input = input.unsqueeze(0) | |
return input | |
# Define an endpoint to make predictions | |
async def predict_endpoint(input: Any): | |
"""Make a prediction on an image uploaded by the user.""" | |
# Preprocess the input image | |
input = preprocess_input(input) | |
# Make a prediction | |
prediction = model(input) | |
# Get the predicted class | |
predicted_class = prediction.argmax(1).item() | |
# Return the predicted class in JSON format | |
return {"prediction": predicted_class} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |