Spaces:
Runtime error
Runtime error
File size: 2,223 Bytes
4ecc856 fbb5992 4ecc856 8b7cda3 4ecc856 8b7cda3 4ecc856 fbb5992 ff48a9b 8da9e7a 14ba05a 1a1929d f9f1904 4ecc856 d1f3905 302c39e 14735ff 673a00c 14735ff 4ecc856 29c50fd 4ecc856 29c50fd d1f3905 4ecc856 8b7cda3 fbb5992 4ecc856 29c50fd 7c36a49 f12310d 157e4f0 f12310d 157e4f0 14ba05a 157e4f0 29c50fd 157e4f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
import fastapi
import numpy as np
from PIL import Image
class TorchTensor(torch.Tensor):
pass
class Prediction:
prediction: TorchTensor
app = fastapi.FastAPI(docs_url="/")
from transformers import ViTForImageClassification
# Define the number of classes in your custom dataset
num_classes = 20
# Initialize the ViTForImageClassification model
model = ViTForImageClassification.from_pretrained(
'google/vit-base-patch16-224-in21k',
num_labels=num_classes # Specify the number of classes
)
class_names = [
"Acral Lick Dermatitis",
"Acute moist dermatitis",
"Canine atopic dermatitis",
"Cherry Eye",
"Ear infections",
"External Parasites",
"Folliculitis",
"Healthy",
"Leishmaniasis",
"Lupus",
"Nuclear sclerosis",
"Otitis externa",
"Pruritus",
"Pyoderma",
"Rabies",
"Ringworm",
"Sarcoptic Mange",
"Sebaceous adenitis",
"Seborrhea",
"Skin tumor"
]
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
image = Image.open(input.file)
image = image.resize((224, 224)).convert("RGB")
input = np.array(image)
input = np.transpose(input, (2, 0, 1))
input = torch.from_numpy(input).float()
input = input.unsqueeze(0)
return input
# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input:fastapi.UploadFile):
"""Make a prediction on an image uploaded by the user."""
# Preprocess the input image
input = preprocess_input(input)
# Make a prediction
prediction = model(input)
logits = prediction.logits
num_top_predictions = 3
top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
top_indices = top_predictions.indices.squeeze().tolist()
top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()
# Return the top N class indices and their probabilities in JSON format
response_data = [{"class_index": class_names[idx], "probability": prob} for idx, prob in zip(top_indices, top_probabilities)]
return {"predictions": response_data}
|