import numpy as np
import torch
import math
import xformers

class DummyController:
    def __call__(self, *args):
        return args[0]
    def __init__(self):
        self.num_att_layers = 0

class GroupedCAController:
    def __init__(self, mask_list = None):
        self.mask_list = mask_list
        if self.mask_list is None:
            self.is_decom = False
        else:
            self.is_decom = True      
                  
    def mask_img_to_mask_vec(self, mask, length):
        mask_vec = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), (length, length)).squeeze()
        mask_vec = mask_vec.flatten()
        return mask_vec
            
    def ca_forward_decom(self, q,  k_list,  v_list,  scale, place_in_unet):
            # attn [Bh, N,   d ]
            #      [8, 4096, 77]
            # q  [Bh, N,   d] [8, 4096, 40]    [8, 1024, 80]   [8, 256,160]    [8, 64, 160]
            # k  [Bh, P,   d] [8, 77  , 40]    [8, 77,   80]   [8, 77, 160]    [8, 77, 160]
            # v  [Bh, P,   d] [8, 77  , 40]    [8, 77,   80]   [8, 77, 160]    [8, 77, 160]
        N = q.shape[1]
        mask_vec_list = []
        for mask in self.mask_list:
            mask_vec = self.mask_img_to_mask_vec(mask,  int(math.sqrt(N)))   # [1,N,1]
            mask_vec = mask_vec.unsqueeze(0).unsqueeze(-1)
            mask_vec_list.append(mask_vec)
        out = 0
        for mask_vec, k, v in zip(mask_vec_list, k_list, v_list):
            sim = torch.einsum("b i d, b j d -> b i j", q, k) * scale   # [8, 4096, 20]
            attn = sim.softmax(dim=-1)                            # [Bh,N,P] [8,4096,20]
            attn = attn.masked_fill(mask_vec==0, 0)
            masked_out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h] 
            # mask_vec_inf = torch.where(mask_vec>0, 0,   torch.finfo(k.dtype).min)
            # masked_out1 = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask_vec_inf, op=None, scale=scale)
            out += masked_out
        return out

def reshape_heads_to_batch_dim(self):
    def func(tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
    return func

def reshape_batch_dim_to_heads(self):
    def func(tensor):
        batch_size, seq_len, dim = tensor.shape
        head_size = self.num_heads
        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
        tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
    return func

def register_attention_disentangled_control(unet, controller):
    def ca_forward(self, place_in_unet):
        to_out = self.to_out
        if type(to_out) is torch.nn.modules.container.ModuleList:
            to_out = self.to_out[0]
        else:
            to_out = self.to_out        
        def forward(x, encoder_hidden_states =None, attention_mask=None):
            if isinstance(controller, DummyController):  # SA CA full
                q = self.to_q(x)
                is_cross = encoder_hidden_states is not None
                encoder_hidden_states = encoder_hidden_states if is_cross else x
                k = self.to_k(encoder_hidden_states)
                v = self.to_v(encoder_hidden_states)
                q = self.head_to_batch_dim(q)
                k = self.head_to_batch_dim(k)
                v = self.head_to_batch_dim(v)
                
                # sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
                # attn = sim.softmax(dim=-1)
                # attn = controller(attn, is_cross, place_in_unet)
                # out = torch.einsum("b i j, b j d -> b i d", attn, v)
                out = xformers.ops.memory_efficient_attention(
                    q, k, v, attn_bias=None, op=None, scale=self.scale
                ) 
                out = self.batch_to_head_dim(out)
            else: # decom: CA+SA
                is_cross = encoder_hidden_states is not None
                assert is_cross is not None
                encoder_hidden_states_list = encoder_hidden_states if is_cross else x
                q = self.to_q(x)
                q = self.head_to_batch_dim(q) # [Bh, 4096, 320/h ] h: 8
                if is_cross:  #CA
                    k_list = []
                    v_list = []
                    assert type(encoder_hidden_states_list) is list
                    for encoder_hidden_states in encoder_hidden_states_list:
                        k = self.to_k(encoder_hidden_states)
                        k = self.head_to_batch_dim(k) # [Bh, 77,   320/h ] 
                        k_list.append(k)
                        v = self.to_v(encoder_hidden_states)
                        v = self.head_to_batch_dim(v) # [Bh, 77,   320/h ]
                        v_list.append(v)
                    out = controller.ca_forward_decom(q, k_list, v_list, self.scale, place_in_unet)   # [Bh,N,d]
                    out = self.batch_to_head_dim(out)
                else:   # SA
                    exit("decomposing SA!")
                    k = self.to_k(x)
                    v = self.to_v(x)
                    k = self.head_to_batch_dim(k) # [Bh, 77,   320/h ] 
                    v = self.head_to_batch_dim(v) # [Bh, 77,   320/h ]
                    import pdb; pdb.set_trace()
                    if  k.shape[1] <= 1024 ** 2:
                        out = controller.sa_forward(q, k, v, self.scale, place_in_unet)   # [Bh,N,d] 
                    else:
                        print("warining")
                        out = controller.sa_forward_decom(q, k, v, self.scale, place_in_unet)   # [Bh,N,d] 
                    # sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
                    # attn = sim.softmax(dim=-1)             # [8,4096,4096]   [Bh,N,N] 
                    # out = torch.einsum("b i j, b j d -> b i d", attn, v) #  [Bh,N,d] [8,4096,320/h] 
     
                    out = self.batch_to_head_dim(out)   # [B, H, N, D]
            
            return to_out(out)

        return forward

    if controller is None:
        controller = DummyController()

    def register_recr(net_, count, place_in_unet):   
        if net_.__class__.__name__ == 'Attention' and net_.to_k.in_features == unet.ca_dim:
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = unet.named_children()

    for net in sub_nets:
        if "down" in net[0]:
            down_count = register_recr(net[1], 0, "down")#6
            cross_att_count += down_count
        elif "up" in net[0]:
            up_count = register_recr(net[1], 0, "up")    #9
            cross_att_count += up_count
        elif "mid" in net[0]:
            mid_count = register_recr(net[1], 0, "mid")  #1
            cross_att_count += mid_count
    controller.num_att_layers = cross_att_count