SANJAYV10's picture
Update app.py
719a218
raw
history blame
2.23 kB
import torch
import fastapi
import numpy as np
from PIL import Image
app = fastapi.FastAPI(docs_url="/")
# Load your pre-trained model and other necessary components here
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=num_classes # Specify the number of classes
)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
image = Image.open(input.file)
image = image.resize((224, 224)).convert("RGB")
input_data = np.array(image)
input_data = np.transpose(input_data, (2, 0, 1))
input_data = torch.from_numpy(input_data).float()
input_data = input_data.unsqueeze(0)
return input_data
@app.post("/predict")
async def predict_endpoint(input: fastapi.UploadFile):
"""Make a prediction on an image uploaded by the user."""
# Preprocess the input image
input_data = preprocess_input(input)
# Make a prediction
prediction = model(input_data)
logits = prediction.logits
num_top_predictions = 3
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
top_indices = top_predictions.indices.squeeze().tolist()
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
# Define class names for your dataset (modify as needed)
class_names = [
"Acral Lick Dermatitis",
"Acute moist dermatitis",
"Canine atopic dermatitis",
"Cherry Eye",
"Ear infections",
"External Parasites",
"Folliculitis",
"Healthy",
"Leishmaniasis",
"Lupus",
"Nuclear sclerosis",
"Otitis externa",
"Pruritus",
"Pyoderma",
"Rabies",
"Ringworm",
"Sarcoptic Mange",
"Sebaceous adenitis",
"Seborrhea",
"Skin tumor"
]
# Return the top N class names and their probabilities in JSON format
response_data = [
{
"class_name": class_names[idx],
"probability": prob
}
for idx, prob in zip(top_indices, top_probabilities)
]
return {"predictions": response_data}