import torch import fastapi import numpy as np from PIL import Image class TorchTensor(torch.Tensor): pass class Prediction: prediction: TorchTensor app = fastapi.FastAPI(docs_url="/") from transformers import ViTForImageClassification # Define the number of classes in your custom dataset num_classes = 20 # Initialize the ViTForImageClassification model model = ViTForImageClassification.from_pretrained( 'google/vit-base-patch16-224-in21k', num_labels=num_classes # Specify the number of classes ) model.load_state_dict(torch.load('best_model.pth', map_location='cpu')) # Define a function to preprocess the input image def preprocess_input(input: fastapi.UploadFile): image = Image.open(input.file) image = image.resize((224, 224)).convert("RGB") input = np.array(image) input = np.transpose(input, (2, 0, 1)) input = torch.from_numpy(input).float() input = input.unsqueeze(0) return input # Define an endpoint to make predictions @app.post("/predict") async def predict_endpoint(input:fastapi.UploadFile): """Make a prediction on an image uploaded by the user.""" # Preprocess the input image input = preprocess_input(input) # Make a prediction prediction = model(input) logits = prediction.logits num_top_predictions = 3 top_predictions = torch.topk(logits, k=num_top_predictions, dim=1) # Get the top 3 class indices and their probabilities top_indices = top_predictions.indices.squeeze().tolist() top_probabilities = torch.softmax(top_predictions.values, dim=1).squeeze().tolist() # Get the disease names for the top 3 predictions disease_names = [disease_names[idx] for idx in top_indices] # Return the top 3 disease names and their probabilities in JSON format response_data = [{"disease_name": name, "probability": prob} for name, prob in zip(disease_names, top_probabilities)] return {"predictions": response_data}