SANJAYV10's picture
Update app.py
0f3e8d6
raw
history blame
No virus
999 Bytes
import torch
from fastapi import FastAPI, UploadFile
from PIL import Image
import numpy as np
class TorchTensor(torch.Tensor):
pass
class Prediction:
prediction: TorchTensor
app = fastapi.FastAPI(docs_url="/")
model = torch.load("best_model.pth", map_location='cpu')
def preprocess_input(input: UploadFile):
"""Preprocess the input image."""
image = Image.open(input.file)
image = image.resize((224, 224))
input = np.array(image)
input = torch.from_numpy(input).float()
input = input.unsqueeze(0)
return input
# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input:fastapi.UploadFile):
"""Make a prediction on an image uploaded by the user."""
# Preprocess the input image
input = preprocess_input(input)
# Make a prediction
prediction = model(input)
predicted_class = prediction.argmax(1).item()
# Return the predicted class in JSON format
return {"prediction": predicted_class}