from distutils.version import LooseVersion
from typing import List
from typing import Tuple
from typing import Union

import logging
import torch
from torch.nn import functional as F
from torch_complex import functional as FC
from torch_complex.tensor import ComplexTensor

from espnet.nets.pytorch_backend.frontends.beamformer import apply_beamforming_vector
from espnet.nets.pytorch_backend.frontends.beamformer import (
    get_power_spectral_density_matrix,  # noqa: H301
)
from espnet2.enh.layers.beamformer import get_covariances
from espnet2.enh.layers.beamformer import get_mvdr_vector
from espnet2.enh.layers.beamformer import get_mvdr_vector_with_rtf
from espnet2.enh.layers.beamformer import get_WPD_filter_v2
from espnet2.enh.layers.beamformer import get_WPD_filter_with_rtf
from espnet2.enh.layers.beamformer import perform_WPD_filtering
from espnet2.enh.layers.mask_estimator import MaskEstimator

is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0")
is_torch_1_3_plus = LooseVersion(torch.__version__) >= LooseVersion("1.3.0")


BEAMFORMER_TYPES = (
    # Minimum Variance Distortionless Response beamformer
    "mvdr",  # RTF-based formula
    "mvdr_souden",  # Souden's solution
    # Minimum Power Distortionless Response beamformer
    "mpdr",  # RTF-based formula
    "mpdr_souden",  # Souden's solution
    # weighted MPDR beamformer
    "wmpdr",  # RTF-based formula
    "wmpdr_souden",  # Souden's solution
    # Weighted Power minimization Distortionless response beamformer
    "wpd",  # RTF-based formula
    "wpd_souden",  # Souden's solution
)


