import numpy as np
import torch
from torch import nn
from torchaudio.transforms import MelSpectrogram as TorchMelSpectrogram

from .hparams import HParams


class MelSpectrogram(nn.Module):
    def __init__(self, hp: HParams):
        """
        Torch implementation of Resemble's mel extraction.
        Note that the values are NOT identical to librosa's implementation
        due to floating point precisions.
        """
        super().__init__()
        self.hp = hp
        self.melspec = TorchMelSpectrogram(
            hp.wav_rate,
            n_fft=hp.n_fft,
            win_length=hp.win_size,
            hop_length=hp.hop_size,
            f_min=0,
            f_max=hp.wav_rate // 2,
            n_mels=hp.num_mels,
            power=1,
            normalized=False,
            # NOTE: Folowing librosa's default.
            pad_mode="constant",
            norm="slaney",
            mel_scale="slaney",
        )
        self.register_buffer(
            "stft_magnitude_min", torch.FloatTensor([hp.stft_magnitude_min])
        )
        self.min_level_db = 20 * np.log10(hp.stft_magnitude_min)
        self.preemphasis = hp.preemphasis
        self.hop_size = hp.hop_size

    def forward(self, wav, pad=True):
        """
        Args:
            wav: [B, T]
        """
        device = wav.device
        if wav.is_mps:
            wav = wav.cpu()
            self.to(wav.device)
        if self.preemphasis > 0:
            wav = torch.nn.functional.pad(wav, [1, 0], value=0)
            wav = wav[..., 1:] - self.preemphasis * wav[..., :-1]
        mel = self.melspec(wav)
        mel = self._amp_to_db(mel)
        mel_normed = self._normalize(mel)
        assert (
            not pad or mel_normed.shape[-1] == 1 + wav.shape[-1] // self.hop_size
        )  # Sanity check
        mel_normed = mel_normed.to(device)
        return mel_normed  # (M, T)

    def _normalize(self, s, headroom_db=15):
        return (s - self.min_level_db) / (-self.min_level_db + headroom_db)

    def _amp_to_db(self, x):
        return x.clamp_min(self.hp.stft_magnitude_min).log10() * 20