import torch import fastapi import numpy as np from PIL import Image class TorchTensor(torch.Tensor): pass class Prediction: prediction: TorchTensor app = fastapi.FastAPI(docs_url="/") from transformers import ViTForImageClassification # Define the number of classes in your custom dataset num_classes = 20 # Initialize the ViTForImageClassification model model = ViTForImageClassification.from_pretrained( 'google/vit-base-patch16-224-in21k', num_labels=num_classes # Specify the number of classes ) # Load your fine-tuned model weights model.load_state_dict(torch.load('best_model.pth', map_location='cpu')) # Define class names for your dataset class_names = [ "Acral Lick Dermatitis", "Acute moist dermatitis", "Canine atopic dermatitis", "Cherry Eye", "Ear infections", "External Parasites", "Folliculitis", "Healthy", "Leishmaniasis", "Lupus", "Nuclear sclerosis", "Otitis externa", "Pruritus", "Pyoderma", "Rabies", "Ringworm", "Sarcoptic Mange", "Sebaceous adenitis", "Seborrhea", "Skin tumor" ] # Define a function to preprocess the input image def preprocess_input(input: fastapi.UploadFile): image = Image.open(input.file) image = image.resize((224, 224)).convert("RGB") input_data = np.array(image) input_data = np.transpose(input_data, (2, 0, 1)) input_data = torch.from_numpy(input_data).float() input_data = input_data.unsqueeze(0) return input_data @app.post("/predict") async def predict_endpoint(input: fastapi.UploadFile): """Make a prediction on an image uploaded by the user.""" # Preprocess the input image input_data = preprocess_input(input) # Make a prediction prediction = model(input_data) logits = prediction.logits # Get the top N predictions predicted_class = torch.argmax(logits, dim=1).item() # Create a response dictionary return {"prediction": predicted_class}