AnishKumbhar's picture
Update app.py
d1f3905
import torch
import fastapi
import numpy as np
from PIL import Image
class TorchTensor(torch.Tensor):
pass
class Prediction:
prediction: TorchTensor
app = fastapi.FastAPI(docs_url="/")
from transformers import ViTForImageClassification
# Define the number of classes in your custom dataset
num_classes = 20
# Initialize the ViTForImageClassification model
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=num_classes # Specify the number of classes
)
class_names = [
"Acral Lick Dermatitis",
"Acute moist dermatitis",
"Canine atopic dermatitis",
"Cherry Eye",
"Ear infections",
"External Parasites",
"Folliculitis",
"Healthy",
"Leishmaniasis",
"Lupus",
"Nuclear sclerosis",
"Otitis externa",
"Pruritus",
"Pyoderma",
"Rabies",
"Ringworm",
"Sarcoptic Mange",
"Sebaceous adenitis",
"Seborrhea",
"Skin tumor"
]
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
image = Image.open(input.file)
image = image.resize((224, 224)).convert("RGB")
input = np.array(image)
input = np.transpose(input, (2, 0, 1))
input = torch.from_numpy(input).float()
input = input.unsqueeze(0)
return input
# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input:fastapi.UploadFile):
"""Make a prediction on an image uploaded by the user."""
# Preprocess the input image
input = preprocess_input(input)
# Make a prediction
prediction = model(input)
logits = prediction.logits
num_top_predictions = 3
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
top_indices = top_predictions.indices.squeeze().tolist()
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
# Return the top N class indices and their probabilities in JSON format
response_data = [{"class_index": class_names[idx], "probability": prob} for idx, prob in zip(top_indices, top_probabilities)]
return {"predictions": response_data}