from typing import List, NoReturn

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def init_embedding(layer: nn.Module) -> NoReturn:
    r"""Initialize a Linear or Convolutional layer."""
    nn.init.uniform_(layer.weight, -1.0, 1.0)

    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)


def init_layer(layer: nn.Module) -> NoReturn:
    r"""Initialize a Linear or Convolutional layer."""
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.0)


def init_bn(bn: nn.Module) -> NoReturn:
    r"""Initialize a Batchnorm layer."""
    bn.bias.data.fill_(0.0)
    bn.weight.data.fill_(1.0)
    bn.running_mean.data.fill_(0.0)
    bn.running_var.data.fill_(1.0)


def act(x: torch.Tensor, activation: str) -> torch.Tensor:

    if activation == "relu":
        return F.relu_(x)

    elif activation == "leaky_relu":
        return F.leaky_relu_(x, negative_slope=0.01)

    elif activation == "swish":
        return x * torch.sigmoid(x)

    else:
        raise Exception("Incorrect activation!")


class Base:
    def __init__(self):
        r"""Base function for extracting spectrogram, cos, and sin, etc."""
        pass

    def spectrogram(self, input: torch.Tensor, eps: float = 0.0) -> torch.Tensor:
        r"""Calculate spectrogram.

        Args:
            input: (batch_size, segments_num)
            eps: float

        Returns:
            spectrogram: (batch_size, time_steps, freq_bins)
        """
        (real, imag) = self.stft(input)
        return torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5

    def spectrogram_phase(
        self, input: torch.Tensor, eps: float = 0.0
    ) -> List[torch.Tensor]:
        r"""Calculate the magnitude, cos, and sin of the STFT of input.

        Args:
            input: (batch_size, segments_num)
            eps: float

        Returns:
            mag: (batch_size, time_steps, freq_bins)
            cos: (batch_size, time_steps, freq_bins)
            sin: (batch_size, time_steps, freq_bins)
        """
        (real, imag) = self.stft(input)
        mag = torch.clamp(real ** 2 + imag ** 2, eps, np.inf) ** 0.5
        cos = real / mag
        sin = imag / mag
        return mag, cos, sin

    def wav_to_spectrogram_phase(
        self, input: torch.Tensor, eps: float = 1e-10
    ) -> List[torch.Tensor]:
        r"""Convert waveforms to magnitude, cos, and sin of STFT.

        Args:
            input: (batch_size, channels_num, segment_samples)
            eps: float

        Outputs:
            mag: (batch_size, channels_num, time_steps, freq_bins)
            cos: (batch_size, channels_num, time_steps, freq_bins)
            sin: (batch_size, channels_num, time_steps, freq_bins)
        """
        batch_size, channels_num, segment_samples = input.shape

        # Reshape input with shapes of (n, segments_num) to meet the
        # requirements of the stft function.
        x = input.reshape(batch_size * channels_num, segment_samples)

        mag, cos, sin = self.spectrogram_phase(x, eps=eps)
        # mag, cos, sin: (batch_size * channels_num, 1, time_steps, freq_bins)

        _, _, time_steps, freq_bins = mag.shape
        mag = mag.reshape(batch_size, channels_num, time_steps, freq_bins)
        cos = cos.reshape(batch_size, channels_num, time_steps, freq_bins)
        sin = sin.reshape(batch_size, channels_num, time_steps, freq_bins)

        return mag, cos, sin

    def wav_to_spectrogram(
        self, input: torch.Tensor, eps: float = 1e-10
    ) -> List[torch.Tensor]:

        mag, cos, sin = self.wav_to_spectrogram_phase(input, eps)
        return mag


class Subband:
    def __init__(self, subbands_num: int):
        r"""Warning!! This class is not used!!

        This class does not work as good as [1] which split subbands in the
        time-domain. Please refere to [1] for formal implementation.

        [1] Liu, Haohe, et al. "Channel-wise subband input for better voice and
        accompaniment separation on high resolution music." arXiv preprint arXiv:2008.05216 (2020).

        Args:
            subbands_num: int, e.g., 4
        """
        self.subbands_num = subbands_num

    def analysis(self, x: torch.Tensor) -> torch.Tensor:
        r"""Analysis time-frequency representation into subbands. Stack the
        subbands along the channel axis.

        Args:
            x: (batch_size, channels_num, time_steps, freq_bins)

        Returns:
            output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)
        """
        batch_size, channels_num, time_steps, freq_bins = x.shape

        x = x.reshape(
            batch_size,
            channels_num,
            time_steps,
            self.subbands_num,
            freq_bins // self.subbands_num,
        )
        # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)

        x = x.transpose(2, 3)

        output = x.reshape(
            batch_size,
            channels_num * self.subbands_num,
            time_steps,
            freq_bins // self.subbands_num,
        )
        # output: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)

        return output

    def synthesis(self, x: torch.Tensor) -> torch.Tensor:
        r"""Synthesis subband time-frequency representations into original
        time-frequency representation.

        Args:
            x: (batch_size, channels_num * subbands_num, time_steps, freq_bins // subbands_num)

        Returns:
            output: (batch_size, channels_num, time_steps, freq_bins)
        """
        batch_size, subband_channels_num, time_steps, subband_freq_bins = x.shape

        channels_num = subband_channels_num // self.subbands_num
        freq_bins = subband_freq_bins * self.subbands_num

        x = x.reshape(
            batch_size,
            channels_num,
            self.subbands_num,
            time_steps,
            subband_freq_bins,
        )
        # x: (batch_size, channels_num, subbands_num, time_steps, freq_bins // subbands_num)

        x = x.transpose(2, 3)
        # x: (batch_size, channels_num, time_steps, subbands_num, freq_bins // subbands_num)

        output = x.reshape(batch_size, channels_num, time_steps, freq_bins)
        # x: (batch_size, channels_num, time_steps, freq_bins)

        return output