File size: 4,568 Bytes
7c3ed50
 
 
 
 
37ad7bc
 
 
 
 
 
7c3ed50
 
 
 
 
 
 
 
 
 
37ad7bc
7c3ed50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37ad7bc
7c3ed50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37ad7bc
7c3ed50
 
 
 
37ad7bc
7c3ed50
37ad7bc
7c3ed50
 
 
 
37ad7bc
7c3ed50
 
 
 
 
 
37ad7bc
7c3ed50
 
 
 
 
37ad7bc
7c3ed50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37ad7bc
7c3ed50
 
 
37ad7bc
7c3ed50
 
 
 
 
 
 
 
 
 
 
37ad7bc
 
7c3ed50
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import logging
import time
import traceback

import gradio as gr
import torch
from utils.inference_utils import preprocess_image, predict
from utils.train_utils import initialize_model
from utils.data import CLASS_NAMES

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('pokemon_classifier.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def setup_model():
    """
    Initialize and load the model with comprehensive error handling.
    
    Returns:
        torch.nn.Module: Loaded and prepared model
    """
    try:
        # Configure model parameters
        model_name = "resnet"
        model_weights = "./pokemon_resnet.pth"
        num_classes = 150
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Log device information
        logger.info(f"Using device: {device}")
        
        # Validate model weights file exists
        if not os.path.exists(model_weights):
            raise FileNotFoundError(f"Model weights file not found: {model_weights}")
        
        # Initialize and load model
        start_time = time.time()
        model = initialize_model(model_name, num_classes).to(device)
        model.load_state_dict(torch.load(model_weights, map_location=device))
        model.eval()  # Set the model to evaluation mode
        
        logger.info(f"Model initialization completed in {time.time() - start_time:.2f} seconds")
        return model, device
    
    except Exception as e:
        logger.error(f"Model initialization failed: {e}")
        logger.error(traceback.format_exc())
        raise

def classify_image(image, model, device):
    """
    Classify an uploaded image with comprehensive error handling and logging.
    
    Args:
        image (PIL.Image): Uploaded image
        model (torch.nn.Module): Loaded model
        device (torch.device): Computation device
    
    Returns:
        str: Prediction result or error message
    """
    if image is None:
        return "No image uploaded"
    
    try:
        start_time = time.time()
        
        # Preprocess image
        logger.info('Preprocessing image...')
        image_tensor = preprocess_image(image, (224, 224)).to(device)
        
        # Perform inference
        logger.info('Running inference...')
        with torch.no_grad():  # Disable gradient computation for inference
            preds = torch.max(predict(model, image_tensor), 1)[1]
        
        predicted_class = CLASS_NAMES[preds.item()]
        
        # Log performance metrics
        inference_time = time.time() - start_time
        logger.info(f"Image classification completed in {inference_time:.4f} seconds")
        logger.info(f"Predicted class: {predicted_class}")
        
        return f"Predicted class: {predicted_class}"
    
    except Exception as e:
        logger.error(f"Classification error: {e}")
        logger.error(traceback.format_exc())
        return f"Error processing image: {str(e)}"

def create_gradio_app():
    """
    Create and configure the Gradio interface.
    
    Returns:
        gr.Interface: Configured Gradio interface
    """
    try:
        # Initialize model once
        model, device = setup_model()
        
        # Create a wrapper function that includes the model and device
        def classify_wrapper(image):
            return classify_image(image, model, device)
        
        demo = gr.Interface(
            fn=classify_wrapper,
            inputs=gr.components.Image(type="pil", label="Upload Pokemon Image"),
            outputs=gr.components.Textbox(label="Prediction"),
            title="Pokemon Classifier",
            description="Upload an image of a Pokemon, and the model will predict its class.",
            allow_flagging="never"  # Disable flagging to simplify UI
        )
        
        return demo
    
    except Exception as e:
        logger.critical(f"Failed to create Gradio app: {e}")
        logger.critical(traceback.format_exc())
        raise

def main():
    try:
        demo = create_gradio_app()
        demo.launch(
            server_name="0.0.0.0",  # Important for Docker
            server_port=7860,        # Standard Hugging Face Spaces port
            share=False
        )
    except Exception as e:
        logger.critical(f"Application launch failed: {e}")
        logger.critical(traceback.format_exc())

if __name__ == "__main__":
    main()