from typing import List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import ListConfig
from torchvision.utils import save_image
from ...util import append_dims, instantiate_from_config


class StandardDiffusionLoss(nn.Module):
    def __init__(
        self,
        sigma_sampler_config,
        type="l2",
        offset_noise_level=0.0,
        batch2model_keys: Optional[Union[str, List[str], ListConfig]] = None,
    ):
        super().__init__()

        assert type in ["l2", "l1"]

        self.sigma_sampler = instantiate_from_config(sigma_sampler_config)

        self.type = type
        self.offset_noise_level = offset_noise_level

        if not batch2model_keys:
            batch2model_keys = []

        if isinstance(batch2model_keys, str):
            batch2model_keys = [batch2model_keys]

        self.batch2model_keys = set(batch2model_keys)

    def __call__(self, network, denoiser, conditioner, input, batch, *args, **kwarg):
        cond = conditioner(batch)
        additional_model_inputs = {
            key: batch[key] for key in self.batch2model_keys.intersection(batch)
        }

        sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
        noise = torch.randn_like(input)
        if self.offset_noise_level > 0.0:
            noise = noise + self.offset_noise_level * append_dims(
                torch.randn(input.shape[0], device=input.device), input.ndim
            )
        noised_input = input + noise * append_dims(sigmas, input.ndim)
        model_output = denoiser(
            network, noised_input, sigmas, cond, **additional_model_inputs
        )
        w = append_dims(denoiser.w(sigmas), input.ndim)

        loss = self.get_diff_loss(model_output, input, w)
        loss = loss.mean()
        loss_dict = {"loss": loss}

        return loss, loss_dict

    def get_diff_loss(self, model_output, target, w):
        if self.type == "l2":
            return torch.mean(
                (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1
            )
        elif self.type == "l1":
            return torch.mean(
                (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1
            )


class FullLoss(StandardDiffusionLoss):

    def __init__(
        self,
        seq_len=12,
        kernel_size=3,
        gaussian_sigma=0.5,
        min_attn_size=16,
        lambda_local_loss=0.0,
        lambda_ocr_loss=0.0,
        ocr_enabled = False,
        predictor_config = None,
        *args, **kwarg
    ):
        super().__init__(*args, **kwarg)

        self.gaussian_kernel_size = kernel_size
        gaussian_kernel = self.get_gaussian_kernel(kernel_size=self.gaussian_kernel_size, sigma=gaussian_sigma, out_channels=seq_len)
        self.register_buffer("g_kernel", gaussian_kernel.requires_grad_(False))

        self.min_attn_size = min_attn_size
        self.lambda_local_loss = lambda_local_loss
        self.lambda_ocr_loss = lambda_ocr_loss

        self.ocr_enabled = ocr_enabled
        if ocr_enabled:
            self.predictor = instantiate_from_config(predictor_config)
    
    def get_gaussian_kernel(self, kernel_size=3, sigma=1, out_channels=3):
        # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
        x_coord = torch.arange(kernel_size)
        x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
        y_grid = x_grid.t()
        xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()

        mean = (kernel_size - 1)/2.
        variance = sigma**2.

        # Calculate the 2-dimensional gaussian kernel which is
        # the product of two gaussian distributions for two different
        # variables (in this case called x and y)
        gaussian_kernel = (1./(2.*torch.pi*variance)) *\
                        torch.exp(
                            -torch.sum((xy_grid - mean)**2., dim=-1) /\
                            (2*variance)
                        )

        # Make sure sum of values in gaussian kernel equals 1.
        gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

        # Reshape to 2d depthwise convolutional weight
        gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
        gaussian_kernel = gaussian_kernel.tile(out_channels, 1, 1, 1)
        
        return gaussian_kernel

    def __call__(self, network, denoiser, conditioner, input, batch, first_stage_model, scaler):

        cond = conditioner(batch)

        sigmas = self.sigma_sampler(input.shape[0]).to(input.device)
        noise = torch.randn_like(input)
        if self.offset_noise_level > 0.0:
            noise = noise + self.offset_noise_level * append_dims(
                torch.randn(input.shape[0], device=input.device), input.ndim
            )

        noised_input = input + noise * append_dims(sigmas, input.ndim)
        model_output = denoiser(network, noised_input, sigmas, cond)
        w = append_dims(denoiser.w(sigmas), input.ndim)

        diff_loss = self.get_diff_loss(model_output, input, w)
        local_loss = self.get_local_loss(network.diffusion_model.attn_map_cache, batch["seg"], batch["seg_mask"])
        diff_loss = diff_loss.mean()
        local_loss = local_loss.mean()

        if self.ocr_enabled:
            ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler)
            ocr_loss = ocr_loss.mean()

        loss = diff_loss + self.lambda_local_loss * local_loss
        if self.ocr_enabled:
            loss += self.lambda_ocr_loss * ocr_loss

        loss_dict = {
            "loss/diff_loss": diff_loss,
            "loss/local_loss": local_loss,
            "loss/full_loss": loss
        }

        if self.ocr_enabled:
            loss_dict["loss/ocr_loss"] = ocr_loss

        return loss, loss_dict
    
    def get_ocr_loss(self, model_output, r_bbox, label, first_stage_model, scaler):

        model_output = 1 / scaler * model_output
        model_output_decoded = first_stage_model.decode(model_output)
        model_output_crops = []
        
        for i, bbox in enumerate(r_bbox):
            m_top, m_bottom, m_left, m_right = bbox
            model_output_crops.append(model_output_decoded[i, :, m_top:m_bottom, m_left:m_right])

        loss = self.predictor.calc_loss(model_output_crops, label)

        return loss

    def get_min_local_loss(self, attn_map_cache, mask, seg_mask):

        loss = 0
        count = 0

        for item in attn_map_cache:

            heads = item["heads"]
            size = item["size"]
            attn_map = item["attn_map"]

            if size < self.min_attn_size: continue

            seg_l = seg_mask.shape[1]

            bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
            attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
            
            assert seg_l <= l
            attn_map = attn_map[..., :seg_l]
            attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
            attn_map = attn_map.mean(dim = 1) # b, l, n

            attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s
            attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel
            attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n
            
            mask_map = F.interpolate(mask, (size, size))
            mask_map = mask_map.tile((1, seg_l, 1, 1))
            mask_map = mask_map.reshape((-1, seg_l, n)) # b, l, n

            p_loss = (mask_map * attn_map).max(dim = -1)[0] # b, l
            p_loss = p_loss + (1 - seg_mask) # b, l
            p_loss = p_loss.min(dim = -1)[0] # b,

            loss += -p_loss
            count += 1

        loss = loss / count

        return loss

    def get_local_loss(self, attn_map_cache, seg, seg_mask):

        loss = 0
        count = 0

        for item in attn_map_cache:

            heads = item["heads"]
            size = item["size"]
            attn_map = item["attn_map"]

            if size < self.min_attn_size: continue

            seg_l = seg_mask.shape[1]

            bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
            attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
            
            assert seg_l <= l
            attn_map = attn_map[..., :seg_l]
            attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
            attn_map = attn_map.mean(dim = 1) # b, l, n

            attn_map = attn_map.reshape((-1, seg_l, size, size)) # b, l, s, s
            attn_map = F.conv2d(attn_map, self.g_kernel, padding = self.gaussian_kernel_size//2, groups=seg_l) # gaussian blur on each channel
            attn_map = attn_map.reshape((-1, seg_l, n)) # b, l, n

            seg_map = F.interpolate(seg, (size, size))
            seg_map = seg_map.reshape((-1, seg_l, n)) # b, l, n
            n_seg_map = 1 - seg_map

            p_loss = (seg_map * attn_map).max(dim = -1)[0] # b, l
            n_loss = (n_seg_map * attn_map).max(dim = -1)[0] # b, l

            p_loss = p_loss * seg_mask # b, l
            n_loss = n_loss * seg_mask # b, l

            p_loss = p_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b,
            n_loss = n_loss.sum(dim = -1) / seg_mask.sum(dim = -1) # b,

            f_loss = n_loss - p_loss # b,
            loss += f_loss
            count += 1

        loss = loss / count

        return loss