import torch
import torchaudio
import scipy.signal
import numpy as np
import pyloudnorm as pyln
import matplotlib.pyplot as plt
from deepafx_st.processors.dsp.compressor import compressor

from tqdm import tqdm


class BaselineEQ(torch.nn.Module):
    def __init__(
        self,
        ntaps: int = 63,
        n_fft: int = 65536,
        sample_rate: float = 44100,
    ):
        super().__init__()
        self.ntaps = ntaps
        self.n_fft = n_fft
        self.sample_rate = sample_rate

        # compute the target spectrum
        # print("Computing target spectrum...")
        # self.target_spec, self.sm_target_spec = self.analyze_speech_dataset(filepaths)
        # self.plot_spectrum(self.target_spec, filename="targetEQ")
        # self.plot_spectrum(self.sm_target_spec, filename="targetEQsm")

    def forward(self, x, y):

        bs, ch, s = x.size()

        x = x.view(bs * ch, -1)
        y = y.view(bs * ch, -1)

        in_spec = self.get_average_spectrum(x)
        ref_spec = self.get_average_spectrum(y)

        sm_in_spec = self.smooth_spectrum(in_spec)
        sm_ref_spec = self.smooth_spectrum(ref_spec)

        # self.plot_spectrum(in_spec, filename="inSpec")
        # self.plot_spectrum(sm_in_spec, filename="inSpecsm")

        # design inverse FIR filter to match target EQ
        freqs = np.linspace(0, 1.0, num=(self.n_fft // 2) + 1)
        response = sm_ref_spec / sm_in_spec
        response[-1] = 0.0  # zero gain at nyquist

        b = scipy.signal.firwin2(
            self.ntaps,
            freqs * (self.sample_rate / 2),
            response,
            fs=self.sample_rate,
        )

        # scale the coefficients for less intense filter
        # clearb *= 0.5

        # apply the filter
        x_filt = scipy.signal.lfilter(b, [1.0], x.numpy())
        x_filt = torch.tensor(x_filt.astype("float32"))

        if False:
            # plot the filter response
            w, h = scipy.signal.freqz(b, fs=self.sample_rate, worN=response.shape[-1])

            fig, ax1 = plt.subplots()
            ax1.set_title("Digital filter frequency response")
            ax1.plot(w, 20 * np.log10(abs(h + 1e-8)))
            ax1.plot(w, 20 * np.log10(abs(response + 1e-8)))

            ax1.set_xscale("log")
            ax1.set_ylim([-12, 12])
            plt.grid(c="lightgray")
            plt.savefig(f"inverse.png")

            x_filt_avg_spec = self.get_average_spectrum(x_filt)
            sm_x_filt_avg_spec = self.smooth_spectrum(x_filt_avg_spec)
            y_avg_spec = self.get_average_spectrum(y)
            sm_y_avg_spec = self.smooth_spectrum(y_avg_spec)
            compare = torch.stack(
                [
                    torch.tensor(sm_in_spec),
                    torch.tensor(sm_x_filt_avg_spec),
                    torch.tensor(sm_ref_spec),
                    torch.tensor(sm_y_avg_spec),
                ]
            )
            self.plot_multi_spectrum(
                compare,
                legend=["in", "out", "target curve", "actual target"],
                filename="outSpec",
            )

        return x_filt

    def analyze_speech_dataset(self, filepaths, peak=-3.0):
        avg_spec = []
        for filepath in tqdm(filepaths, ncols=80):
            x, sr = torchaudio.load(filepath)
            x /= x.abs().max()
            x *= 10 ** (peak / 20.0)
            avg_spec.append(self.get_average_spectrum(x))
        avg_specs = torch.stack(avg_spec)

        avg_spec = avg_specs.mean(dim=0).numpy()
        avg_spec_std = avg_specs.std(dim=0).numpy()

        # self.plot_multi_spectrum(avg_specs, filename="allTargetEQs")
        # self.plot_spectrum_stats(avg_spec, avg_spec_std, filename="targetEQstats")

        sm_avg_spec = self.smooth_spectrum(avg_spec)

        return avg_spec, sm_avg_spec

    def smooth_spectrum(self, H):
        # apply Savgol filter for smoothed target curve
        return scipy.signal.savgol_filter(H, 1025, 2)

    def get_average_spectrum(self, x):

        # x = x[:, : self.n_fft]
        X = torch.stft(x, self.n_fft, return_complex=True, normalized=True)
        # fft_size = self.next_power_of_2(x.shape[-1])
        # X = torch.fft.rfft(x, n=fft_size)

        X = X.abs()  # convert to magnitude
        X = X.mean(dim=-1).view(-1)  # average across frames

        return X

    @staticmethod
    def next_power_of_2(x):
        return 1 if x == 0 else int(2 ** np.ceil(np.log2(x)))

    def plot_multi_spectrum(self, Hs, legend=[], filename=None):

        bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
        freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)

        fig, ax1 = plt.subplots()

        for H in Hs:
            ax1.plot(
                freqs,
                20 * np.log10(abs(H) + 1e-8),
            )

        plt.legend(legend)

        # avg_spec = Hs.mean(dim=0).numpy()
        # ax1.plot(freqs, 20 * np.log10(avg_spec), color="k", linewidth=2)

        ax1.set_xscale("log")
        ax1.set_ylim([-80, 0])
        plt.grid(c="lightgray")

        if filename is not None:
            plt.savefig(f"{filename}.png")

    def plot_spectrum_stats(self, H_mean, H_std, filename=None):
        bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
        freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)

        fig, ax1 = plt.subplots()
        ax1.plot(freqs, 20 * np.log10(H_mean))
        ax1.plot(
            freqs,
            (20 * np.log10(H_mean)) + (20 * np.log10(H_std)),
            linestyle="--",
            color="k",
        )
        ax1.plot(
            freqs,
            (20 * np.log10(H_mean)) - (20 * np.log10(H_std)),
            linestyle="--",
            color="k",
        )

        ax1.set_xscale("log")
        ax1.set_ylim([-80, 0])
        plt.grid(c="lightgray")

        if filename is not None:
            plt.savefig(f"{filename}.png")

    def plot_spectrum(self, H, legend=[], filename=None):

        bin_width = (self.sample_rate / 2) / (self.n_fft // 2)
        freqs = np.arange(0, (self.sample_rate / 2) + bin_width, step=bin_width)

        fig, ax1 = plt.subplots()
        ax1.plot(freqs, 20 * np.log10(H))
        ax1.set_xscale("log")
        ax1.set_ylim([-80, 0])
        plt.grid(c="lightgray")

        plt.legend(legend)

        if filename is not None:
            plt.savefig(f"{filename}.png")


class BaslineComp(torch.nn.Module):
    def __init__(
        self,
        sample_rate: float = 44100,
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.meter = pyln.Meter(sample_rate)

    def forward(self, x, y):

        x_lufs = self.meter.integrated_loudness(x.view(-1).numpy())
        y_lufs = self.meter.integrated_loudness(y.view(-1).numpy())

        delta_lufs = y_lufs - x_lufs

        threshold = 0.0
        x_comp = x
        x_comp_new = x
        while delta_lufs > 0.5 and threshold > -80.0:
            x_comp = x_comp_new  # use the last setting
            x_comp_new = compressor(
                x.view(-1).numpy(),
                self.sample_rate,
                threshold=threshold,
                ratio=3,
                attack_time=0.001,
                release_time=0.05,
                knee_dB=6.0,
                makeup_gain_dB=0.0,
            )
            x_comp_new = torch.tensor(x_comp_new)
            x_comp_new /= x_comp_new.abs().max()
            x_comp_new *= 10 ** (-12.0 / 20)
            x_lufs = self.meter.integrated_loudness(x_comp_new.view(-1).numpy())
            delta_lufs = y_lufs - x_lufs
            threshold -= 0.5

        return x_comp.view(1, 1, -1)


class BaselineEQAndComp(torch.nn.Module):
    def __init__(
        self,
        ntaps=63,
        n_fft=65536,
        sample_rate=44100,
        block_size=1024,
        plugin_config=None,
    ):
        super().__init__()
        self.eq = BaselineEQ(ntaps, n_fft, sample_rate)
        self.comp = BaslineComp(sample_rate)

    def forward(self, x, y):

        with torch.inference_mode():
            x /= x.abs().max()
            y /= y.abs().max()
            x *= 10 ** (-12.0 / 20)
            y *= 10 ** (-12.0 / 20)

            x = self.eq(x, y)

            x /= x.abs().max()
            y /= y.abs().max()
            x *= 10 ** (-12.0 / 20)
            y *= 10 ** (-12.0 / 20)

            x = self.comp(x, y)

            x /= x.abs().max()
            x *= 10 ** (-12.0 / 20)

        return x