from numpy import sqrt
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from typing import Tuple, Literal
from functools import partial

from pdb import set_trace as st

# from core.attention import MemEffAttention
from vit.vision_transformer import MemEffAttention


class MVAttention(nn.Module):

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            proj_bias: bool = True,
            attn_drop: float = 0.0,
            proj_drop: float = 0.0,
            groups: int = 32,
            eps: float = 1e-5,
            residual: bool = True,
            skip_scale: float = 1,
            num_frames: int = 4,  # WARN: hardcoded!
    ):
        super().__init__()

        self.residual = residual
        self.skip_scale = skip_scale
        self.num_frames = num_frames

        self.norm = nn.GroupNorm(num_groups=groups,
                                 num_channels=dim,
                                 eps=eps,
                                 affine=True)
        self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias,
                                    attn_drop, proj_drop)

    def forward(self, x):
        # x: [B*V, C, H, W]
        BV, C, H, W = x.shape
        B = BV // self.num_frames  # assert BV % self.num_frames == 0

        res = x
        x = self.norm(x)

        x = x.reshape(B, self.num_frames, C, H,
                      W).permute(0, 1, 3, 4, 2).reshape(B, -1, C)
        x = self.attn(x)
        x = x.reshape(B, self.num_frames, H, W,
                      C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W)

        if self.residual:
            x = (x + res) * self.skip_scale
        return x


class ResnetBlock(nn.Module):

    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            resample: Literal['default', 'up', 'down'] = 'default',
            groups: int = 32,
            eps: float = 1e-5,
            skip_scale: float = 1,  # multiplied to output
    ):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.skip_scale = skip_scale

        self.norm1 = nn.GroupNorm(num_groups=groups,
                                  num_channels=in_channels,
                                  eps=eps,
                                  affine=True)
        self.conv1 = nn.Conv2d(in_channels,
                               out_channels,
                               kernel_size=3,
                               stride=1,
                               padding=1)

        self.norm2 = nn.GroupNorm(num_groups=groups,
                                  num_channels=out_channels,
                                  eps=eps,
                                  affine=True)
        self.conv2 = nn.Conv2d(out_channels,
                               out_channels,
                               kernel_size=3,
                               stride=1,
                               padding=1)

        self.act = F.silu

        self.resample = None
        if resample == 'up':
            self.resample = partial(F.interpolate,
                                    scale_factor=2.0,
                                    mode="nearest")
        elif resample == 'down':
            self.resample = nn.AvgPool2d(kernel_size=2, stride=2)

        self.shortcut = nn.Identity()
        if self.in_channels != self.out_channels:
            self.shortcut = nn.Conv2d(in_channels,
                                      out_channels,
                                      kernel_size=1,
                                      bias=True)

    def forward(self, x):
        res = x

        x = self.norm1(x)
        x = self.act(x)

        if self.resample:
            res = self.resample(res)
            x = self.resample(x)

        x = self.conv1(x)
        x = self.norm2(x)
        x = self.act(x)
        x = self.conv2(x)

        x = (x + self.shortcut(res)) * self.skip_scale

        return x


class DownBlock(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        num_layers: int = 1,
        downsample: bool = True,
        attention: bool = True,
        attention_heads: int = 16,
        skip_scale: float = 1,
    ):
        super().__init__()

        nets = []
        attns = []
        for i in range(num_layers):
            in_channels = in_channels if i == 0 else out_channels
            nets.append(
                ResnetBlock(in_channels, out_channels, skip_scale=skip_scale))
            if attention:
                attns.append(
                    MVAttention(out_channels,
                                attention_heads,
                                skip_scale=skip_scale))
            else:
                attns.append(None)
        self.nets = nn.ModuleList(nets)
        self.attns = nn.ModuleList(attns)

        self.downsample = None
        if downsample:
            self.downsample = nn.Conv2d(out_channels,
                                        out_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=1)

    def forward(self, x):
        xs = []

        for attn, net in zip(self.attns, self.nets):
            x = net(x)
            if attn:
                x = attn(x)
            xs.append(x)

        if self.downsample:
            x = self.downsample(x)
            xs.append(x)

        return x, xs


class MidBlock(nn.Module):

    def __init__(
        self,
        in_channels: int,
        num_layers: int = 1,
        attention: bool = True,
        attention_heads: int = 16,
        skip_scale: float = 1,
    ):
        super().__init__()

        nets = []
        attns = []
        # first layer
        nets.append(
            ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
        # more layers
        for i in range(num_layers):
            nets.append(
                ResnetBlock(in_channels, in_channels, skip_scale=skip_scale))
            if attention:
                attns.append(
                    MVAttention(in_channels,
                                attention_heads,
                                skip_scale=skip_scale))
            else:
                attns.append(None)
        self.nets = nn.ModuleList(nets)
        self.attns = nn.ModuleList(attns)

    def forward(self, x):
        x = self.nets[0](x)
        for attn, net in zip(self.attns, self.nets[1:]):
            if attn:
                x = attn(x)
            x = net(x)
        return x


class UpBlock(nn.Module):

    def __init__(
        self,
        in_channels: int,
        prev_out_channels: int,
        out_channels: int,
        num_layers: int = 1,
        upsample: bool = True,
        attention: bool = True,
        attention_heads: int = 16,
        skip_scale: float = 1,
    ):
        super().__init__()

        nets = []
        attns = []
        for i in range(num_layers):
            cin = in_channels if i == 0 else out_channels
            cskip = prev_out_channels if (i == num_layers -
                                          1) else out_channels

            nets.append(
                ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale))
            if attention:
                attns.append(
                    MVAttention(out_channels,
                                attention_heads,
                                skip_scale=skip_scale))
            else:
                attns.append(None)
        self.nets = nn.ModuleList(nets)
        self.attns = nn.ModuleList(attns)

        self.upsample = None
        if upsample:
            self.upsample = nn.Conv2d(out_channels,
                                      out_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1)

    def forward(self, x, xs):

        for attn, net in zip(self.attns, self.nets):
            res_x = xs[-1]
            xs = xs[:-1]
            x = torch.cat([x, res_x], dim=1)
            x = net(x)
            if attn:
                x = attn(x)

        if self.upsample:
            x = F.interpolate(x, scale_factor=2.0, mode='nearest')
            x = self.upsample(x)

        return x


