import gradio as gr
from tensorflow.keras.models import load_model
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO

# Load the trained model
model = load_model('model2.h5')

def predict_and_visualize(img):
    # Input validation
    if img is None:
        raise gr.Error("Please upload an image")
        
    try:
        # Convert numpy array to PIL Image if necessary
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img)
            
        # Store the original image size
        original_size = img.size
        
        # Convert the input image to the target size expected by the model
        img_resized = img.resize((224, 224))
        img_array = np.array(img_resized) / 255.0  # Normalize the image
        
        # Ensure the image has 3 channels (RGB)
        if len(img_array.shape) == 2:  # Grayscale image
            img_array = np.stack((img_array,)*3, axis=-1)
        elif img_array.shape[-1] == 4:  # RGBA image
            img_array = img_array[:, :, :3]
            
        img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension

        # Make a prediction
        prediction = model.predict(img_array)
        
        # Process the prediction
        predicted_mask = (prediction[0, :, :, 0] * 255).astype(np.uint8)
        
        # Convert the prediction to a PIL image
        prediction_image = Image.fromarray(predicted_mask, mode='L')
        
        # Resize the predicted image back to the original image size
        prediction_image = prediction_image.resize(original_size, Image.NEAREST)

        return prediction_image
    
    except Exception as e:
        raise gr.Error(f"Error processing image: {str(e)}")

# Create the Gradio interface with examples
iface = gr.Interface(
    fn=predict_and_visualize,
    inputs=gr.Image(type="pil", label="Input Image"),
    outputs=gr.Image(type="pil", label="Predicted Mask"),
    title="MilitarEye: Military Stealth Camouflage Detector",
    description="Upload an image of a military personnel camouflaged in their surroundings. The model will predict the camouflage mask silhouette.",
    allow_flagging="never"
)

# Launch the app
iface.launch()