from typing import Tuple
import torch
from torch import nn, Tensor
from torch.nn import functional as F


class EncoderBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int) -> None:
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=(2, 2))
        self.bn = nn.BatchNorm2d(
            num_features=out_channels,
            track_running_stats=True,
            eps=1e-3,
            momentum=0.01,
        )
        self.relu = nn.LeakyReLU(negative_slope=0.2)

    def forward(self, input: Tensor) -> Tuple[Tensor, Tensor]:
        down = self.conv(F.pad(input, (1, 2, 1, 2), "constant", 0))
        return down, self.relu(self.bn(down))


class DecoderBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, dropout_prob: float = 0.0
    ) -> None:
        super().__init__()
        self.tconv = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=5, stride=2
        )
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(
            out_channels, track_running_stats=True, eps=1e-3, momentum=0.01
        )
        self.dropout = nn.Dropout(dropout_prob) if dropout_prob > 0 else nn.Identity()

    def forward(self, input: Tensor) -> Tensor:
        up = self.tconv(input)
        # reverse padding
        l, r, t, b = 1, 2, 1, 2
        up = up[:, :, l:-r, t:-b]
        return self.dropout(self.bn(self.relu(up)))


class UNet(nn.Module):
    def __init__(
        self,
        n_layers: int = 6,
        in_channels: int = 1,
    ) -> None:
        super().__init__()

        # DownSample layers
        down_set = [in_channels] + [2 ** (i + 4) for i in range(n_layers)]
        self.encoder_layers = nn.ModuleList(
            [
                EncoderBlock(in_channels=in_ch, out_channels=out_ch)
                for in_ch, out_ch in zip(down_set[:-1], down_set[1:])
            ]
        )

        # UpSample layers
        up_set = [1] + [2 ** (i + 4) for i in range(n_layers)]
        up_set.reverse()
        self.decoder_layers = nn.ModuleList(
            [
                DecoderBlock(
                    # doubled for concatenated inputs (skip connections)
                    in_channels=in_ch if i == 0 else in_ch * 2,
                    out_channels=out_ch,
                    #   50 % dropout... first 3 layers only
                    dropout_prob=0.5 if i < 3 else 0,
                )
                for i, (in_ch, out_ch) in enumerate(zip(up_set[:-1], up_set[1:]))
            ]
        )

        # reconstruct the final mask same as the original channels
        self.up_final = nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input: Tensor) -> Tensor:
        encoder_outputs_pre_act = []
        x = input
        for down in self.encoder_layers:
            conv, x = down(x)
            encoder_outputs_pre_act.append(conv)

        for i, up in enumerate(self.decoder_layers):
            if i == 0:
                x = up(encoder_outputs_pre_act.pop())
            else:
                # merge skip connection
                x = up(torch.concat([encoder_outputs_pre_act.pop(), x], axis=1))
        mask = self.sigmoid(self.up_final(x))
        return mask * input