Spaces:
Runtime error
Runtime error
File size: 1,618 Bytes
d72ad25 e143977 c2551f6 06ba20f 367823f 06ba20f 367823f d72ad25 e143977 367823f c4de414 367823f c4de414 367823f c5f6a7a 367823f c5f6a7a 367823f e143977 367823f e143977 367823f c5f6a7a 367823f e143977 367823f e143977 367823f e143977 367823f e143977 367823f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from fastapi import FastAPI
import torch
import torch.nn as nn
import torch
from torchvision import transforms
from typing import Any, Type
import torch
class TorchTensor(torch.Tensor):
pass
class Prediction():
prediction: TorchTensor
app = FastAPI()
# Load the PyTorch model
model = torch.load("best_model-epoch=01-val_loss=3.00.ckpt")
# Define a function to preprocess the input
def preprocess_input(input):
"""Preprocess the input image for the PyTorch image classification model.
Args:
input: A PIL Image object.
Returns:
A PyTorch tensor representing the preprocessed image.
"""
# Resize the image to the expected size.
input = input.resize((224, 224))
# Convert the image to a PyTorch tensor.
input = torch.from_numpy(np.array(input)).float()
# Normalize the image.
input = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(input)
# Return the preprocessed image.
return input
@app.post("/predict")
async def predict_endpoint(input: Any):
"""Predict the output of the PyTorch image classification model.
Args:
input: A file containing the input image.
Returns:
A JSON object containing the prediction.
"""
# Load the image.
image = Image.open(BytesIO(input))
# Preprocess the image.
image = preprocess_input(image)
# Make a prediction.
prediction = model(image.unsqueeze(0))
# Get the top predicted class.
predicted_class = prediction.argmax(1)
# Return the prediction.
return {"prediction": predicted_class.item()}
if _name_ == "_main_":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |