import re, math, torch
from collections import OrderedDict
from typing import Optional, Tuple

from torch import nn
from torch.nn.init import trunc_normal_, normal_
import torch.utils.checkpoint

from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel


class ClassInstantier(OrderedDict):
    def __getitem__(self, key):
        content = super().__getitem__(key)
        cls, kwargs = content if isinstance(content, tuple) else (content, {})
        return cls(**kwargs)


ACT2CLS = {"silu": nn.SiLU}

ACT2FN = ClassInstantier(ACT2CLS)


class WeightedNorm(nn.Module):
    def __init__(self, hidden_size):
        """
        WeightedNorm
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.norm = nn.LayerNorm(self.hidden_size)
        self.wheight = nn.Parameter(torch.ones(self.hidden_size))
        normal_(self.wheight, mean=1, std=.02)

    def forward(self, x):
        x = self.norm(x)
        return x * self.wheight


class PerceiverMLP(nn.Module):
    def __init__(
            self,
            hidden_size: int,
            intermediate_size: int,
            output_size: int,
            hidden_act: str,
    ):
        super().__init__()
        self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
        self.act_fn = ACT2FN[hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class PerceiverAttention(nn.Module):
    def __init__(self, connector_config, layer_idx: Optional[int] = None) -> None:
        """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
        super().__init__()

        self.layer_idx = None
        self.hidden_size = connector_config.text_hidden_size
        self.num_heads = connector_config.resampler_n_heads
        self.head_dim = connector_config.resampler_head_dim
        self.num_key_value_heads = connector_config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        self.is_causal = False

    def forward(
            self,
            latents: torch.Tensor,
            context: torch.Tensor,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            output_attentions: bool = False,
            use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """
        Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!

        Args:
            latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
            context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
            output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
            use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching.
        """
        bsz, q_len, _ = latents.size()
        kv_seq_len = q_len + context.size()[1]

        hidden_states = torch.concat([context, latents], dim=-2)

        query_states = self.q_proj(latents)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        past_key_value = getattr(self, "past_key_value", past_key_value)

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)

        # repeat k/v heads if n_kv_heads < n_heads
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


PERCEIVER_ATTENTION_CLASSES = {
    "eager": PerceiverAttention,
}


class PerceiverLayer(nn.Module):
    def __init__(self, connector_config, layer_idx: int):
        super().__init__()
        self.hidden_size = connector_config.text_hidden_size
        self.n_latents = connector_config.num_output_tokens
        self.depth = connector_config.resampler_depth
        self.ff_multi = connector_config.ff_multi

        self.input_latents_norm = WeightedNorm(self.hidden_size)
        self.input_context_norm = WeightedNorm(self.hidden_size)
        self.self_attn = PERCEIVER_ATTENTION_CLASSES[connector_config._attn_implementation](connector_config,
                                                                                            layer_idx=layer_idx)
        self.post_attention_layernorm = WeightedNorm(self.hidden_size)
        self.mlp = PerceiverMLP(
            hidden_size=connector_config.text_hidden_size,
            intermediate_size=connector_config.text_hidden_size * self.ff_multi,
            output_size=connector_config.text_hidden_size,
            hidden_act=connector_config.hidden_act,
        )

    def forward(
            self,
            latents: torch.Tensor,
            context: torch.Tensor,
            past_key_value: Optional[Tuple[torch.Tensor]] = None,
            output_attentions: Optional[bool] = False,
            use_cache: Optional[bool] = False,
            **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        residual = latents

        latents = self.input_latents_norm(latents)
        context = self.input_context_norm(context)

        latents, self_attn_weights, present_key_value = self.self_attn(
            latents=latents,
            context=context,
        )

        latents = residual + latents
        residual = latents

        latents = self.post_attention_layernorm(latents)
        latents = self.mlp(latents)
        latents = residual + latents

        outputs = (latents,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


class PerceiverResampler(nn.Module):
    """Perceiver Resampler that compresses input embeddings into a fixed number of latents."""

    def __init__(self, connector_config) -> None:
        super().__init__()
        self.hidden_size = connector_config.text_hidden_size
        self.hidden_act = connector_config.hidden_act
        self.n_latents = connector_config.num_output_tokens
        self.depth = connector_config.resampler_depth

        # Create Latents for Perceiver
        self.latents = nn.Parameter(torch.zeros(self.n_latents, self.hidden_size))

        # Create Transformer Blocks
        self.layers = nn.ModuleList([PerceiverLayer(connector_config, idx) for idx in range(self.depth)])
        self.norm = WeightedNorm(self.hidden_size)
        self._use_flash_attention_2 = connector_config._attn_implementation == "flash_attention_2"

    def forward(
            self,
            context: torch.Tensor,
            attention_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # seq embed -> bsz seq embed
        latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size()))

        compressed_context = latents
        for i, perceiver_layer in enumerate(self.layers):
            layer_outputs = perceiver_layer(
                compressed_context,
                context,
                past_key_value=None,
                output_attentions=False,
                use_cache=False,
            )
            compressed_context = layer_outputs[0]

        compressed_context = self.norm(compressed_context)
        return compressed_context


def build_mm_projector(
    input_dim, 
    output_dim, 
    projector_type, 
    hidden_act='silu', 
    delay_load=False, 
    token_input_shape=0,
    **kwargs
    ) -> nn.Sequential:
    
    modules = [nn.Linear(input_dim, output_dim)]
    mlp_gelu_match = re.match(r'.*mlp(\d+)x_gelu$', projector_type)
    if mlp_gelu_match is not None:
        mlp_depth = int(mlp_gelu_match.group(1))
        for _ in range(mlp_depth - 1):
            modules.append(nn.GELU())
            modules.append(nn.Linear(output_dim, output_dim))

    return nn.Sequential(*modules)


class MMConnector(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__(config)
        self.proj = build_mm_projector(config.vision_hidden_size, config.text_hidden_size,
                                       config.projector_type, token_input_shape=config.token_input_shape)
        self.resampler = PerceiverResampler(config)

    def forward(self, x):
        x = self.proj(x)
        x = self.resampler(x)
        return x