from fastapi import FastAPI from pydantic import BaseModel import torch import torch.nn as nn import torch from torchvision import transforms class Prediction(BaseModel): prediction: torch.Tensor app = FastAPI() # Load the PyTorch model model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt") # Define a function to preprocess the input def preprocess_input(input): """Preprocess the input image for the PyTorch image classification model. Args: input: A PIL Image object. Returns: A PyTorch tensor representing the preprocessed image. """ # Resize the image to the expected size. input = input.resize((224, 224)) # Convert the image to a PyTorch tensor. input = torch.from_numpy(np.array(input)).float() # Normalize the image. input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input) # Return the preprocessed image. return input @app.post("/predict", response_model=Prediction) async def predict_endpoint(input: fastapi.File): """Predict the output of the PyTorch image classification model. Args: input: A file containing the input image. Returns: A JSON object containing the prediction. """ # Load the image. image = await input.read() image = Image.open(BytesIO(image)) # Preprocess the image. image = preprocess_input(image) # Make a prediction. prediction = model(image.unsqueeze(0)) # Get the top predicted class. predicted_class = prediction.argmax(1) # Return the prediction. return Prediction(prediction=predicted_class) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)