# app.py from fastapi import FastAPI import torch import torch.nn as nn import torch from torchvision import transforms from typing import Any, Type import torch class TorchTensor(torch.Tensor): pass class Prediction: prediction: TorchTensor app = FastAPI() # Load the PyTorch model model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt") # Define a function to preprocess the input def preprocess_input(input): input = input.resize((224, 224)) input = torch.from_numpy(np.array(input)).float() input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input) return input @app.post("/predict") async def predict_endpoint(input: Any): image = Image.open(BytesIO(input)) image = preprocess_input(image) prediction = model(image.unsqueeze(0)) predicted_class = prediction.argmax(1) return {"prediction": predicted_class.item()} if _name_ == "_main_": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)