class DNN_Beamformer(torch.nn.Module):
    """DNN mask based Beamformer.

    Citation:
        Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
        http://proceedings.mlr.press/v70/ochiai17a/ochiai17a.pdf

    """

    def __init__(
        self,
        bidim,
        btype: str = "blstmp",
        blayers: int = 3,
        bunits: int = 300,
        bprojs: int = 320,
        num_spk: int = 1,
        use_noise_mask: bool = True,
        nonlinear: str = "sigmoid",
        dropout_rate: float = 0.0,
        badim: int = 320,
        ref_channel: int = -1,
        beamformer_type: str = "mvdr_souden",
        rtf_iterations: int = 2,
        eps: float = 1e-6,
        diagonal_loading: bool = True,
        diag_eps: float = 1e-7,
        mask_flooring: bool = False,
        flooring_thres: float = 1e-6,
        use_torch_solver: bool = True,
        # only for WPD beamformer
        btaps: int = 5,
        bdelay: int = 3,
    ):
        super().__init__()
        bnmask = num_spk + 1 if use_noise_mask else num_spk
        self.mask = MaskEstimator(
            btype,
            bidim,
            blayers,
            bunits,
            bprojs,
            dropout_rate,
            nmask=bnmask,
            nonlinear=nonlinear,
        )
        self.ref = AttentionReference(bidim, badim) if ref_channel < 0 else None
        self.ref_channel = ref_channel

        self.use_noise_mask = use_noise_mask
        assert num_spk >= 1, num_spk
        self.num_spk = num_spk
        self.nmask = bnmask

        if beamformer_type not in BEAMFORMER_TYPES:
            raise ValueError("Not supporting beamformer_type=%s" % beamformer_type)
        if (
            beamformer_type == "mvdr_souden" or not beamformer_type.endswith("_souden")
        ) and not use_noise_mask:
            if num_spk == 1:
                logging.warning(
                    "Initializing %s beamformer without noise mask "
                    "estimator (single-speaker case)" % beamformer_type.upper()
                )
                logging.warning(
                    "(1 - speech_mask) will be used for estimating noise "
                    "PSD in %s beamformer!" % beamformer_type.upper()
                )
            else:
                logging.warning(
                    "Initializing %s beamformer without noise mask "
                    "estimator (multi-speaker case)" % beamformer_type.upper()
                )
                logging.warning(
                    "Interference speech masks will be used for estimating "
                    "noise PSD in %s beamformer!" % beamformer_type.upper()
                )

        self.beamformer_type = beamformer_type
        if not beamformer_type.endswith("_souden"):
            assert rtf_iterations >= 2, rtf_iterations
        # number of iterations in power method for estimating the RTF
        self.rtf_iterations = rtf_iterations

        assert btaps >= 0 and bdelay >= 0, (btaps, bdelay)
        self.btaps = btaps
        self.bdelay = bdelay if self.btaps > 0 else 1
        self.eps = eps
        self.diagonal_loading = diagonal_loading
        self.diag_eps = diag_eps
        self.mask_flooring = mask_flooring
        self.flooring_thres = flooring_thres
        self.use_torch_solver = use_torch_solver

    def forward(
        self,
        data: ComplexTensor,
        ilens: torch.LongTensor,
        powers: Union[List[torch.Tensor], None] = None,
    ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]:
        """DNN_Beamformer forward function.

        Notation:
            B: Batch
            C: Channel
            T: Time or Sequence length
            F: Freq

        Args:
            data (ComplexTensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
            powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T)
        Returns:
            enhanced (ComplexTensor): (B, T, F)
            ilens (torch.Tensor): (B,)
            masks (torch.Tensor): (B, T, C, F)
        """

        def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None):
            """Beamforming with the provided statistics.

            Args:
                data (ComplexTensor): (B, F, C, T)
                ilens (torch.Tensor): (B,)
                psd_n (ComplexTensor):
                    Noise covariance matrix for MVDR (B, F, C, C)
                    Observation covariance matrix for MPDR/wMPDR (B, F, C, C)
                    Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C)
                psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C)
                psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C)
            Return:
                enhanced (ComplexTensor): (B, F, T)
                ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C)
            """
            # u: (B, C)
            if self.ref_channel < 0:
                u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens)
                u = u.double()
            else:
                if self.beamformer_type.endswith("_souden"):
                    # (optional) Create onehot vector for fixed reference microphone
                    u = torch.zeros(
                        *(data.size()[:-3] + (data.size(-2),)),
                        device=data.device,
                        dtype=torch.double
                    )
                    u[..., self.ref_channel].fill_(1)
                else:
                    # for simplifying computation in RTF-based beamforming
                    u = self.ref_channel

            if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"):
                ws = get_mvdr_vector_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"):
                ws = get_mvdr_vector(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = apply_beamforming_vector(ws, data.double())
            elif self.beamformer_type == "wpd":
                ws = get_WPD_filter_with_rtf(
                    psd_n.double(),
                    psd_speech.double(),
                    psd_distortion.double(),
                    iterations=self.rtf_iterations,
                    reference_vector=u,
                    normalize_ref_channel=self.ref_channel,
                    use_torch_solver=self.use_torch_solver,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, data.double(), self.bdelay, self.btaps
                )
            elif self.beamformer_type == "wpd_souden":
                ws = get_WPD_filter_v2(
                    psd_speech.double(),
                    psd_n.double(),
                    u,
                    diagonal_loading=self.diagonal_loading,
                    diag_eps=self.diag_eps,
                )
                enhanced = perform_WPD_filtering(
                    ws, data.double(), self.bdelay, self.btaps
                )
            else:
                raise ValueError(
                    "Not supporting beamformer_type={}".format(self.beamformer_type)
                )

            return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype)

        # data (B, T, C, F) -> (B, F, C, T)
        data = data.permute(0, 3, 2, 1)
        data_d = data.double()

        # mask: [(B, F, C, T)]
        masks, _ = self.mask(data, ilens)
        assert self.nmask == len(masks), len(masks)
        # floor masks to increase numerical stability
        if self.mask_flooring:
            masks = [torch.clamp(m, min=self.flooring_thres) for m in masks]

        if self.num_spk == 1:  # single-speaker case
            if self.use_noise_mask:
                # (mask_speech, mask_noise)
                mask_speech, mask_noise = masks
            else:
                # (mask_speech,)
                mask_speech = masks[0]
                mask_noise = 1 - mask_speech

            if self.beamformer_type.startswith(
                "wmpdr"
            ) or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real ** 2 + data_d.imag ** 2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = (power_input * mask_speech.double()).mean(dim=-2)
                else:
                    assert len(powers) == 1, len(powers)
                    powers = powers[0]
                inverse_power = 1 / torch.clamp(powers, min=self.eps)

            psd_speech = get_power_spectral_density_matrix(data_d, mask_speech.double())
            if mask_noise is not None and (
                self.beamformer_type == "mvdr_souden"
                or not self.beamformer_type.endswith("_souden")
            ):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double()
                )
            if self.beamformer_type == "mvdr":
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "mvdr_souden":
                enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech)
            elif self.beamformer_type == "mpdr":
                psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "mpdr_souden":
                psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()])
                enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech)
            elif self.beamformer_type == "wmpdr":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :], data_d.conj()],
                )
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "wmpdr_souden":
                psd_observed = FC.einsum(
                    "...ct,...et->...ce",
                    [data_d * inverse_power[..., None, :], data_d.conj()],
                )
                enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech)
            elif self.beamformer_type == "wpd":
                psd_observed_bar = get_covariances(
                    data_d, inverse_power, self.bdelay, self.btaps, get_vector=False
                )
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise
                )
            elif self.beamformer_type == "wpd_souden":
                psd_observed_bar = get_covariances(
                    data_d, inverse_power, self.bdelay, self.btaps, get_vector=False
                )
                enhanced, ws = apply_beamforming(
                    data, ilens, psd_observed_bar, psd_speech
                )
            else:
                raise ValueError(
                    "Not supporting beamformer_type={}".format(self.beamformer_type)
                )

            # (..., F, T) -> (..., T, F)
            enhanced = enhanced.transpose(-1, -2)
        else:  # multi-speaker case
            if self.use_noise_mask:
                # (mask_speech1, ..., mask_noise)
                mask_speech = list(masks[:-1])
                mask_noise = masks[-1]
            else:
                # (mask_speech1, ..., mask_speechX)
                mask_speech = list(masks)
                mask_noise = None

            if self.beamformer_type.startswith(
                "wmpdr"
            ) or self.beamformer_type.startswith("wpd"):
                if powers is None:
                    power_input = data_d.real ** 2 + data_d.imag ** 2
                    # Averaging along the channel axis: (..., C, T) -> (..., T)
                    powers = [
                        (power_input * m.double()).mean(dim=-2) for m in mask_speech
                    ]
                else:
                    assert len(powers) == self.num_spk, len(powers)
                inverse_power = [1 / torch.clamp(p, min=self.eps) for p in powers]

            psd_speeches = [
                get_power_spectral_density_matrix(data_d, mask.double())
                for mask in mask_speech
            ]
            if mask_noise is not None and (
                self.beamformer_type == "mvdr_souden"
                or not self.beamformer_type.endswith("_souden")
            ):
                # MVDR or other RTF-based formulas
                psd_noise = get_power_spectral_density_matrix(
                    data_d, mask_noise.double()
                )
            if self.beamformer_type in ("mpdr", "mpdr_souden"):
                psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()])
            elif self.beamformer_type in ("wmpdr", "wmpdr_souden"):
                psd_observed = [
                    FC.einsum(
                        "...ct,...et->...ce",
                        [data_d * inv_p[..., None, :], data_d.conj()],
                    )
                    for inv_p in inverse_power
                ]
            elif self.beamformer_type in ("wpd", "wpd_souden"):
                psd_observed_bar = [
                    get_covariances(
                        data_d, inv_p, self.bdelay, self.btaps, get_vector=False
                    )
                    for inv_p in inverse_power
                ]

            enhanced, ws = [], []
            for i in range(self.num_spk):
                psd_speech = psd_speeches.pop(i)
                if (
                    self.beamformer_type == "mvdr_souden"
                    or not self.beamformer_type.endswith("_souden")
                ):
                    psd_noise_i = (
                        psd_noise + sum(psd_speeches)
                        if mask_noise is not None
                        else sum(psd_speeches)
                    )
                # treat all other speakers' psd_speech as noises
                if self.beamformer_type == "mvdr":
                    enh, w = apply_beamforming(
                        data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i
                    )
                elif self.beamformer_type == "mvdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech)
                elif self.beamformer_type == "mpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed,
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "mpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech)
                elif self.beamformer_type == "wmpdr":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wmpdr_souden":
                    enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech)
                elif self.beamformer_type == "wpd":
                    enh, w = apply_beamforming(
                        data,
                        ilens,
                        psd_observed_bar[i],
                        psd_speech,
                        psd_distortion=psd_noise_i,
                    )
                elif self.beamformer_type == "wpd_souden":
                    enh, w = apply_beamforming(
                        data, ilens, psd_observed_bar[i], psd_speech
                    )
                else:
                    raise ValueError(
                        "Not supporting beamformer_type={}".format(self.beamformer_type)
                    )
                psd_speeches.insert(i, psd_speech)

                # (..., F, T) -> (..., T, F)
                enh = enh.transpose(-1, -2)
                enhanced.append(enh)
                ws.append(w)

        # (..., F, C, T) -> (..., T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return enhanced, ilens, masks

    def predict_mask(
        self, data: ComplexTensor, ilens: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]:
        """Predict masks for beamforming.

        Args:
            data (ComplexTensor): (B, T, C, F), double precision
            ilens (torch.Tensor): (B,)
        Returns:
            masks (torch.Tensor): (B, T, C, F)
            ilens (torch.Tensor): (B,)
        """
        masks, _ = self.mask(data.permute(0, 3, 2, 1).float(), ilens)
        # (B, F, C, T) -> (B, T, C, F)
        masks = [m.transpose(-1, -3) for m in masks]
        return masks, ilens


