from re import X
import torch
import auraloss
import pytorch_lightning as pl
from typing import Tuple, List, Dict
from argparse import ArgumentParser


import deepafx_st.utils as utils
from deepafx_st.data.proxy import DSPProxyDataset
from deepafx_st.processors.proxy.tcn import ConditionalTCN
from deepafx_st.processors.spsa.channel import SPSAChannel
from deepafx_st.processors.dsp.peq import ParametricEQ
from deepafx_st.processors.dsp.compressor import Compressor


class ProxySystem(pl.LightningModule):
    def __init__(
        self,
        causal=True,
        nblocks=4,
        dilation_growth=8,
        kernel_size=13,
        channel_width=64,
        input_dir=None,
        processor="channel",
        batch_size=32,
        lr=3e-4,
        lr_patience=20,
        patience=10,
        preload=False,
        sample_rate=24000,
        shuffle=True,
        train_length=65536,
        train_examples_per_epoch=10000,
        val_length=131072,
        val_examples_per_epoch=1000,
        num_workers=16,
        output_gain=False,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        #print(f"Proxy Processor: {processor} @ fs={sample_rate} Hz")

        # construct both the true DSP...
        if self.hparams.processor == "peq":
            self.processor = ParametricEQ(self.hparams.sample_rate)
        elif self.hparams.processor == "comp":
            self.processor = Compressor(self.hparams.sample_rate)
        elif self.hparams.processor == "channel":
            self.processor = SPSAChannel(self.hparams.sample_rate)

        # and the neural network proxy
        self.proxy = ConditionalTCN(
            self.hparams.sample_rate,
            num_control_params=self.processor.num_control_params,
            causal=self.hparams.causal,
            nblocks=self.hparams.nblocks,
            channel_width=self.hparams.channel_width,
            kernel_size=self.hparams.kernel_size,
            dilation_growth=self.hparams.dilation_growth,
        )

        self.receptive_field = self.proxy.compute_receptive_field()

        self.recon_losses = {}
        self.recon_loss_weights = {}

        self.recon_losses["mrstft"] = auraloss.freq.MultiResolutionSTFTLoss(
            fft_sizes=[32, 128, 512, 2048, 8192, 32768],
            hop_sizes=[16, 64, 256, 1024, 4096, 16384],
            win_lengths=[32, 128, 512, 2048, 8192, 32768],
            w_sc=0.0,
            w_phs=0.0,
            w_lin_mag=1.0,
            w_log_mag=1.0,
        )
        self.recon_loss_weights["mrstft"] = 1.0

        self.recon_losses["l1"] = torch.nn.L1Loss()
        self.recon_loss_weights["l1"] = 100.0

    def forward(self, x, p, use_dsp=False, sample_rate=24000, **kwargs):
        """Use the pre-trained neural network proxy effect."""
        bs, chs, samp = x.size()
        if not use_dsp:
            y = self.proxy(x, p)
            # manually apply the makeup gain parameter
            if self.hparams.output_gain and not self.hparams.processor == "peq":
                gain_db = (p[..., -1] * 96) - 48
                gain_ln = 10 ** (gain_db / 20.0)
                y *= gain_ln.view(bs, chs, 1)
        else:
            with torch.no_grad():
                bs, chs, s = x.shape

                if self.hparams.output_gain and not self.hparams.processor == "peq":
                    # override makeup gain
                    gain_db = (p[..., -1] * 96) - 48
                    gain_ln = 10 ** (gain_db / 20.0)
                    p[..., -1] = 0.5

                if self.hparams.processor == "channel":
                    y_temp = self.processor(x.cpu(), p.cpu())
                    y_temp = y_temp.view(bs, chs, s).type_as(x)
                else:
                    y_temp = self.processor(
                        x.cpu().numpy(),
                        p.cpu().numpy(),
                        sample_rate,
                    )
                    y_temp = torch.tensor(y_temp).view(bs, chs, s).type_as(x)

                y = y_temp.type_as(x).view(bs, 1, -1)

                if self.hparams.output_gain and not self.hparams.processor == "peq":
                    y *= gain_ln.view(bs, chs, 1)

        return y

    def common_step(
        self,
        batch: Tuple,
        batch_idx: int,
        optimizer_idx: int = 0,
        train: bool = True,
    ):
        loss = 0
        x, y, p = batch

        y_hat = self(x, p)

        # compute loss
        for loss_idx, (loss_name, loss_fn) in enumerate(self.recon_losses.items()):
            tmp_loss = loss_fn(y_hat.float(), y.float())
            loss += self.recon_loss_weights[loss_name] * tmp_loss

            self.log(
                f"train_loss/{loss_name}" if train else f"val_loss/{loss_name}",
                tmp_loss,
                on_step=True,
                on_epoch=True,
                prog_bar=False,
                logger=True,
                sync_dist=True,
            )

        if not train:
            # store audio data
            data_dict = {
                "x": x.float().cpu(),
                "y": y.float().cpu(),
                "p": p.float().cpu(),
                "y_hat": y_hat.float().cpu(),
            }
        else:
            data_dict = {}

        self.log(
            "train_loss" if train else "val_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            sync_dist=True,
        )

        return loss, data_dict

    def training_step(self, batch, batch_idx, optimizer_idx=0):
        loss, _ = self.common_step(batch, batch_idx)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, data_dict = self.common_step(batch, batch_idx, train=False)

        if batch_idx == 0:
            return data_dict

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.proxy.parameters(),
            lr=self.hparams.lr,
            betas=(0.9, 0.999),
        )

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=self.hparams.lr_patience,
            verbose=True,
        )

        return [optimizer], {"scheduler": scheduler, "monitor": "val_loss"}

    def train_dataloader(self):

        train_dataset = DSPProxyDataset(
            self.hparams.input_dir,
            self.processor,
            self.hparams.processor,  # name
            subset="train",
            length=self.hparams.train_length,
            num_examples_per_epoch=self.hparams.train_examples_per_epoch,
            half=True if self.hparams.precision == 16 else False,
            buffer_size_gb=self.hparams.buffer_size_gb,
            buffer_reload_rate=self.hparams.buffer_reload_rate,
        )

        g = torch.Generator()
        g.manual_seed(0)

        return torch.utils.data.DataLoader(
            train_dataset,
            num_workers=self.hparams.num_workers,
            batch_size=self.hparams.batch_size,
            worker_init_fn=utils.seed_worker,
            generator=g,
            pin_memory=True,
        )

    def val_dataloader(self):

        val_dataset = DSPProxyDataset(
            self.hparams.input_dir,
            self.processor,
            self.hparams.processor,  # name
            subset="val",
            length=self.hparams.val_length,
            num_examples_per_epoch=self.hparams.val_examples_per_epoch,
            half=True if self.hparams.precision == 16 else False,
            buffer_size_gb=self.hparams.buffer_size_gb,
            buffer_reload_rate=self.hparams.buffer_reload_rate,
        )

        g = torch.Generator()
        g.manual_seed(0)

        return torch.utils.data.DataLoader(
            val_dataset,
            num_workers=self.hparams.num_workers,
            batch_size=self.hparams.batch_size,
            worker_init_fn=utils.seed_worker,
            generator=g,
            pin_memory=True,
        )

    @staticmethod
    def count_control_params(plugin_config):
        num_control_params = 0

        for plugin in plugin_config["plugins"]:
            for port in plugin["ports"]:
                if port["optim"]:
                    num_control_params += 1

        return num_control_params

    # add any model hyperparameters here
    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        # --- Model  ---
        parser.add_argument("--causal", action="store_true")
        parser.add_argument("--output_gain", action="store_true")
        parser.add_argument("--dilation_growth", type=int, default=8)
        parser.add_argument("--nblocks", type=int, default=4)
        parser.add_argument("--kernel_size", type=int, default=13)
        parser.add_argument("--channel_width", type=int, default=13)
        # --- Training  ---
        parser.add_argument("--input_dir", type=str)
        parser.add_argument("--processor", type=str)
        parser.add_argument("--batch_size", type=int, default=32)
        parser.add_argument("--lr", type=float, default=3e-4)
        parser.add_argument("--lr_patience", type=int, default=20)
        parser.add_argument("--patience", type=int, default=10)
        parser.add_argument("--preload", action="store_true")
        parser.add_argument("--sample_rate", type=int, default=24000)
        parser.add_argument("--shuffle", type=bool, default=True)
        parser.add_argument("--train_length", type=int, default=65536)
        parser.add_argument("--train_examples_per_epoch", type=int, default=10000)
        parser.add_argument("--val_length", type=int, default=131072)
        parser.add_argument("--val_examples_per_epoch", type=int, default=1000)
        parser.add_argument("--num_workers", type=int, default=8)
        parser.add_argument("--buffer_reload_rate", type=int, default=1000)
        parser.add_argument("--buffer_size_gb", type=float, default=1.0)

        return parser