import os
import spaces
import gradio as gr
import numpy as np
import tensorflow as tf
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from keras.models import Model
import matplotlib.pyplot as plt
import logging
from skimage.transform import resize
from PIL import Image, ImageEnhance, ImageFilter
from tqdm import tqdm

# Disable GPU usage by default
os.environ['CUDA_VISIBLE_DEVICES'] = ''

class SwarmAgent:
    def __init__(self, position, velocity):
        self.position = position
        self.velocity = velocity
        self.m = np.zeros_like(position)
        self.v = np.zeros_like(position)

class SwarmNeuralNetwork:
    def __init__(self, num_agents, image_shape, target_image_path):
        self.image_shape = image_shape
        self.resized_shape = (128, 128, 3)  # Reduced resolution
        self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
        self.target_image = self.load_target_image(target_image_path)
        self.generated_image = np.random.randn(*image_shape)  # Start with noise
        self.mobilenet = self.load_mobilenet_model()
        self.current_epoch = 0
        self.noise_schedule = np.linspace(0.1, 0.002, 1000)  # Noise schedule

    def random_position(self):
        return np.random.randn(*self.image_shape)  # Use Gaussian noise

    def random_velocity(self):
        return np.random.randn(*self.image_shape) * 0.01

    def load_target_image(self, img_path):
        img = Image.open(img_path)
        img = img.resize((self.image_shape[1], self.image_shape[0]))
        img_array = np.array(img) / 127.5 - 1  # Normalize to [-1, 1]
        plt.imshow((img_array + 1) / 2)  # Convert back to [0, 1] for display
        plt.title('Target Image')
        plt.show()
        return img_array

    def resize_image(self, image):
        return resize(image, self.resized_shape, anti_aliasing=True)

    def load_mobilenet_model(self):
        mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=self.resized_shape)
        return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output)

    def add_positional_encoding(self, image):
        h, w, c = image.shape
        pos_enc = np.zeros_like(image)
        for i in range(h):
            for j in range(w):
                pos_enc[i, j, :] = [i/h, j/w, 0]
        return image + pos_enc

    def multi_head_attention(self, agent, num_heads=4):
        attention_scores = []
        for _ in range(num_heads):
            similarity = np.exp(-np.sum((agent.position - self.target_image)**2, axis=-1))
            attention_score = similarity / np.sum(similarity)
            attention_scores.append(attention_score)
        attention = np.mean(attention_scores, axis=0)
        return np.expand_dims(attention, axis=-1)

    def multi_scale_perceptual_loss(self, agent_positions):
        target_image_resized = self.resize_image((self.target_image + 1) / 2)  # Convert to [0, 1] for MobileNet
        target_image_preprocessed = preprocess_input(target_image_resized[np.newaxis, ...] * 255)  # MobileNet expects [0, 255]
        target_features = self.mobilenet.predict(target_image_preprocessed)

        losses = []
        for agent_position in agent_positions:
            agent_image_resized = self.resize_image((agent_position + 1) / 2)
            agent_image_preprocessed = preprocess_input(agent_image_resized[np.newaxis, ...] * 255)
            agent_features = self.mobilenet.predict(agent_image_preprocessed)

            loss = np.mean((target_features - agent_features)**2)
            losses.append(1 / (1 + loss))

        return np.array(losses)

    @spaces.GPU(duration=120)
    def update_agents(self, timestep):
        noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
        
        for agent in self.agents:
            # Predict noise
            predicted_noise = agent.position - self.target_image
            
            # Denoise
            denoised = (agent.position - noise_level * predicted_noise) / (1 - noise_level)
            
            # Add scaled noise for next step
            agent.position = denoised + np.random.randn(*self.image_shape) * np.sqrt(noise_level)
            
            # Clip values
            agent.position = np.clip(agent.position, -1, 1)

    @spaces.GPU(duration=120)
    def generate_image(self):
        self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
        # Normalize to [0, 1] range for display
        self.generated_image = (self.generated_image + 1) / 2
        self.generated_image = np.clip(self.generated_image, 0, 1)
        
        # Apply sharpening filter
        image_pil = Image.fromarray((self.generated_image * 255).astype(np.uint8))
        image_pil = image_pil.filter(ImageFilter.SHARPEN)
        self.generated_image = np.array(image_pil) / 255.0

    @spaces.GPU(duration=120)
    def train(self, epochs):
        logging.basicConfig(filename='training.log', level=logging.INFO)

        for epoch in tqdm(range(epochs), desc="Training Epochs"):
            self.update_agents(epoch)
            self.generate_image()

            mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2)
            logging.info(f"Epoch {epoch}, MSE: {mse}")

            if epoch % 2 == 0:  # Display more frequently for faster feedback
                print(f"Epoch {epoch}, MSE: {mse}")
                self.display_image(self.generated_image, title=f'Epoch {epoch}')
            self.current_epoch += 1

    def display_image(self, image, title=''):
        plt.imshow(image)
        plt.title(title)
        plt.axis('off')
        plt.show()

    def display_agent_positions(self, epoch):
        fig, ax = plt.subplots()
        positions = np.array([agent.position for agent in self.agents])
        ax.imshow(self.generated_image, extent=[0, self.image_shape[1], 0, self.image_shape[0]])
        ax.scatter(positions[:, :, 0].flatten(), positions[:, :, 1].flatten(), s=1, c='red')
        plt.title(f'Agent Positions at Epoch {epoch}')
        plt.show()

    def save_model(self, filename):
        model_state = {
            'agents': self.agents,
            'generated_image': self.generated_image,
            'current_epoch': self.current_epoch
        }
        np.save(filename, model_state)

    def load_model(self, filename):
        model_state = np.load(filename, allow_pickle=True).item()
        self.agents = model_state['agents']
        self.generated_image = model_state['generated_image']
        self.current_epoch = model_state['current_epoch']

    @spaces.GPU(duration=120)
    def generate_new_image(self, num_steps=200):  # Reduced number of steps
        for agent in self.agents:
            agent.position = np.random.randn(*self.image_shape)
        
        for step in tqdm(range(num_steps), desc="Generating Image"):
            self.update_agents(num_steps - step - 1)  # Reverse order
        
        self.generate_image()
        return self.generated_image

    def adjust_limbs(self, arm_position, leg_position):
        # Logic to adjust arm and leg positions in the target image
        # For simplicity, let's assume arm_position and leg_position range from -100 to 100
        arm_shift = arm_position / 100.0 * 0.2  # Scale to a reasonable range
        leg_shift = leg_position / 100.0 * 0.2  # Scale to a reasonable range

        # Translate the positions of the arms and legs in the image
        for agent in self.agents:
            agent.position[50:100, 50:200, :] += arm_shift  # Example adjustment
            agent.position[150:200, 50:200, :] += leg_shift  # Example adjustment