# it could be asymmetric!
class MVUNet(nn.Module):

    def __init__(
            self,
            in_channels: int = 3,
            out_channels: int = 3,
            down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024),
            down_attention: Tuple[bool,
                                  ...] = (False, False, False, True, True),
            mid_attention: bool = True,
            up_channels: Tuple[int, ...] = (1024, 512, 256),
            up_attention: Tuple[bool, ...] = (True, True, False),
            layers_per_block: int = 2,
            skip_scale: float = np.sqrt(0.5),
    ):
        super().__init__()

        # first
        self.conv_in = nn.Conv2d(in_channels,
                                 down_channels[0],
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)

        # down
        down_blocks = []
        cout = down_channels[0]
        for i in range(len(down_channels)):
            cin = cout
            cout = down_channels[i]

            down_blocks.append(
                DownBlock(
                    cin,
                    cout,
                    num_layers=layers_per_block,
                    downsample=(i
                                != len(down_channels) - 1),  # not final layer
                    attention=down_attention[i],
                    skip_scale=skip_scale,
                ))
        self.down_blocks = nn.ModuleList(down_blocks)

        # mid
        self.mid_block = MidBlock(down_channels[-1],
                                  attention=mid_attention,
                                  skip_scale=skip_scale)

        # up
        up_blocks = []
        cout = up_channels[0]
        for i in range(len(up_channels)):
            cin = cout
            cout = up_channels[i]
            cskip = down_channels[max(-2 - i,
                                      -len(down_channels))]  # for assymetric

            up_blocks.append(
                UpBlock(
                    cin,
                    cskip,
                    cout,
                    num_layers=layers_per_block + 1,  # one more layer for up
                    upsample=(i != len(up_channels) - 1),  # not final layer
                    attention=up_attention[i],
                    skip_scale=skip_scale,
                ))
        self.up_blocks = nn.ModuleList(up_blocks)

        # last
        self.norm_out = nn.GroupNorm(num_channels=up_channels[-1],
                                     num_groups=32,
                                     eps=1e-5)
        self.conv_out = nn.Conv2d(up_channels[-1],
                                  out_channels,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1)

    def forward(self, x):
        # x: [B, Cin, H, W]

        # first
        x = self.conv_in(x)

        # down
        xss = [x]
        for block in self.down_blocks:
            x, xs = block(x)
            xss.extend(xs)

        # mid
        x = self.mid_block(x)  # 32 (B V) 1024 16 16

        # up
        for block in self.up_blocks:
            xs = xss[-len(block.nets):]
            xss = xss[:-len(block.nets)]
            x = block(x, xs)

        # last
        x = self.norm_out(x)
        x = F.silu(x)
        x = self.conv_out(x)  # [B, Cout, H', W']

        return x


class LGM_MVEncoder(MVUNet):

    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_channels: Tuple[int] = (64, 128, 256, 512, 1024),
        down_attention: Tuple[bool] = (False, False, False, True, True),
        mid_attention: bool = True,
        up_channels: Tuple[int] = (1024, 512, 256),
        up_attention: Tuple[bool] = (True, True, False),
        layers_per_block: int = 2,
        skip_scale: float = np.sqrt(0.5),
        z_channels=4,
        double_z=True,
        add_fusion_layer=True,
    ):
        super().__init__(in_channels, out_channels, down_channels,
                         down_attention, mid_attention, up_channels,
                         up_attention, layers_per_block, skip_scale)
        del self.up_blocks

        self.conv_out = torch.nn.Conv2d(up_channels[0],
                                        2 *
                                        z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
        if add_fusion_layer:  # fusion 4 frames
            self.fusion_layer = torch.nn.Conv2d(
                2 * z_channels * 4 if double_z else z_channels * 4,
                2 * z_channels if double_z else z_channels,
                kernel_size=3,
                stride=1,
                padding=1)

        self.num_frames = 4 # !hard coded
    
    def forward(self, x):
        # first
        x = self.conv_in(x)

        # down
        xss = [x]
        for block in self.down_blocks:
            x, xs = block(x)
            xss.extend(xs)

        # mid
        x = self.mid_block(x)  # 32 (B V) 1024 16 16

        # multi-view aggregation, as in pixel-nerf
        x = x.chunk(x.shape[0] // self.num_frames) # features from the same single instance aggregated here
        # h = [feat.max(keepdim=True, dim=0)[0] for feat in h] # max pooling
        x = [self.fusion_layer(torch.cat(feat.chunk(feat.shape[0]), dim=1)) for feat in x] # conv pooling
        st()
        return torch.cat(x, dim=0)