Spaces:
Runtime error
Runtime error
import torch | |
import fastapi | |
import numpy as np | |
from PIL import Image | |
app = fastapi.FastAPI(docs_url="/") | |
# Load your pre-trained model and other necessary components here | |
model = ViTForImageClassification.from_pretrained( | |
'google/vit-base-patch16-224-in21k', | |
num_labels=num_classes # Specify the number of classes | |
) | |
model.load_state_dict(torch.load('best_model.pth', map_location='cpu')) | |
# Define a function to preprocess the input image | |
def preprocess_input(input: fastapi.UploadFile): | |
image = Image.open(input.file) | |
image = image.resize((224, 224)).convert("RGB") | |
input_data = np.array(image) | |
input_data = np.transpose(input_data, (2, 0, 1)) | |
input_data = torch.from_numpy(input_data).float() | |
input_data = input_data.unsqueeze(0) | |
return input_data | |
async def predict_endpoint(input: fastapi.UploadFile): | |
"""Make a prediction on an image uploaded by the user.""" | |
# Preprocess the input image | |
input_data = preprocess_input(input) | |
# Make a prediction | |
prediction = model(input_data) | |
logits = prediction.logits | |
num_top_predictions = 3 | |
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1) | |
top_indices = top_predictions.indices.squeeze().tolist() | |
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist() | |
# Define class names for your dataset (modify as needed) | |
class_names = [ | |
"Acral Lick Dermatitis", | |
"Acute moist dermatitis", | |
"Canine atopic dermatitis", | |
"Cherry Eye", | |
"Ear infections", | |
"External Parasites", | |
"Folliculitis", | |
"Healthy", | |
"Leishmaniasis", | |
"Lupus", | |
"Nuclear sclerosis", | |
"Otitis externa", | |
"Pruritus", | |
"Pyoderma", | |
"Rabies", | |
"Ringworm", | |
"Sarcoptic Mange", | |
"Sebaceous adenitis", | |
"Seborrhea", | |
"Skin tumor" | |
] | |
# Return the top N class names and their probabilities in JSON format | |
response_data = [ | |
{ | |
"class_name": class_names[idx], | |
"probability": prob | |
} | |
for idx, prob in zip(top_indices, top_probabilities) | |
] | |
return {"predictions": response_data} | |