import torch
import torch.nn as nn


class Conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, padding_mode='zeros', bias=True, residual=False):
        super(Conv2d, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride,
                      padding, padding_mode=padding_mode, bias=bias),
            nn.BatchNorm2d(out_channels)
        )
        self.residual = residual
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        out = self.act(out)
        return out


class ResnetBlock(nn.Module):
    def __init__(self, channel, padding_mode, norm_layer=nn.BatchNorm2d, bias=False):
        super().__init__()
        if padding_mode not in ['reflect', 'zero']:
            raise NotImplementedError(f"{padding_mode} is not supported!")

        self.block = nn.Sequential(
            nn.Conv2d(channel, channel, kernel_size=3, padding=1, padding_mode=padding_mode, bias=bias),
            norm_layer(channel)
        )
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.block(x)
        out = out + x
        out = self.act(out)
        return out


class ResidualBlocks(nn.Module):
    def __init__(self, channel, n_blocks=6):
        super().__init__()
        model = []
        for i in range(n_blocks):  # add ResNet blocks
            model += [ResnetBlock(channel, padding_mode='reflect')]

        self.module = nn.Sequential(*model)

    def forward(self, x):
        return self.module(x)


class SelfAttentionBlock(nn.Module):

    def __init__(self, in_dim):
        super().__init__()
        self.feature_dim = in_dim // 8
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.feature_dim, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        B, C, H, W = x.size()
        _query = self.query_conv(x).view(B, -1, H * W).permute(0, 2, 1)  # B x C x (H'*W')
        _key = self.key_conv(x).view(B, -1, H * W)  # B x C x (H'*W')
        attn_matrix = torch.bmm(_query, _key)
        attention = self.softmax(attn_matrix)  # B x (H'*W') x (H'*W')
        _value = self.value_conv(x).view(B, -1, H * W)  # B X C X (H * W)

        out = torch.bmm(_value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)

        out = self.gamma * out + x
        return out


class ContextAwareAttentionBlock(nn.Module):

    def __init__(self, in_channels, hidden_dim=128):
        super().__init__()
        self.self_attn = SelfAttentionBlock(in_channels)
        self.fc = nn.Linear(in_channels, hidden_dim)
        self.context_vector = nn.Linear(hidden_dim, 1, bias=False)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, style_features):
        B, C, H, W = style_features.size()
        h = self.self_attn(style_features)
        h = h.permute(0, 2, 3, 1).reshape(-1, C)
        h = torch.tanh(self.fc(h))  # (B*H*W) x self.hidden_dim
        h = self.context_vector(h)  # (B*H*W) x 1
        attention_score = self.softmax(h.view(B, H * W)).view(B, 1, H, W)  # B x 1 x H x W
        return torch.sum(style_features * attention_score, dim=[2, 3])  # B x C


class LayerAttentionBlock(nn.Module):
    """from FTransGAN
    """

    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels
        self.width_feat = 4
        self.height_feat = 4
        self.fc = nn.Linear(self.in_channels * self.width_feat * self.height_feat, 3)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, style_features, style_features_1, style_features_2, style_features_3, B, K):
        style_features = torch.mean(style_features.view(B, K, self.in_channels, self.height_feat, self.width_feat), dim=1)
        style_features = style_features.view(B, -1)
        weight = self.softmax(self.fc(style_features))

        style_features_1 = torch.mean(style_features_1.view(B, K, self.in_channels), dim=1)
        style_features_2 = torch.mean(style_features_2.view(B, K, self.in_channels), dim=1)
        style_features_3 = torch.mean(style_features_3.view(B, K, self.in_channels), dim=1)

        style_features = (style_features_1 * weight.narrow(1, 0, 1) +
                          style_features_2 * weight.narrow(1, 1, 1) +
                          style_features_3 * weight.narrow(1, 2, 1))
        style_features = style_features.view(B, self.in_channels, 1, 1)
        return style_features


class StyleAttentionBlock(nn.Module):
    """from FTransGAN
    """

    def __init__(self, in_channels):
        super().__init__()
        self.num_local_attention = 3
        for module_idx in range(1, self.num_local_attention + 1):
            self.add_module(f"local_attention_{module_idx}",
                            ContextAwareAttentionBlock(in_channels))

        for module_idx in range(1, self.num_local_attention):
            self.add_module(f"downsample_{module_idx}",
                            Conv2d(in_channels, in_channels,
                                   kernel_size=3, stride=2, padding=1, bias=False))

        self.add_module(f"layer_attention", LayerAttentionBlock(in_channels))

    def forward(self, x, B, K):
        feature_1 = self.local_attention_1(x)

        x = self.downsample_1(x)
        feature_2 = self.local_attention_2(x)

        x = self.downsample_2(x)
        feature_3 = self.local_attention_3(x)

        out = self.layer_attention(x, feature_1, feature_2, feature_3, B, K)

        return out