Spaces:
Runtime error
Runtime error
# app.py | |
from fastapi import FastAPI | |
import torch | |
import torch.nn as nn | |
import torch | |
from torchvision import transforms | |
from typing import Any, Type | |
import torch | |
class TorchTensor(torch.Tensor): | |
pass | |
class Prediction: | |
prediction: TorchTensor | |
app = FastAPI() | |
# Load the PyTorch model | |
model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt") | |
# Define a function to preprocess the input | |
def preprocess_input(input): | |
"""Preprocess the input image for the PyTorch image classification model. | |
Args: | |
input: A PIL Image object. | |
Returns: | |
A PyTorch tensor representing the preprocessed image. | |
""" | |
# Resize the image to the expected size. | |
input = input.resize((224, 224)) | |
# Convert the image to a PyTorch tensor. | |
input = torch.from_numpy(np.array(input)).float() | |
# Normalize the image. | |
input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input) | |
# Return the preprocessed image. | |
return input | |
async def predict_endpoint(input: Any): | |
"""Predict the output of the PyTorch image classification model. | |
Args: | |
input: A file containing the input image. | |
Returns: | |
A JSON object containing the prediction. | |
""" | |
# Load the image. | |
image = Image.open(BytesIO(input)) | |
# Preprocess the image. | |
image = preprocess_input(image) | |
# Make a prediction. | |
prediction = model(image.unsqueeze(0)) | |
# Get the top predicted class. | |
predicted_class = prediction.argmax(1) | |
# Return the prediction. | |
return {"prediction": predicted_class.item()} | |
if _name_ == "_main_": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |