import os
import torch
import torch.distributed as dist
from torch import nn
from torch.utils.data import DistributedSampler
from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler
from utils.audio.io import save_wav
from utils.commons.base_task import BaseTask
from utils.commons.dataset_utils import data_loader
from utils.commons.hparams import hparams
from utils.commons.tensor_utils import tensors_to_scalars


class VocoderBaseTask(BaseTask):
    def __init__(self):
        super(VocoderBaseTask, self).__init__()
        self.max_sentences = hparams['max_sentences']
        self.max_valid_sentences = hparams['max_valid_sentences']
        if self.max_valid_sentences == -1:
            hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences
        self.dataset_cls = VocoderDataset

    @data_loader
    def train_dataloader(self):
        train_dataset = self.dataset_cls('train', shuffle=True)
        return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds'])

    @data_loader
    def val_dataloader(self):
        valid_dataset = self.dataset_cls('test', shuffle=False)
        return self.build_dataloader(valid_dataset, False, self.max_valid_sentences)

    @data_loader
    def test_dataloader(self):
        test_dataset = self.dataset_cls('test', shuffle=False)
        return self.build_dataloader(test_dataset, False, self.max_valid_sentences)

    def build_dataloader(self, dataset, shuffle, max_sentences, endless=False):
        world_size = 1
        rank = 0
        if dist.is_initialized():
            world_size = dist.get_world_size()
            rank = dist.get_rank()
        sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler
        train_sampler = sampler_cls(
            dataset=dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=shuffle,
        )
        return torch.utils.data.DataLoader(
            dataset=dataset,
            shuffle=False,
            collate_fn=dataset.collater,
            batch_size=max_sentences,
            num_workers=dataset.num_workers,
            sampler=train_sampler,
            pin_memory=True,
        )

    def build_optimizer(self, model):
        optimizer_gen = torch.optim.AdamW(self.model_gen.parameters(), lr=hparams['lr'],
                                          betas=[hparams['adam_b1'], hparams['adam_b2']])
        optimizer_disc = torch.optim.AdamW(self.model_disc.parameters(), lr=hparams['lr'],
                                           betas=[hparams['adam_b1'], hparams['adam_b2']])
        return [optimizer_gen, optimizer_disc]

    def build_scheduler(self, optimizer):
        return {
            "gen": torch.optim.lr_scheduler.StepLR(
                optimizer=optimizer[0],
                **hparams["generator_scheduler_params"]),
            "disc": torch.optim.lr_scheduler.StepLR(
                optimizer=optimizer[1],
                **hparams["discriminator_scheduler_params"]),
        }

    def validation_step(self, sample, batch_idx):
        outputs = {}
        total_loss, loss_output = self._training_step(sample, batch_idx, 0)
        outputs['losses'] = tensors_to_scalars(loss_output)
        outputs['total_loss'] = tensors_to_scalars(total_loss)

        if self.global_step % hparams['valid_infer_interval'] == 0 and \
                batch_idx < 10:
            mels = sample['mels']
            y = sample['wavs']
            f0 = sample['f0']
            y_ = self.model_gen(mels, f0)
            for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])):
                wav_pred = wav_pred / wav_pred.abs().max()
                if self.global_step == 0:
                    wav_gt = wav_gt / wav_gt.abs().max()
                    self.logger.add_audio(f'wav_{batch_idx}_{idx}_gt', wav_gt, self.global_step,
                                          hparams['audio_sample_rate'])
                self.logger.add_audio(f'wav_{batch_idx}_{idx}_pred', wav_pred, self.global_step,
                                      hparams['audio_sample_rate'])
        return outputs

    def test_start(self):
        self.gen_dir = os.path.join(hparams['work_dir'],
                                    f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
        os.makedirs(self.gen_dir, exist_ok=True)

    def test_step(self, sample, batch_idx):
        mels = sample['mels']
        y = sample['wavs']
        f0 = sample['f0']
        loss_output = {}
        y_ = self.model_gen(mels, f0)
        gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
        os.makedirs(gen_dir, exist_ok=True)
        for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])):
            wav_gt = wav_gt.clamp(-1, 1)
            wav_pred = wav_pred.clamp(-1, 1)
            save_wav(
                wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav',
                hparams['audio_sample_rate'])
            save_wav(
                wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav',
                hparams['audio_sample_rate'])
        return loss_output

    def test_end(self, outputs):
        return {}

    def on_before_optimization(self, opt_idx):
        if opt_idx == 0:
            nn.utils.clip_grad_norm_(self.model_gen.parameters(), hparams['generator_grad_norm'])
        else:
            nn.utils.clip_grad_norm_(self.model_disc.parameters(), hparams["discriminator_grad_norm"])

    def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx):
        if optimizer_idx == 0:
            self.scheduler['gen'].step(self.global_step // hparams['accumulate_grad_batches'])
        else:
            self.scheduler['disc'].step(self.global_step // hparams['accumulate_grad_batches'])