import functools

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras import layers

from .blocks.attentions import SAM
from .blocks.bottleneck import BottleneckBlock
from .blocks.misc_gating import CrossGatingBlock
from .blocks.others import UpSampleRatio
from .blocks.unet import UNetDecoderBlock, UNetEncoderBlock
from .layers import Resizing

Conv1x1 = functools.partial(layers.Conv2D, kernel_size=(1, 1), padding="same")
Conv3x3 = functools.partial(layers.Conv2D, kernel_size=(3, 3), padding="same")
ConvT_up = functools.partial(
    layers.Conv2DTranspose, kernel_size=(2, 2), strides=(2, 2), padding="same"
)
Conv_down = functools.partial(
    layers.Conv2D, kernel_size=(4, 4), strides=(2, 2), padding="same"
)


def MAXIM(
    features: int = 64,
    depth: int = 3,
    num_stages: int = 2,
    num_groups: int = 1,
    use_bias: bool = True,
    num_supervision_scales: int = 1,
    lrelu_slope: float = 0.2,
    use_global_mlp: bool = True,
    use_cross_gating: bool = True,
    high_res_stages: int = 2,
    block_size_hr=(16, 16),
    block_size_lr=(8, 8),
    grid_size_hr=(16, 16),
    grid_size_lr=(8, 8),
    num_bottleneck_blocks: int = 1,
    block_gmlp_factor: int = 2,
    grid_gmlp_factor: int = 2,
    input_proj_factor: int = 2,
    channels_reduction: int = 4,
    num_outputs: int = 3,
    dropout_rate: float = 0.0,
):
    """The MAXIM model function with multi-stage and multi-scale supervision.

    For more model details, please check the CVPR paper:
    MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)

    Attributes:
      features: initial hidden dimension for the input resolution.
      depth: the number of downsampling depth for the model.
      num_stages: how many stages to use. It will also affects the output list.
      num_groups: how many blocks each stage contains.
      use_bias: whether to use bias in all the conv/mlp layers.
      num_supervision_scales: the number of desired supervision scales.
      lrelu_slope: the negative slope parameter in leaky_relu layers.
      use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each
        layer.
      use_cross_gating: whether to use the cross-gating MLP block (CGB) in the
        skip connections and multi-stage feature fusion layers.
      high_res_stages: how many stages are specificied as high-res stages. The
        rest (depth - high_res_stages) are called low_res_stages.
      block_size_hr: the block_size parameter for high-res stages.
      block_size_lr: the block_size parameter for low-res stages.
      grid_size_hr: the grid_size parameter for high-res stages.
      grid_size_lr: the grid_size parameter for low-res stages.
      num_bottleneck_blocks: how many bottleneck blocks.
      block_gmlp_factor: the input projection factor for block_gMLP layers.
      grid_gmlp_factor: the input projection factor for grid_gMLP layers.
      input_proj_factor: the input projection factor for the MAB block.
      channels_reduction: the channel reduction factor for SE layer.
      num_outputs: the output channels.
      dropout_rate: Dropout rate.

    Returns:
      The output contains a list of arrays consisting of multi-stage multi-scale
      outputs. For example, if num_stages = num_supervision_scales = 3 (the
      model used in the paper), the output specs are: outputs =
      [[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],
       [output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],
       [output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]
      The final output can be retrieved by outputs[-1][-1].
    """

    def apply(x):
        n, h, w, c = (
            K.int_shape(x)[0],
            K.int_shape(x)[1],
            K.int_shape(x)[2],
            K.int_shape(x)[3],
        )  # input image shape

        shortcuts = []
        shortcuts.append(x)

        # Get multi-scale input images
        for i in range(1, num_supervision_scales):
            resizing_layer = Resizing(
                height=h // (2 ** i),
                width=w // (2 ** i),
                method="nearest",
                antialias=True,  # Following `jax.image.resize()`.
                name=f"initial_resizing_{K.get_uid('Resizing')}",
            )
            shortcuts.append(resizing_layer(x))

        # store outputs from all stages and all scales
        # Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)],   # Stage-1 outputs
        #      [(64, 64, 3), (128, 128, 3), (256, 256, 3)],]  # Stage-2 outputs
        outputs_all = []
        sam_features, encs_prev, decs_prev = [], [], []

        for idx_stage in range(num_stages):
            # Input convolution, get multi-scale input features
            x_scales = []
            for i in range(num_supervision_scales):
                x_scale = Conv3x3(
                    filters=(2 ** i) * features,
                    use_bias=use_bias,
                    name=f"stage_{idx_stage}_input_conv_{i}",
                )(shortcuts[i])

                # If later stages, fuse input features with SAM features from prev stage
                if idx_stage > 0:
                    # use larger blocksize at high-res stages
                    if use_cross_gating:
                        block_size = (
                            block_size_hr if i < high_res_stages else block_size_lr
                        )
                        grid_size = grid_size_hr if i < high_res_stages else block_size_lr
                        x_scale, _ = CrossGatingBlock(
                            features=(2 ** i) * features,
                            block_size=block_size,
                            grid_size=grid_size,
                            dropout_rate=dropout_rate,
                            input_proj_factor=input_proj_factor,
                            upsample_y=False,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_input_fuse_sam_{i}",
                        )(x_scale, sam_features.pop())
                    else:
                        x_scale = Conv1x1(
                            filters=(2 ** i) * features,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_input_catconv_{i}",
                        )(tf.concat([x_scale, sam_features.pop()], axis=-1))

                x_scales.append(x_scale)

            # start encoder blocks
            encs = []
            x = x_scales[0]  # First full-scale input feature

            for i in range(depth):  # 0, 1, 2
                # use larger blocksize at high-res stages, vice versa.
                block_size = block_size_hr if i < high_res_stages else block_size_lr
                grid_size = grid_size_hr if i < high_res_stages else block_size_lr
                use_cross_gating_layer = True if idx_stage > 0 else False

                # Multi-scale input if multi-scale supervision
                x_scale = x_scales[i] if i < num_supervision_scales else None

                # UNet Encoder block
                enc_prev = encs_prev.pop() if idx_stage > 0 else None
                dec_prev = decs_prev.pop() if idx_stage > 0 else None

                x, bridge = UNetEncoderBlock(
                    num_channels=(2 ** i) * features,
                    num_groups=num_groups,
                    downsample=True,
                    lrelu_slope=lrelu_slope,
                    block_size=block_size,
                    grid_size=grid_size,
                    block_gmlp_factor=block_gmlp_factor,
                    grid_gmlp_factor=grid_gmlp_factor,
                    input_proj_factor=input_proj_factor,
                    channels_reduction=channels_reduction,
                    use_global_mlp=use_global_mlp,
                    dropout_rate=dropout_rate,
                    use_bias=use_bias,
                    use_cross_gating=use_cross_gating_layer,
                    name=f"stage_{idx_stage}_encoder_block_{i}",
                )(x, skip=x_scale, enc=enc_prev, dec=dec_prev)

                # Cache skip signals
                encs.append(bridge)

            # Global MLP bottleneck blocks
            for i in range(num_bottleneck_blocks):
                x = BottleneckBlock(
                    block_size=block_size_lr,
                    grid_size=block_size_lr,
                    features=(2 ** (depth - 1)) * features,
                    num_groups=num_groups,
                    block_gmlp_factor=block_gmlp_factor,
                    grid_gmlp_factor=grid_gmlp_factor,
                    input_proj_factor=input_proj_factor,
                    dropout_rate=dropout_rate,
                    use_bias=use_bias,
                    channels_reduction=channels_reduction,
                    name=f"stage_{idx_stage}_global_block_{i}",
                )(x)
            # cache global feature for cross-gating
            global_feature = x

            # start cross gating. Use multi-scale feature fusion
            skip_features = []
            for i in reversed(range(depth)):  # 2, 1, 0
                # use larger blocksize at high-res stages
                block_size = block_size_hr if i < high_res_stages else block_size_lr
                grid_size = grid_size_hr if i < high_res_stages else block_size_lr

                # get additional multi-scale signals
                signal = tf.concat(
                    [
                        UpSampleRatio(
                            num_channels=(2 ** i) * features,
                            ratio=2 ** (j - i),
                            use_bias=use_bias,
                            name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
                        )(enc)
                        for j, enc in enumerate(encs)
                    ],
                    axis=-1,
                )

                # Use cross-gating to cross modulate features
                if use_cross_gating:
                    skips, global_feature = CrossGatingBlock(
                        features=(2 ** i) * features,
                        block_size=block_size,
                        grid_size=grid_size,
                        input_proj_factor=input_proj_factor,
                        dropout_rate=dropout_rate,
                        upsample_y=True,
                        use_bias=use_bias,
                        name=f"stage_{idx_stage}_cross_gating_block_{i}",
                    )(signal, global_feature)
                else:
                    skips = Conv1x1(
                        filters=(2 ** i) * features, use_bias=use_bias, name="Conv_0"
                    )(signal)
                    skips = Conv3x3(
                        filters=(2 ** i) * features, use_bias=use_bias, name="Conv_1"
                    )(skips)

                skip_features.append(skips)

            # start decoder. Multi-scale feature fusion of cross-gated features
            outputs, decs, sam_features = [], [], []
            for i in reversed(range(depth)):
                # use larger blocksize at high-res stages
                block_size = block_size_hr if i < high_res_stages else block_size_lr
                grid_size = grid_size_hr if i < high_res_stages else block_size_lr

                # get multi-scale skip signals from cross-gating block
                signal = tf.concat(
                    [
                        UpSampleRatio(
                            num_channels=(2 ** i) * features,
                            ratio=2 ** (depth - j - 1 - i),
                            use_bias=use_bias,
                            name=f"UpSampleRatio_{K.get_uid('UpSampleRatio')}",
                        )(skip)
                        for j, skip in enumerate(skip_features)
                    ],
                    axis=-1,
                )

                # Decoder block
                x = UNetDecoderBlock(
                    num_channels=(2 ** i) * features,
                    num_groups=num_groups,
                    lrelu_slope=lrelu_slope,
                    block_size=block_size,
                    grid_size=grid_size,
                    block_gmlp_factor=block_gmlp_factor,
                    grid_gmlp_factor=grid_gmlp_factor,
                    input_proj_factor=input_proj_factor,
                    channels_reduction=channels_reduction,
                    use_global_mlp=use_global_mlp,
                    dropout_rate=dropout_rate,
                    use_bias=use_bias,
                    name=f"stage_{idx_stage}_decoder_block_{i}",
                )(x, bridge=signal)

                # Cache decoder features for later-stage's usage
                decs.append(x)

                # output conv, if not final stage, use supervised-attention-block.
                if i < num_supervision_scales:
                    if idx_stage < num_stages - 1:  # not last stage, apply SAM
                        sam, output = SAM(
                            num_channels=(2 ** i) * features,
                            output_channels=num_outputs,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_supervised_attention_module_{i}",
                        )(x, shortcuts[i])
                        outputs.append(output)
                        sam_features.append(sam)
                    else:  # Last stage, apply output convolutions
                        output = Conv3x3(
                            num_outputs,
                            use_bias=use_bias,
                            name=f"stage_{idx_stage}_output_conv_{i}",
                        )(x)
                        output = output + shortcuts[i]
                        outputs.append(output)
            # Cache encoder and decoder features for later-stage's usage
            encs_prev = encs[::-1]
            decs_prev = decs

            # Store outputs
            outputs_all.append(outputs)
        return outputs_all

    return apply