class AttentionReference(torch.nn.Module):
    def __init__(self, bidim, att_dim):
        super().__init__()
        self.mlp_psd = torch.nn.Linear(bidim, att_dim)
        self.gvec = torch.nn.Linear(att_dim, 1)

    def forward(
        self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0
    ) -> Tuple[torch.Tensor, torch.LongTensor]:
        """Attention-based reference forward function.

        Args:
            psd_in (ComplexTensor): (B, F, C, C)
            ilens (torch.Tensor): (B,)
            scaling (float):
        Returns:
            u (torch.Tensor): (B, C)
            ilens (torch.Tensor): (B,)
        """
        B, _, C = psd_in.size()[:3]
        assert psd_in.size(2) == psd_in.size(3), psd_in.size()
        # psd_in: (B, F, C, C)
        datatype = torch.bool if is_torch_1_3_plus else torch.uint8
        datatype2 = torch.bool if is_torch_1_2_plus else torch.uint8
        psd = psd_in.masked_fill(
            torch.eye(C, dtype=datatype, device=psd_in.device).type(datatype2), 0
        )
        # psd: (B, F, C, C) -> (B, C, F)
        psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2)

        # Calculate amplitude
        psd_feat = (psd.real ** 2 + psd.imag ** 2) ** 0.5

        # (B, C, F) -> (B, C, F2)
        mlp_psd = self.mlp_psd(psd_feat)
        # (B, C, F2) -> (B, C, 1) -> (B, C)
        e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1)
        u = F.softmax(scaling * e, dim=-1)
        return u, ilens