Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import time | |
import traceback | |
import gradio as gr | |
import torch | |
from utils.inference_utils import preprocess_image, predict | |
from utils.train_utils import initialize_model | |
from utils.data import CLASS_NAMES | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.FileHandler('pokemon_classifier.log'), | |
logging.StreamHandler() | |
] | |
) | |
logger = logging.getLogger(__name__) | |
def setup_model(): | |
""" | |
Initialize and load the model with comprehensive error handling. | |
Returns: | |
torch.nn.Module: Loaded and prepared model | |
""" | |
try: | |
# Configure model parameters | |
model_name = "resnet" | |
model_weights = "./pokemon_resnet.pth" | |
num_classes = 150 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Log device information | |
logger.info(f"Using device: {device}") | |
# Validate model weights file exists | |
if not os.path.exists(model_weights): | |
raise FileNotFoundError(f"Model weights file not found: {model_weights}") | |
# Initialize and load model | |
start_time = time.time() | |
model = initialize_model(model_name, num_classes).to(device) | |
model.load_state_dict(torch.load(model_weights, map_location=device)) | |
model.eval() # Set the model to evaluation mode | |
logger.info(f"Model initialization completed in {time.time() - start_time:.2f} seconds") | |
return model, device | |
except Exception as e: | |
logger.error(f"Model initialization failed: {e}") | |
logger.error(traceback.format_exc()) | |
raise | |
def classify_image(image, model, device): | |
""" | |
Classify an uploaded image with comprehensive error handling and logging. | |
Args: | |
image (PIL.Image): Uploaded image | |
model (torch.nn.Module): Loaded model | |
device (torch.device): Computation device | |
Returns: | |
str: Prediction result or error message | |
""" | |
if image is None: | |
return "No image uploaded" | |
try: | |
start_time = time.time() | |
# Preprocess image | |
logger.info('Preprocessing image...') | |
image_tensor = preprocess_image(image, (224, 224)).to(device) | |
# Perform inference | |
logger.info('Running inference...') | |
with torch.no_grad(): # Disable gradient computation for inference | |
preds = torch.max(predict(model, image_tensor), 1)[1] | |
predicted_class = CLASS_NAMES[preds.item()] | |
# Log performance metrics | |
inference_time = time.time() - start_time | |
logger.info(f"Image classification completed in {inference_time:.4f} seconds") | |
logger.info(f"Predicted class: {predicted_class}") | |
return f"Predicted class: {predicted_class}" | |
except Exception as e: | |
logger.error(f"Classification error: {e}") | |
logger.error(traceback.format_exc()) | |
return f"Error processing image: {str(e)}" | |
def create_gradio_app(): | |
""" | |
Create and configure the Gradio interface. | |
Returns: | |
gr.Interface: Configured Gradio interface | |
""" | |
try: | |
# Initialize model once | |
model, device = setup_model() | |
# Create a wrapper function that includes the model and device | |
def classify_wrapper(image): | |
return classify_image(image, model, device) | |
demo = gr.Interface( | |
fn=classify_wrapper, | |
inputs=gr.components.Image(type="pil", label="Upload Pokemon Image"), | |
outputs=gr.components.Textbox(label="Prediction"), | |
title="Pokemon Classifier", | |
description="Upload an image of a Pokemon, and the model will predict its class.", | |
allow_flagging="never" # Disable flagging to simplify UI | |
) | |
return demo | |
except Exception as e: | |
logger.critical(f"Failed to create Gradio app: {e}") | |
logger.critical(traceback.format_exc()) | |
raise | |
def main(): | |
try: | |
demo = create_gradio_app() | |
demo.launch( | |
server_name="0.0.0.0", # Important for Docker | |
server_port=7860, # Standard Hugging Face Spaces port | |
share=False | |
) | |
except Exception as e: | |
logger.critical(f"Application launch failed: {e}") | |
logger.critical(traceback.format_exc()) | |
if __name__ == "__main__": | |
main() |