import gradio as gr
import numpy as np
import tensorflow as tf
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from keras.models import Model
import matplotlib.pyplot as plt
import logging
from skimage.transform import resize
from PIL import Image
from tqdm import tqdm

class SwarmAgent:
    def __init__(self, position, velocity):
        self.position = position
        self.velocity = velocity
        self.m = np.zeros_like(position)
        self.v = np.zeros_like(position)

class SwarmNeuralNetwork:
    def __init__(self, num_agents, image_shape, target_image):
        self.image_shape = image_shape
        self.resized_shape = (64, 64, 3)
        self.agents = [SwarmAgent(self.random_position(), self.random_velocity()) for _ in range(num_agents)]
        self.target_image = self.load_target_image(target_image)
        self.generated_image = np.random.randn(*image_shape)  # Start with noise
        self.mobilenet = self.load_mobilenet_model()
        self.current_epoch = 0
        self.noise_schedule = np.linspace(0.1, 0.002, 1000)  # Noise schedule

    def random_position(self):
        return np.random.randn(*self.image_shape)  # Use Gaussian noise

    def random_velocity(self):
        return np.random.randn(*self.image_shape) * 0.01

    def load_target_image(self, img):
        img = img.resize((self.image_shape[1], self.image_shape[0]))
        img_array = np.array(img) / 127.5 - 1  # Normalize to [-1, 1]
        plt.imshow((img_array + 1) / 2)  # Convert back to [0, 1] for display
        plt.title('Target Image')
        plt.show()
        return img_array

    def resize_image(self, image):
        return resize(image, self.resized_shape, anti_aliasing=True)

    def load_mobilenet_model(self):
        mobilenet = MobileNetV2(weights='imagenet', include_top=False, input_shape=self.resized_shape)
        return Model(inputs=mobilenet.input, outputs=mobilenet.get_layer('block_13_expand_relu').output)

    def add_positional_encoding(self, image):
        h, w, c = image.shape
        pos_enc = np.zeros_like(image)
        for i in range(h):
            for j in range(w):
                pos_enc[i, j, :] = [i/h, j/w, 0]
        return image + pos_enc

    def multi_head_attention(self, agent, num_heads=4):
        attention_scores = []
        for _ in range(num_heads):
            similarity = np.exp(-np.sum((agent.position - self.target_image)**2, axis=-1))
            attention_score = similarity / np.sum(similarity)
            attention_scores.append(attention_score)
        attention = np.mean(attention_scores, axis=0)
        return np.expand_dims(attention, axis=-1)

    def multi_scale_perceptual_loss(self, agent_positions):
        target_image_resized = self.resize_image((self.target_image + 1) / 2)  # Convert to [0, 1] for MobileNet
        target_image_preprocessed = preprocess_input(target_image_resized[np.newaxis, ...] * 255)  # MobileNet expects [0, 255]
        target_features = self.mobilenet.predict(target_image_preprocessed)

        losses = []
        for agent_position in agent_positions:
            agent_image_resized = self.resize_image((agent_position + 1) / 2)
            agent_image_preprocessed = preprocess_input(agent_image_resized[np.newaxis, ...] * 255)
            agent_features = self.mobilenet.predict(agent_image_preprocessed)

            loss = np.mean((target_features - agent_features)**2)
            losses.append(1 / (1 + loss))

        return np.array(losses)

    def update_agents(self, timestep):
        noise_level = self.noise_schedule[min(timestep, len(self.noise_schedule) - 1)]
        
        for agent in self.agents:
            # Predict noise
            predicted_noise = agent.position - self.target_image
            
            # Denoise
            denoised = (agent.position - noise_level * predicted_noise) / (1 - noise_level)
            
            # Add scaled noise for next step
            agent.position = denoised + np.random.randn(*self.image_shape) * np.sqrt(noise_level)
            
            # Clip values
            agent.position = np.clip(agent.position, -1, 1)

    def generate_image(self):
        self.generated_image = np.mean([agent.position for agent in self.agents], axis=0)
        # Normalize to [0, 1] range for display
        self.generated_image = (self.generated_image + 1) / 2
        self.generated_image = np.clip(self.generated_image, 0, 1)

    def train(self, epochs):
        logging.basicConfig(filename='training.log', level=logging.INFO)

        for epoch in tqdm(range(epochs), desc="Training Epochs"):
            self.update_agents(epoch)
            self.generate_image()

            mse = np.mean(((self.generated_image * 2 - 1) - self.target_image)**2)
            logging.info(f"Epoch {epoch}, MSE: {mse}")

            if epoch % 10 == 0:
                print(f"Epoch {epoch}, MSE: {mse}")
                self.display_image(self.generated_image, title=f'Epoch {epoch}')
            self.current_epoch += 1

    def display_image(self, image, title=''):
        plt.imshow(image)
        plt.title(title)
        plt.axis('off')
        plt.show()

    def display_agent_positions(self, epoch):
        fig, ax = plt.subplots()
        positions = np.array([agent.position for agent in self.agents])
        ax.imshow(self.generated_image, extent=[0, self.image_shape[1], 0, self.image_shape[0]])
        ax.scatter(positions[:, :, 0].flatten(), positions[:, :, 1].flatten(), s=1, c='red')
        plt.title(f'Agent Positions at Epoch {epoch}')
        plt.show()

    def save_model(self, filename):
        model_state = {
            'agents': self.agents,
            'generated_image': self.generated_image,
            'current_epoch': self.current_epoch
        }
        np.save(filename, model_state)

    def load_model(self, filename):
        model_state = np.load(filename, allow_pickle=True).item()
        self.agents = model_state['agents']
        self.generated_image = model_state['generated_image']
        self.current_epoch = model_state['current_epoch']

    def generate_new_image(self, num_steps=1000):
        for agent in self.agents:
            agent.position = np.random.randn(*self.image_shape)
        
        for step in tqdm(range(num_steps), desc="Generating Image"):
            self.update_agents(num_steps - step - 1)  # Reverse order
        
        self.generate_image()
        return self.generated_image

# Gradio Interface
def train_snn(image, num_agents, epochs):
    snn = SwarmNeuralNetwork(num_agents=num_agents, image_shape=(64, 64, 3), target_image=image)
    snn.train(epochs=epochs)
    snn.save_model('snn_model.npy')
    return snn.generated_image

def generate_new_image():
    snn = SwarmNeuralNetwork(num_agents=2000, image_shape=(64, 64, 3), target_image=None)
    snn.load_model('snn_model.npy')
    new_image = snn.generate_new_image()
    return new_image

interface = gr.Interface(
    fn=train_snn,
    inputs=[
        gr.Image(type="pil", label="Upload Target Image"),
        gr.Slider(minimum=500, maximum=3000, value=2000, label="Number of Agents"),
        gr.Slider(minimum=10, maximum=200, value=100, label="Number of Epochs")
    ],
    outputs=gr.Image(type="numpy", label="Generated Image"),
    title="Swarm Neural Network Image Generation",
    description="Upload an image and set the number of agents and epochs to train the Swarm Neural Network to generate a new image."
)

interface.launch()