# Gradio Interface
def train_snn(image_path, num_agents, epochs, arm_position, leg_position, brightness, contrast, color):
    snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(128, 128, 3), target_image_path=image_path)  # Reduced resolution
    
    # Apply user-specified adjustments to the target image
    image = Image.open(image_path)
    image = ImageEnhance.Brightness(image).enhance(brightness)
    image = ImageEnhance.Contrast(image).enhance(contrast)
    image = ImageEnhance.Color(image).enhance(color)
    
    snn.target_image = snn.load_target_image(image_path)
    
    # Adjust limb positions based on slider values
    snn.adjust_limbs(arm_position, leg_position)
    
    snn.train(epochs=epochs)
    snn.save_model('snn_model.npy')
    generated_image = snn.generated_image
    return generated_image

def generate_new_image():
    snn = SwarmNeuralNetwork(num_agents=1000, image_shape=(128, 128, 3), target_image_path=None)  # Reduced number of agents
    snn.load_model('snn_model.npy')
    new_image = snn.generate_new_image()
    return new_image

interface = gr.Interface(
    fn=train_snn,
    inputs=[
        gr.Image(type="filepath", label="Upload Target Image"),
        gr.Slider(minimum=100, maximum=1000, value=500, label="Number of Agents"),  # Further reduced range for number of agents
        gr.Slider(minimum=5, maximum=20, value=10, label="Number of Epochs"),  # Further reduced range for number of epochs
        gr.Slider(minimum=-100, maximum=100, value=0, label="Arm Position"),
        gr.Slider(minimum=-100, maximum=100, value=0, label="Leg Position"),
        gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Brightness"),
        gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Contrast"),
        gr.Slider(minimum=0.5, maximum=2.0, value=1.0, label="Color Balance")
    ],
    outputs=gr.Image(type="numpy", label="Generated Image"),
    title="Swarm Neural Network Image Generation",
    description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image. Adjust arm and leg positions, brightness, contrast, and color balance for personalization."
)

interface.launch()