# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F

from modules.general.utils import Conv1d, normalization, zero_module
from .basic import UNetBlock


class AttentionBlock(UNetBlock):
    r"""A spatial transformer encoder block that allows spatial positions to attend
    to each other. Reference from `latent diffusion repo
    <https://github.com/Stability-AI/generative-models/blob/main/sgm/modules/attention.py#L531>`_.

    Args:
        channels: Number of channels in the input.
        num_head_channels: Number of channels per attention head.
        num_heads: Number of attention heads. Overrides ``num_head_channels`` if set.
        encoder_channels: Number of channels in the encoder output for cross-attention.
            If ``None``, then self-attention is performed.
        use_self_attention: Whether to use self-attention before cross-attention, only applicable if encoder_channels is set.
        dims: Number of spatial dimensions, i.e. 1 for temporal signals, 2 for images.
        h_dim: The dimension of the height, would be applied if ``dims`` is 2.
        encoder_hdim: The dimension of the height of the encoder output, would be applied if ``dims`` is 2.
        p_dropout: Dropout probability.
    """

    def __init__(
        self,
        channels: int,
        num_head_channels: int = 32,
        num_heads: int = -1,
        encoder_channels: int = None,
        use_self_attention: bool = False,
        dims: int = 1,
        h_dim: int = 100,
        encoder_hdim: int = 384,
        p_dropout: float = 0.0,
    ):
        super().__init__()

        self.channels = channels
        self.p_dropout = p_dropout
        self.dims = dims

        if dims == 1:
            self.channels = channels
        elif dims == 2:
            # We consider the channel as product of channel and height, i.e. C x H
            # This is because we want to apply attention on the audio signal, which is 1D
            self.channels = channels * h_dim
        else:
            raise ValueError(f"invalid number of dimensions: {dims}")

        if num_head_channels == -1:
            assert (
                self.channels % num_heads == 0
            ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
            self.num_heads = num_heads
            self.num_head_channels = self.channels // num_heads
        else:
            assert (
                self.channels % num_head_channels == 0
            ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = self.channels // num_head_channels
            self.num_head_channels = num_head_channels

        if encoder_channels is not None:
            self.use_self_attention = use_self_attention

            if dims == 1:
                self.encoder_channels = encoder_channels
            elif dims == 2:
                self.encoder_channels = encoder_channels * encoder_hdim
            else:
                raise ValueError(f"invalid number of dimensions: {dims}")

            if use_self_attention:
                self.self_attention = BasicAttentionBlock(
                    self.channels,
                    self.num_head_channels,
                    self.num_heads,
                    p_dropout=self.p_dropout,
                )
            self.cross_attention = BasicAttentionBlock(
                self.channels,
                self.num_head_channels,
                self.num_heads,
                self.encoder_channels,
                p_dropout=self.p_dropout,
            )
        else:
            self.encoder_channels = None
            self.self_attention = BasicAttentionBlock(
                self.channels,
                self.num_head_channels,
                self.num_heads,
                p_dropout=self.p_dropout,
            )

    def forward(self, x: torch.Tensor, encoder_output: torch.Tensor = None):
        r"""
        Args:
            x: input tensor with shape [B x ``channels`` x ...]
            encoder_output: feature tensor with shape [B x ``encoder_channels`` x ...], if ``None``, then self-attention is performed.

        Returns:
            output tensor with shape [B x ``channels`` x ...]
        """
        shape = x.size()
        x = x.reshape(shape[0], self.channels, -1).contiguous()

        if self.encoder_channels is None:
            assert (
                encoder_output is None
            ), "encoder_output must be None for self-attention."
            h = self.self_attention(x)

        else:
            assert (
                encoder_output is not None
            ), "encoder_output must be given for cross-attention."
            encoder_output = encoder_output.reshape(
                shape[0], self.encoder_channels, -1
            ).contiguous()

            if self.use_self_attention:
                x = self.self_attention(x)
            h = self.cross_attention(x, encoder_output)

        return h.reshape(*shape).contiguous()


class BasicAttentionBlock(nn.Module):
    def __init__(
        self,
        channels: int,
        num_head_channels: int = 32,
        num_heads: int = -1,
        context_channels: int = None,
        p_dropout: float = 0.0,
    ):
        super().__init__()

        self.channels = channels
        self.p_dropout = p_dropout
        self.context_channels = context_channels

        if num_head_channels == -1:
            assert (
                self.channels % num_heads == 0
            ), f"q,k,v channels {self.channels} is not divisible by num_heads {num_heads}"
            self.num_heads = num_heads
            self.num_head_channels = self.channels // num_heads
        else:
            assert (
                self.channels % num_head_channels == 0
            ), f"q,k,v channels {self.channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = self.channels // num_head_channels
            self.num_head_channels = num_head_channels

        if context_channels is not None:
            self.to_q = nn.Sequential(
                normalization(self.channels),
                Conv1d(self.channels, self.channels, 1),
            )
            self.to_kv = Conv1d(context_channels, 2 * self.channels, 1)
        else:
            self.to_qkv = nn.Sequential(
                normalization(self.channels),
                Conv1d(self.channels, 3 * self.channels, 1),
            )

        self.linear = Conv1d(self.channels, self.channels)

        self.proj_out = nn.Sequential(
            normalization(self.channels),
            Conv1d(self.channels, self.channels, 1),
            nn.GELU(),
            nn.Dropout(p=self.p_dropout),
            zero_module(Conv1d(self.channels, self.channels, 1)),
        )

    def forward(self, q: torch.Tensor, kv: torch.Tensor = None):
        r"""
        Args:
            q: input tensor with shape [B, ``channels``, L]
            kv: feature tensor with shape [B, ``context_channels``, T], if ``None``, then self-attention is performed.

        Returns:
            output tensor with shape [B, ``channels``, L]
        """
        N, C, L = q.size()

        if self.context_channels is not None:
            assert kv is not None, "kv must be given for cross-attention."

            q = (
                self.to_q(q)
                .reshape(self.num_heads, self.num_head_channels, -1)
                .transpose(-1, -2)
                .contiguous()
            )
            kv = (
                self.to_kv(kv)
                .reshape(2, self.num_heads, self.num_head_channels, -1)
                .transpose(-1, -2)
                .chunk(2)
            )
            k, v = (
                kv[0].squeeze(0).contiguous(),
                kv[1].squeeze(0).contiguous(),
            )

        else:
            qkv = (
                self.to_qkv(q)
                .reshape(3, self.num_heads, self.num_head_channels, -1)
                .transpose(-1, -2)
                .chunk(3)
            )
            q, k, v = (
                qkv[0].squeeze(0).contiguous(),
                qkv[1].squeeze(0).contiguous(),
                qkv[2].squeeze(0).contiguous(),
            )

        h = F.scaled_dot_product_attention(q, k, v, dropout_p=self.p_dropout).transpose(
            -1, -2
        )
        h = h.reshape(N, -1, L).contiguous()
        h = self.linear(h)

        x = q + h
        h = self.proj_out(x)

        return x + h