from fastapi import FastAPI from pydantic import BaseModel import torch import torch.nn as nn import torch from torchvision import transforms from typing import Any, Type import pydantic import torch class TensorSchema(pydantic.BaseModel): shape: list[int] dtype: str requires_grad: bool @classmethod def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema: return pydantic.schema.Schema( type="object", properties={ "shape": pydantic.schema.Schema(type="array", items=pydantic.schema.Schema(type="integer")), "dtype": pydantic.schema.Schema(type="string"), "requires_grad": pydantic.schema.Schema(type="boolean"), }, required=["shape", "dtype", "requires_grad"], ) class TorchTensor(torch.Tensor): @classmethod def __get_pydantic_core_schema__(cls: Type[Any]) -> pydantic.schema.Schema: return TensorSchema.__get_pydantic_core_schema__() class Prediction(BaseModel): prediction: torch.Tensor app = FastAPI() # Load the PyTorch model model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt") # Define a function to preprocess the input def preprocess_input(input): """Preprocess the input image for the PyTorch image classification model. Args: input: A PIL Image object. Returns: A PyTorch tensor representing the preprocessed image. """ # Resize the image to the expected size. input = input.resize((224, 224)) # Convert the image to a PyTorch tensor. input = torch.from_numpy(np.array(input)).float() # Normalize the image. input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input) # Return the preprocessed image. return input @app.post("/predict", response_model=Prediction) async def predict_endpoint(input: fastapi.File): """Predict the output of the PyTorch image classification model. Args: input: A file containing the input image. Returns: A JSON object containing the prediction. """ # Load the image. image = await input.read() image = Image.open(BytesIO(image)) # Preprocess the image. image = preprocess_input(image) # Make a prediction. prediction = model(image.unsqueeze(0)) # Get the top predicted class. predicted_class = prediction.argmax(1) # Return the prediction. return Prediction(prediction=predicted_class) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)