import gradio as gr import torch from utils.inference_utils import preprocess_image, predict from utils.train_utils import initialize_model from utils.data import CLASS_NAMES # Load the model once during app initialization model_name = "resnet" model_weights = "./pokemon_resnet.pth" num_classes = 150 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize and load the model model = initialize_model(model_name, num_classes).to(device) model.load_state_dict(torch.load(model_weights, map_location=device)) model.eval() # Set the model to evaluation mode def classify_image(image): """Function to preprocess the image and classify it.""" try: # Preprocess the uploaded image image_tensor = preprocess_image(image, (224, 224)).to(device) # Perform inference preds = torch.max(predict(model, image_tensor), 1)[1] predicted_class = CLASS_NAMES[preds.item()] return f"Predicted class: {predicted_class}" except Exception as e: return f"Error: {str(e)}" # Create a Gradio interface demo = gr.Interface( fn=classify_image, inputs=gr.inputs.Image(type="pil", label="Upload Image"), outputs="text", title="Pokemon Classifier", description="Upload an image of a Pokemon, and the model will predict its class.", ) if __name__ == "__main__": # Launch the Gradio app demo.launch()