File size: 2,223 Bytes
4ecc856
fbb5992
4ecc856
 
8b7cda3
4ecc856
 
8b7cda3
4ecc856
 
fbb5992
ff48a9b
8da9e7a
 
 
 
 
 
 
 
 
 
 
14ba05a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a1929d
f9f1904
4ecc856
d1f3905
 
302c39e
14735ff
673a00c
14735ff
 
4ecc856
29c50fd
4ecc856
29c50fd
d1f3905
4ecc856
 
 
 
8b7cda3
fbb5992
4ecc856
29c50fd
7c36a49
f12310d
157e4f0
 
 
 
f12310d
157e4f0
14ba05a
157e4f0
29c50fd
157e4f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import fastapi
import numpy as np
from PIL import Image

class TorchTensor(torch.Tensor):
    pass

class Prediction:
    prediction: TorchTensor

app = fastapi.FastAPI(docs_url="/")
from transformers import ViTForImageClassification

# Define the number of classes in your custom dataset
num_classes = 20

# Initialize the ViTForImageClassification model
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes  # Specify the number of classes
)

class_names = [
    "Acral Lick Dermatitis",
    "Acute moist dermatitis",
    "Canine atopic dermatitis",
    "Cherry Eye",
    "Ear infections",
    "External Parasites",
    "Folliculitis",
    "Healthy",
    "Leishmaniasis",
    "Lupus",
    "Nuclear sclerosis",
    "Otitis externa",
    "Pruritus",
    "Pyoderma",
    "Rabies",
    "Ringworm",
    "Sarcoptic Mange",
    "Sebaceous adenitis",
    "Seborrhea",
    "Skin tumor"
]

model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
    image = Image.open(input.file)
    image = image.resize((224, 224)).convert("RGB")
    input = np.array(image)
    input = np.transpose(input, (2, 0, 1))
    input = torch.from_numpy(input).float()
    input = input.unsqueeze(0)
    return input

# Define an endpoint to make predictions
@app.post("/predict")
async def predict_endpoint(input:fastapi.UploadFile):
    """Make a prediction on an image uploaded by the user."""

    # Preprocess the input image
    input = preprocess_input(input)

    # Make a prediction
    prediction = model(input)


    logits = prediction.logits
    num_top_predictions = 3
    top_predictions = torch.topk(logits, k=num_top_predictions, dim=1)
    top_indices = top_predictions.indices.squeeze().tolist()
    top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist()

    # Return the top N class indices and their probabilities in JSON format
    response_data = [{"class_index": class_names[idx], "probability": prob} for idx, prob in zip(top_indices, top_probabilities)]
    return {"predictions": response_data}