import torch import fastapi import numpy as np from PIL import Image from typing import Any, Type class TorchTensor(torch.Tensor): pass class Prediction: prediction: TorchTensor app = fastapi.FastAPI() model = torch.load("model67.bin", map_location='cpu') # Define a function to preprocess the input image def preprocess_input(input: Any): image = Image.open(BytesIO(input)) image = image.resize((224, 224)) input = np.array(image) input = torch.from_numpy(input).float() input = input.permute(2, 0, 1) input = input.unsqueeze(0) return input # Define an endpoint to make predictions @app.post("/predict") async def predict_endpoint(input: Any): """Make a prediction on an image uploaded by the user.""" # Preprocess the input image input = preprocess_input(input) # Make a prediction prediction = model(input) # Get the predicted class predicted_class = prediction.argmax(1).item() # Return the predicted class in JSON format return {"prediction": predicted_class} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)