from einops import repeat, rearrange
from typing import Callable, Optional, Union
from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention
# from t2v_enhanced.model.diffusers_conditional.controldiffusers.models.attention import Attention
from diffusers.utils.import_utils import is_xformers_available
from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams
import torch
import torch.nn.functional as F
if is_xformers_available():
    import xformers
    import xformers.ops
else:
    xformers = None


def set_use_memory_efficient_attention_xformers(
    model, num_frame_conditioning: int, num_frames: int, attention_mask_params: AttentionMaskParams, valid: bool = True, attention_op: Optional[Callable] = None
) -> None:
    # Recursively walk through all the children.
    # Any children which exposes the set_use_memory_efficient_attention_xformers method
    # gets the message
    def fn_recursive_set_mem_eff(module: torch.nn.Module):
        if hasattr(module, "set_processor"):

            module.set_processor(XFormersAttnProcessor(attention_op=attention_op,
                                                       num_frame_conditioning=num_frame_conditioning,
                                                       num_frames=num_frames,
                                                       attention_mask_params=attention_mask_params,)
                                 )

        for child in module.children():
            fn_recursive_set_mem_eff(child)

    for module in model.children():
        if isinstance(module, torch.nn.Module):
            fn_recursive_set_mem_eff(module)


class XFormersAttnProcessor:
    def __init__(self,
                 attention_mask_params: AttentionMaskParams,
                 attention_op: Optional[Callable] = None,
                 num_frame_conditioning: int = None,
                 num_frames: int = None,
                 use_image_embedding: bool = False,
                 ):
        self.attention_op = attention_op
        self.num_frame_conditioning = num_frame_conditioning
        self.num_frames = num_frames
        self.temp_attend_on_neighborhood_of_condition_frames = attention_mask_params.temp_attend_on_neighborhood_of_condition_frames
        self.spatial_attend_on_condition_frames = attention_mask_params.spatial_attend_on_condition_frames
        self.use_image_embedding = use_image_embedding

    def __call__(self, attn: Attention, hidden_states, hidden_state_height=None, hidden_state_width=None, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        key_img = None
        value_img = None
        hidden_states_img = None
        if attention_mask is not None:
            attention_mask = repeat(
                attention_mask, "1 F D -> B F D", B=batch_size)

        attention_mask = attn.prepare_attention_mask(
            attention_mask, sequence_length, batch_size)

        query = attn.to_q(hidden_states)

        is_cross_attention = encoder_hidden_states is not None

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(
                encoder_hidden_states)
        default_attention = not hasattr(attn, "is_spatial_attention")
        if default_attention:
            assert not self.temp_attend_on_neighborhood_of_condition_frames, "special attention must be implemented with new interface"
            assert not self.spatial_attend_on_condition_frames, "special attention must be implemented with new interface"
        is_spatial_attention = attn.is_spatial_attention if hasattr(
            attn, "is_spatial_attention") else False
        use_image_embedding = attn.use_image_embedding if hasattr(
            attn, "use_image_embedding") else False

        if is_spatial_attention and use_image_embedding and attn.cross_attention_mode:
            assert not self.spatial_attend_on_condition_frames, "Not implemented together with image embedding"

            alpha = attn.alpha
            encoder_hidden_states_txt = encoder_hidden_states[:, :77, :]

            encoder_hidden_states_mixed = attn.conv(encoder_hidden_states)
            encoder_hidden_states_mixed = attn.conv_ln(encoder_hidden_states_mixed)
            encoder_hidden_states = encoder_hidden_states_txt + encoder_hidden_states_mixed * F.silu(alpha)

            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)
        else:
            key = attn.to_k(encoder_hidden_states)
            value = attn.to_v(encoder_hidden_states)




        if not default_attention and not is_spatial_attention and self.temp_attend_on_neighborhood_of_condition_frames and not attn.cross_attention_mode:
            # normal attention
            query_condition = query[:, :self.num_frame_conditioning]
            query_condition = attn.head_to_batch_dim(
                query_condition).contiguous()
            key_condition = key
            value_condition = value
            key_condition = attn.head_to_batch_dim(key_condition).contiguous()
            value_condition = attn.head_to_batch_dim(
                value_condition).contiguous()
            hidden_states_condition = xformers.ops.memory_efficient_attention(
                query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale
            )
            hidden_states_condition = hidden_states_condition.to(query.dtype)
            hidden_states_condition = attn.batch_to_head_dim(
                hidden_states_condition)
            #
            query_uncondition = query[:, self.num_frame_conditioning:]

            key = key[:, :self.num_frame_conditioning]
            value = value[:, :self.num_frame_conditioning]
            key = rearrange(key, "(B W H) F C -> B W H F C",
                            H=hidden_state_height, W=hidden_state_width)
            value = rearrange(value, "(B W H) F C -> B W H F C",
                              H=hidden_state_height, W=hidden_state_width)

            keys = []
            values = []
            for shifts_width in [-1, 0, 1]:
                for shifts_height in [-1, 0, 1]:
                    keys.append(torch.roll(key, shifts=(
                        shifts_width, shifts_height), dims=(1, 2)))
                    values.append(torch.roll(value, shifts=(
                        shifts_width, shifts_height), dims=(1, 2)))
            key = rearrange(torch.cat(keys, dim=3), "B W H F C -> (B W H) F C")
            value = rearrange(torch.cat(values, dim=3),
                              'B W H F C -> (B W H) F C')

            query = attn.head_to_batch_dim(query_uncondition).contiguous()
            key = attn.head_to_batch_dim(key).contiguous()
            value = attn.head_to_batch_dim(value).contiguous()

            hidden_states = xformers.ops.memory_efficient_attention(
                query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
            )
            hidden_states = hidden_states.to(query.dtype)
            hidden_states = attn.batch_to_head_dim(hidden_states)
            hidden_states = torch.cat(
                [hidden_states_condition, hidden_states], dim=1)
        elif not default_attention and is_spatial_attention and self.spatial_attend_on_condition_frames and not attn.cross_attention_mode:
            # (B F) W H C -> B F W H C
            query_condition = rearrange(
                query, "(B F) S C -> B F S C", F=self.num_frames)
            query_condition = query_condition[:, :self.num_frame_conditioning]
            query_condition = rearrange(
                query_condition, "B F S C -> (B F) S C")
            query_condition = attn.head_to_batch_dim(
                query_condition).contiguous()

            key_condition = rearrange(
                key, "(B F) S C -> B F S C", F=self.num_frames)
            key_condition = key_condition[:, :self.num_frame_conditioning]
            key_condition = rearrange(key_condition, "B F S C -> (B F) S C")

            value_condition = rearrange(
                value, "(B F) S C -> B F S C", F=self.num_frames)
            value_condition = value_condition[:, :self.num_frame_conditioning]
            value_condition = rearrange(
                value_condition, "B F S C -> (B F) S C")

            key_condition = attn.head_to_batch_dim(key_condition).contiguous()
            value_condition = attn.head_to_batch_dim(
                value_condition).contiguous()
            hidden_states_condition = xformers.ops.memory_efficient_attention(
                query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale
            )
            hidden_states_condition = hidden_states_condition.to(query.dtype)
            hidden_states_condition = attn.batch_to_head_dim(
                hidden_states_condition)

            query_uncondition = rearrange(
                query, "(B F) S C -> B F S C", F=self.num_frames)
            query_uncondition = query_uncondition[:,
                                                  self.num_frame_conditioning:]
            key_uncondition = rearrange(
                key, "(B F) S C -> B F S C", F=self.num_frames)
            value_uncondition = rearrange(
                value, "(B F) S C -> B F S C", F=self.num_frames)
            key_uncondition = key_uncondition[:,
                                              self.num_frame_conditioning-1, None]
            value_uncondition = value_uncondition[:,
                                                  self.num_frame_conditioning-1, None]
            # if self.trainer.training:
            # import pdb
            # pdb.set_trace()
            # print("now")
            query_uncondition = rearrange(
                query_uncondition, "B F S C -> (B F) S C")
            key_uncondition = repeat(rearrange(
                key_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning)
            value_uncondition = repeat(rearrange(
                value_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning)
            query_uncondition = attn.head_to_batch_dim(
                query_uncondition).contiguous()
            key_uncondition = attn.head_to_batch_dim(
                key_uncondition).contiguous()
            value_uncondition = attn.head_to_batch_dim(
                value_uncondition).contiguous()
            hidden_states_uncondition = xformers.ops.memory_efficient_attention(
                query_uncondition, key_uncondition, value_uncondition, attn_bias=None, op=self.attention_op, scale=attn.scale
            )
            hidden_states_uncondition = hidden_states_uncondition.to(
                query.dtype)
            hidden_states_uncondition = attn.batch_to_head_dim(
                hidden_states_uncondition)
            hidden_states = torch.cat([rearrange(hidden_states_condition, "(B F) S C -> B F S C", F=self.num_frame_conditioning), rearrange(
                hidden_states_uncondition, "(B F) S C -> B F S C", F=self.num_frames-self.num_frame_conditioning)], dim=1)
            hidden_states = rearrange(hidden_states, "B F S C -> (B F) S C")
        else:
            query = attn.head_to_batch_dim(query).contiguous()
            key = attn.head_to_batch_dim(key).contiguous()
            value = attn.head_to_batch_dim(value).contiguous()

            hidden_states = xformers.ops.memory_efficient_attention(
                query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
            )

            hidden_states = hidden_states.to(query.dtype)
            hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states