SANJAYV10's picture
Update app.py
9cd2acf
raw
history blame
No virus
1.96 kB
import torch
import fastapi
import numpy as np
from PIL import Image
class TorchTensor(torch.Tensor):
pass
class Prediction:
prediction: TorchTensor
app = fastapi.FastAPI(docs_url="/")
from transformers import ViTForImageClassification
# Define the number of classes in your custom dataset
num_classes = 20
# Initialize the ViTForImageClassification model
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=num_classes # Specify the number of classes
)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
image = Image.open(input.file)
image = image.resize((224, 224)).convert("RGB")
input = np.array(image)
input = np.transpose(input, (2, 0, 1))
input = torch.from_numpy(input).float()
input = input.unsqueeze(0)
return input
# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input:fastapi.UploadFile):
"""Make a prediction on an image uploaded by the user."""
# Preprocess the input image
input = preprocess_input(input)
# Make a prediction
prediction = model(input)
logits = prediction.logits
num_top_predictions = 3
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
# Get the top 3 class indices and their probabilities
top_indices = top_predictions.indices.squeeze().tolist()
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
# Get the disease names for the top 3 predictions
disease_names = [disease_names[idx] for idx in top_indices]
# Return the top 3 disease names and their probabilities in JSON format
response_data = [{"disease_name": name, "probability": prob} for name, prob in zip(disease_names, top_probabilities)]
return {"predictions": response_data}