try:
    import torch
    import torchvision
except ImportError:
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "torchvision"])
    import torch
    import torchvision

import gradio as gr
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from transformers import pipeline
from scipy.ndimage import gaussian_filter

def preprocess_image(image, target_size=(512, 512)):
    """Preprocess the input image"""
    if isinstance(image, str):
        image = Image.open(image)
    elif isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    
    # Calculate aspect ratio preserving resize
    aspect_ratio = image.size[0] / image.size[1]
    if aspect_ratio > 1:
        new_width = int(target_size[0] * aspect_ratio)
        new_height = target_size[1]
    else:
        new_width = target_size[0]
        new_height = int(target_size[1] / aspect_ratio)

    preprocess = transforms.Compose([
        transforms.Resize((new_height, new_width)),
        transforms.CenterCrop(target_size),
    ])
    
    return preprocess(image)

def estimate_depth(image, pipe):
    """Estimate depth using the Depth-Anything model"""
    depth_output = pipe(image)
    depth_map = depth_output["depth"]
    depth_map = np.array(depth_map) / 16.67
    return depth_map

def apply_depth_aware_blur(image, depth_map, max_sigma, min_sigma):
    """Apply variable Gaussian blur based on depth values"""
    image_array = np.array(image)
    blurred = np.zeros_like(image_array, dtype=np.float32)
    
    # Calculate sigma for each depth value
    sigmas = np.interp(depth_map, [depth_map.min(), depth_map.max()], [min_sigma, max_sigma])
    unique_sigmas = np.unique(sigmas)
    blur_stack = {}

    # Create blurred versions for each unique sigma
    for sigma in unique_sigmas:
        if sigma > 0:
            blurred_image = np.zeros_like(image_array, dtype=np.float32)
            for channel in range(3):
                blurred_image[:, :, channel] = gaussian_filter(
                    image_array[:, :, channel].astype(np.float32),
                    sigma=sigma
                )
            blur_stack[sigma] = blurred_image

    # Combine blurred versions
    for sigma in unique_sigmas:
        if sigma > 0:
            mask = (sigmas == sigma)
            mask_3d = np.stack([mask] * 3, axis=2)
            blurred += mask_3d * blur_stack[sigma]
        else:
            mask = (sigmas == 0)
            mask_3d = np.stack([mask] * 3, axis=2)
            blurred += mask_3d * image_array

    return Image.fromarray(blurred.astype(np.uint8))

def apply_gaussian_blur(image, sigma):
    """Apply uniform Gaussian blur"""
    image_array = np.array(image)
    blurred = np.zeros_like(image_array)
    
    for channel in range(3):
        blurred[:, :, channel] = gaussian_filter(
            image_array[:, :, channel],
            sigma=sigma
        )
    
    return Image.fromarray(blurred.astype(np.uint8))

# Initialize depth estimation pipeline (moved inside the processing function to avoid CUDA issues)
def get_depth_pipeline():
    return pipeline(
        task="depth-estimation",
        model="depth-anything/Depth-Anything-V2-Small-hf",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device=0 if torch.cuda.is_available() else -1
    )

def process_image(image, blur_type, gaussian_sigma, lens_min_sigma, lens_max_sigma):
    """Main processing function for Gradio interface"""
    if image is None:
        return None
        
    processed_image = preprocess_image(image)
    
    if blur_type == "Gaussian Blur":
        result = apply_gaussian_blur(processed_image, gaussian_sigma)
    else:  # Lens Blur
        pipe = get_depth_pipeline()
        depth_map = estimate_depth(processed_image, pipe)
        result = apply_depth_aware_blur(processed_image, depth_map, lens_max_sigma, lens_min_sigma)
    
    return result

# Create Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Image Blur Effects Demo")
    gr.Markdown("Apply Gaussian or Lens (Depth-aware) blur to your images")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="numpy")
            blur_type = gr.Radio(
                choices=["Gaussian Blur", "Lens Blur"],
                label="Blur Effect",
                value="Gaussian Blur"
            )
            
            with gr.Column(visible=True) as gaussian_controls:
                gaussian_sigma = gr.Slider(
                    minimum=0, maximum=20, value=5,
                    label="Gaussian Blur Sigma",
                    step=0.5
                )
            
            with gr.Column() as lens_controls:
                lens_min_sigma = gr.Slider(
                    minimum=0, maximum=20, value=15,
                    label="Maximum Blur (Far)",
                    step=0.5
                )
                lens_max_sigma = gr.Slider(
                    minimum=0, maximum=10, value=0,
                    label="Minimum Blur (Near)",
                    step=0.5
                )
            
            process_btn = gr.Button("Apply Blur")
        
        with gr.Column():
            output_image = gr.Image(label="Output Image")
    
    # Handle visibility of controls based on blur type selection
    def update_controls(blur_type):
        return {
            gaussian_controls: blur_type == "Gaussian Blur",
            lens_controls: blur_type == "Lens Blur"
        }
    
    blur_type.change(
        fn=update_controls,
        inputs=[blur_type],
        outputs=[gaussian_controls, lens_controls]
    )
    
    # Process image when button is clicked
    process_btn.click(
        fn=process_image,
        inputs=[
            input_image,
            blur_type,
            gaussian_sigma,
            lens_min_sigma,
            lens_max_sigma
        ],
        outputs=output_image
    )

# Launch the demo
demo.launch()