from fastapi import FastAPI import torch import torch.nn as nn import torch from torchvision import transforms from typing import Any, Type import torch class TorchTensor(torch.Tensor):  pass class Prediction():  prediction: TorchTensor app = FastAPI() # Load the PyTorch model model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt") # Define a function to preprocess the input def preprocess_input(input):  """Preprocess the input image for the PyTorch image classification model.  Args:   input: A PIL Image object.  Returns:   A PyTorch tensor representing the preprocessed image.  """  # Resize the image to the expected size.  input = input.resize((224, 224))  # Convert the image to a PyTorch tensor.  input = torch.from_numpy(np.array(input)).float()  # Normalize the image.  input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)  # Return the preprocessed image.  return input @app.post("/predict") async def predict_endpoint(input: Any):  """Predict the output of the PyTorch image classification model.  Args:   input: A file containing the input image.  Returns:   A JSON object containing the prediction.  """  # Load the image.  image = Image.open(BytesIO(input))  # Preprocess the image.  image = preprocess_input(image)  # Make a prediction.  prediction = model(image.unsqueeze(0))  # Get the top predicted class.  predicted_class = prediction.argmax(1)  # Return the prediction.  return {"prediction": predicted_class.item()} if _name_ == "_main_":  import uvicorn  uvicorn.run(app, host="0.0.0.0", port=8000)