Spaces:
Runtime error
Runtime error
File size: 2,550 Bytes
e143977 3ec29bf 0f3e8d6 3ec29bf 06ba20f c305981 bfbcab4 c305981 e9d559f c305981 9cd2acf 5ed70cd 9cd2acf 719a218 5ed70cd 3ec29bf bfbcab4 9cd2acf 719a218 e143977 367823f 719a218 60712d2 719a218 60712d2 719a218 5c503f2 bfbcab4 5c503f2 719a218 5c503f2 719a218 5c503f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
import fastapi
import numpy as np
from PIL import Image
class TorchTensor(torch.Tensor):
pass
class Prediction:
prediction: TorchTensor
app = fastapi.FastAPI(docs_url="/")
from transformers import ViTForImageClassification
# Define the number of classes in your custom dataset
num_classes = 20
# Initialize the ViTForImageClassification model
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=num_classes # Specify the number of classes
)
# Load your fine-tuned model weights
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
# Define class names for your dataset
class_names = [
"Acral Lick Dermatitis",
"Acute moist dermatitis",
"Canine atopic dermatitis",
"Cherry Eye",
"Ear infections",
"External Parasites",
"Folliculitis",
"Healthy",
"Leishmaniasis",
"Lupus",
"Nuclear sclerosis",
"Otitis externa",
"Pruritus",
"Pyoderma",
"Rabies",
"Ringworm",
"Sarcoptic Mange",
"Sebaceous adenitis",
"Seborrhea",
"Skin tumor"
]
# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
image = Image.open(input.file)
image = image.resize((224, 224)).convert("RGB")
input_data = np.array(image)
input_data = np.transpose(input_data, (2, 0, 1))
input_data = torch.from_numpy(input_data).float()
input_data = input_data.unsqueeze(0)
return input_data
@app.post("/predict")
async def predict_endpoint(input: fastapi.UploadFile):
"""Make a prediction on an image uploaded by the user."""
# Preprocess the input image
input_data = preprocess_input(input)
# Make a prediction
prediction = model(input_data)
logits = prediction.logits
# Get the top N predictions
num_top_predictions = 20
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
top_indices = top_predictions.indices.squeeze().tolist()
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
# Return the top N class names and their probabilities in JSON format
top_diseases = [class_names[idx] for idx in top_indices]
# Create a response dictionary
response_data = {"predictions": []}
# Add the top N predictions and their probabilities to the response
for disease, probability in zip(top_diseases, top_probabilities):
response_data["predictions"].append(f"{disease}: {probability * 100:.2f}%")
return response_data
|