File size: 1,979 Bytes
e143977
3ec29bf
0f3e8d6
3ec29bf
06ba20f
c305981
 
 
 
 
 
bfbcab4
c305981
 
 
 
e9d559f
c305981
9cd2acf
 
 
 
5ed70cd
 
9cd2acf
719a218
5ed70cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ec29bf
 
bfbcab4
9cd2acf
719a218
 
 
 
 
e143977
367823f
719a218
60712d2
 
 
719a218
60712d2
 
719a218
5c503f2
bfbcab4
5c503f2
c1a7e91
5c503f2
 
c1a7e91
5c503f2
c1a7e91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import fastapi
import numpy as np
from PIL import Image

class TorchTensor(torch.Tensor):
    pass

class Prediction:
    prediction: TorchTensor

app = fastapi.FastAPI(docs_url="/")
from transformers import ViTForImageClassification

# Define the number of classes in your custom dataset
num_classes = 20

# Initialize the ViTForImageClassification model
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes  # Specify the number of classes
)

# Load your fine-tuned model weights
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))

# Define class names for your dataset
class_names = [
    "Acral Lick Dermatitis",
    "Acute moist dermatitis",
    "Canine atopic dermatitis",
    "Cherry Eye",
    "Ear infections",
    "External Parasites",
    "Folliculitis",
    "Healthy",
    "Leishmaniasis",
    "Lupus",
    "Nuclear sclerosis",
    "Otitis externa",
    "Pruritus",
    "Pyoderma",
    "Rabies",
    "Ringworm",
    "Sarcoptic Mange",
    "Sebaceous adenitis",
    "Seborrhea",
    "Skin tumor"
]

# Define a function to preprocess the input image
def preprocess_input(input: fastapi.UploadFile):
    image = Image.open(input.file)
    image = image.resize((224, 224)).convert("RGB")
    input_data = np.array(image)
    input_data = np.transpose(input_data, (2, 0, 1))
    input_data = torch.from_numpy(input_data).float()
    input_data = input_data.unsqueeze(0)
    return input_data

@app.post("/predict")
async def predict_endpoint(input: fastapi.UploadFile):
    """Make a prediction on an image uploaded by the user."""

    # Preprocess the input image
    input_data = preprocess_input(input)

    # Make a prediction
    prediction = model(input_data)
    logits = prediction.logits

    # Get the top N predictions
    predicted_class = torch.argmax(logits, dim=1).item()

    # Create a response dictionary
    return {"prediction": predicted_class}