import os import logging import time import traceback import gradio as gr import torch from utils.inference_utils import preprocess_image, predict from utils.train_utils import initialize_model from utils.data import CLASS_NAMES # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('pokemon_classifier.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def setup_model(): """ Initialize and load the model with comprehensive error handling. Returns: torch.nn.Module: Loaded and prepared model """ try: # Configure model parameters model_name = "resnet" model_weights = "./pokemon_resnet.pth" num_classes = 150 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Log device information logger.info(f"Using device: {device}") # Validate model weights file exists if not os.path.exists(model_weights): raise FileNotFoundError(f"Model weights file not found: {model_weights}") # Initialize and load model start_time = time.time() model = initialize_model(model_name, num_classes).to(device) model.load_state_dict(torch.load(model_weights, map_location=device)) model.eval() # Set the model to evaluation mode logger.info(f"Model initialization completed in {time.time() - start_time:.2f} seconds") return model, device except Exception as e: logger.error(f"Model initialization failed: {e}") logger.error(traceback.format_exc()) raise def classify_image(image, model, device): """ Classify an uploaded image with comprehensive error handling and logging. Args: image (PIL.Image): Uploaded image model (torch.nn.Module): Loaded model device (torch.device): Computation device Returns: str: Prediction result or error message """ if image is None: return "No image uploaded" try: start_time = time.time() # Preprocess image logger.info('Preprocessing image...') image_tensor = preprocess_image(image, (224, 224)).to(device) # Perform inference logger.info('Running inference...') with torch.no_grad(): # Disable gradient computation for inference preds = torch.max(predict(model, image_tensor), 1)[1] predicted_class = CLASS_NAMES[preds.item()] # Log performance metrics inference_time = time.time() - start_time logger.info(f"Image classification completed in {inference_time:.4f} seconds") logger.info(f"Predicted class: {predicted_class}") return f"Predicted class: {predicted_class}" except Exception as e: logger.error(f"Classification error: {e}") logger.error(traceback.format_exc()) return f"Error processing image: {str(e)}" def create_gradio_app(): """ Create and configure the Gradio interface. Returns: gr.Interface: Configured Gradio interface """ try: # Initialize model once model, device = setup_model() # Create a wrapper function that includes the model and device def classify_wrapper(image): return classify_image(image, model, device) demo = gr.Interface( fn=classify_wrapper, inputs=gr.components.Image(type="pil", label="Upload Pokemon Image"), outputs=gr.components.Textbox(label="Prediction"), title="Pokemon Classifier", description="Upload an image of a Pokemon, and the model will predict its class.", allow_flagging="never" # Disable flagging to simplify UI ) return demo except Exception as e: logger.critical(f"Failed to create Gradio app: {e}") logger.critical(traceback.format_exc()) raise def main(): try: demo = create_gradio_app() demo.launch( server_name="0.0.0.0", # Important for Docker server_port=7860, # Standard Hugging Face Spaces port share=False ) except Exception as e: logger.critical(f"Application launch failed: {e}") logger.critical(traceback.format_exc()) if __name__ == "__main__": main()