Spaces:
Sleeping
Sleeping
File size: 4,568 Bytes
7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 37ad7bc 7c3ed50 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import logging
import time
import traceback
import gradio as gr
import torch
from utils.inference_utils import preprocess_image, predict
from utils.train_utils import initialize_model
from utils.data import CLASS_NAMES
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('pokemon_classifier.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def setup_model():
"""
Initialize and load the model with comprehensive error handling.
Returns:
torch.nn.Module: Loaded and prepared model
"""
try:
# Configure model parameters
model_name = "resnet"
model_weights = "./pokemon_resnet.pth"
num_classes = 150
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Log device information
logger.info(f"Using device: {device}")
# Validate model weights file exists
if not os.path.exists(model_weights):
raise FileNotFoundError(f"Model weights file not found: {model_weights}")
# Initialize and load model
start_time = time.time()
model = initialize_model(model_name, num_classes).to(device)
model.load_state_dict(torch.load(model_weights, map_location=device))
model.eval() # Set the model to evaluation mode
logger.info(f"Model initialization completed in {time.time() - start_time:.2f} seconds")
return model, device
except Exception as e:
logger.error(f"Model initialization failed: {e}")
logger.error(traceback.format_exc())
raise
def classify_image(image, model, device):
"""
Classify an uploaded image with comprehensive error handling and logging.
Args:
image (PIL.Image): Uploaded image
model (torch.nn.Module): Loaded model
device (torch.device): Computation device
Returns:
str: Prediction result or error message
"""
if image is None:
return "No image uploaded"
try:
start_time = time.time()
# Preprocess image
logger.info('Preprocessing image...')
image_tensor = preprocess_image(image, (224, 224)).to(device)
# Perform inference
logger.info('Running inference...')
with torch.no_grad(): # Disable gradient computation for inference
preds = torch.max(predict(model, image_tensor), 1)[1]
predicted_class = CLASS_NAMES[preds.item()]
# Log performance metrics
inference_time = time.time() - start_time
logger.info(f"Image classification completed in {inference_time:.4f} seconds")
logger.info(f"Predicted class: {predicted_class}")
return f"Predicted class: {predicted_class}"
except Exception as e:
logger.error(f"Classification error: {e}")
logger.error(traceback.format_exc())
return f"Error processing image: {str(e)}"
def create_gradio_app():
"""
Create and configure the Gradio interface.
Returns:
gr.Interface: Configured Gradio interface
"""
try:
# Initialize model once
model, device = setup_model()
# Create a wrapper function that includes the model and device
def classify_wrapper(image):
return classify_image(image, model, device)
demo = gr.Interface(
fn=classify_wrapper,
inputs=gr.components.Image(type="pil", label="Upload Pokemon Image"),
outputs=gr.components.Textbox(label="Prediction"),
title="Pokemon Classifier",
description="Upload an image of a Pokemon, and the model will predict its class.",
allow_flagging="never" # Disable flagging to simplify UI
)
return demo
except Exception as e:
logger.critical(f"Failed to create Gradio app: {e}")
logger.critical(traceback.format_exc())
raise
def main():
try:
demo = create_gradio_app()
demo.launch(
server_name="0.0.0.0", # Important for Docker
server_port=7860, # Standard Hugging Face Spaces port
share=False
)
except Exception as e:
logger.critical(f"Application launch failed: {e}")
logger.critical(traceback.format_exc())
if __name__ == "__main__":
main() |