import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import importlib
import PIL.Image as Image

import models
import datasets
from evaluator.ssim import SSIM, MSSSIM
import lpips
from models.loss import GANHingeLoss
from utils import set_logger, magic_image_handler

NUM_TEST_SAVE_IMAGE = 10


class FontLightningModule(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args

        self.losses = {}
        self.metrics = {}
        self.networks = nn.ModuleDict(self.build_models())
        self.module_keys = list(self.networks.keys())

        self.losses = self.build_losses()
        self.metrics = self.build_metrics()

        self.opt_tag = {key: None for key in self.networks.keys()}
        self.sched_tag = {key: None for key in self.networks.keys()}
        self.sched_use = False
        # self.automatic_optimization = False

        self.train_d_content = True
        self.train_d_style = True

    def build_models(self):
        networks = {}
        for key, hp_model in self.args.models.items():
            key_ = key.lower()
            if 'g' == key_[0]:
                model_ = models.Generator(hp_model)
            elif 'd' == key_[0]:
                model_ = models.PatchGANDiscriminator(hp_model)  # TODO: add option for selecting discriminator
            else:
                raise ValueError(f"No key such as {key}")

            networks[key.lower()] = model_
        return networks

    def build_losses(self):
        losses_dict = {}
        losses_dict['L1'] = torch.nn.L1Loss()

        if 'd_content' in self.module_keys:
            losses_dict['GANLoss_content'] = GANHingeLoss()
        if 'd_style' in self.module_keys:
            losses_dict['GANLoss_style'] = GANHingeLoss()

        return losses_dict

    def build_metrics(self):
        metrics_dict = nn.ModuleDict()
        metrics_dict['ssim'] = SSIM(val_range=1)  # img value is in [0, 1]
        metrics_dict['msssim'] = MSSSIM(weights=[0.45, 0.3, 0.25], val_range=1)  # since imsize=64, len(weight)<=3
        metrics_dict['lpips'] = lpips.LPIPS(net='vgg')
        return metrics_dict

    def configure_optimizers(self):
        optims = {}
        for key, args_model in self.args.models.items():
            key = key.lower()
            if args_model['optim'] is not None:
                args_optim = args_model['optim']
                module, cls = args_optim['class'].rsplit(".", 1)
                O = getattr(importlib.import_module(module, package=None), cls)
                o = O([p for p in self.networks[key].parameters() if p.requires_grad],
                      lr=args_optim.lr, betas=args_optim.betas)

                optims[key] = o

        optim_module_keys = optims.keys()

        count = 0
        optim_list = []

        for _key in self.module_keys:
            if _key in optim_module_keys:
                optim_list.append(optims[_key])
                self.opt_tag[_key] = count
                count += 1

        return optim_list

    def forward(self, content_images, style_images):
        return self.networks['g']((content_images, style_images))

    def common_forward(self, batch, batch_idx):
        loss = {}
        logs = {}

        content_images = batch['content_images']
        style_images = batch['style_images']
        gt_images = batch['gt_images']
        image_paths = batch['image_paths']
        char_idx = batch['char_idx']

        generated_images = self(content_images, style_images)

        # l1 loss
        loss['g_L1'] = self.losses['L1'](generated_images, gt_images)
        loss['g_backward'] = loss['g_L1'] * self.args.logging.lambda_L1

        # loss for training generator
        if 'd_content' in self.module_keys:
            loss = self.d_content_loss_for_G(content_images, generated_images, loss)

        if 'd_style' in self.networks.keys():
            loss = self.d_style_loss_for_G(style_images, generated_images, loss)

        # loss for training discriminator
        generated_images = generated_images.detach()

        if 'd_content' in self.module_keys:
            if self.train_d_content:
                loss = self.d_content_loss_for_D(content_images, generated_images, gt_images, loss)

        if 'd_style' in self.module_keys:
            if self.train_d_style:
                loss = self.d_style_loss_for_D(style_images, generated_images, gt_images, loss)

        logs['content_images'] = content_images
        logs['style_images'] = style_images
        logs['gt_images'] = gt_images
        logs['generated_images'] = generated_images

        return loss, logs

    @property
    def automatic_optimization(self):
        return False

    def training_step(self, batch, batch_idx):
        metrics = {}
        # forward
        loss, logs = self.common_forward(batch, batch_idx)

        if self.global_step % self.args.logging.freq['train'] == 0:
            with torch.no_grad():
                metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))

        # backward
        opts = self.optimizers()

        opts[self.opt_tag['g']].zero_grad()
        self.manual_backward(loss['g_backward'])

        if 'd_content' in self.module_keys:
            if self.train_d_content:
                opts[self.opt_tag['d_content']].zero_grad()
                self.manual_backward(loss['dcontent_backward'])

        if 'd_style' in self.module_keys:
            if self.train_d_style:
                opts[self.opt_tag['d_style']].zero_grad()
                self.manual_backward(loss['dstyle_backward'])

        opts[self.opt_tag['g']].step()

        if 'd_content' in self.module_keys:
            if self.train_d_content:
                opts[self.opt_tag['d_content']].step()

        if 'd_style' in self.module_keys:
            if self.train_d_style:
                opts[self.opt_tag['d_style']].step()

        if self.global_step % self.args.logging.freq['train'] == 0:
            self.custom_log(loss, metrics, logs, mode='train')

    def validation_step(self, batch, batch_idx):
        metrics = {}
        loss, logs = self.common_forward(batch, batch_idx)
        self.custom_log(loss, metrics, logs, mode='eval')

    def test_step(self, batch, batch_idx):
        metrics = {}
        loss, logs = self.common_forward(batch, batch_idx)
        metrics.update(self.calc_metrics(logs['gt_images'], logs['generated_images']))

        if batch_idx < NUM_TEST_SAVE_IMAGE:
            for key, value in logs.items():
                if 'image' in key:
                    sample_images = (magic_image_handler(value) * 255)[..., 0].astype(np.uint8)
                    Image.fromarray(sample_images).save(f"{batch_idx:02d}_{key}.png")

        return loss, logs, metrics

    def test_epoch_end(self, test_step_outputs):
        # do something with the outputs of all test batches
        # all_test_preds = test_step_outputs.metrics
        ssim_list = []
        msssim_list = []

        for _, test_output in enumerate(test_step_outputs):

            ssim_list.append(test_output[2]['SSIM'].cpu().numpy())
            msssim_list.append(test_output[2]['MSSSIM'].cpu().numpy())

        print(f"SSIM: {np.mean(ssim_list)}")
        print(f"MSSSIM: {np.mean(msssim_list)}")

    def common_dataloader(self, mode='train', batch_size=None):
        dataset_cls = getattr(datasets, self.args.datasets.type)
        dataset_config = getattr(self.args.datasets, mode)
        dataset = dataset_cls(dataset_config, mode=mode)
        _batch_size = batch_size if batch_size is not None else dataset_config.batch_size
        dataloader = DataLoader(dataset,
                                shuffle=dataset_config.shuffle,
                                batch_size=_batch_size,
                                num_workers=dataset_config.num_workers,
                                drop_last=True)

        return dataloader

    def train_dataloader(self):
        return self.common_dataloader(mode='train')

    def val_dataloader(self):
        return self.common_dataloader(mode='eval')

    def test_dataloader(self):
        return self.common_dataloader(mode='eval')

    def calc_metrics(self, gt_images, generated_images):
        """

        :param gt_images:
        :param generated_images:
        :return:
        """
        metrics = {}
        _gt = torch.clamp(gt_images.clone(), 0, 1)
        _gen = torch.clamp(generated_images.clone(), 0, 1)
        metrics['SSIM'] = self.metrics['ssim'](_gt, _gen)
        msssim_value = self.metrics['msssim'](_gt, _gen)
        metrics['MSSSIM'] = msssim_value if not torch.isnan(msssim_value) else torch.tensor(0.).type_as(_gt)
        metrics['LPIPS'] = self.metrics['lpips'](_gt * 2 - 1, _gen * 2 - 1).squeeze().mean()
        return metrics

    # region step
    def d_content_loss_for_G(self, content_images, generated_images, loss):
        pred_generated = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))
        loss['g_gan_content'] = self.losses['GANLoss_content'](pred_generated, True, for_discriminator=False)

        loss['g_backward'] += loss['g_gan_content']
        return loss

    def d_content_loss_for_D(self, content_images, generated_images, gt_images, loss):
        # D
        if 'd_content' in self.module_keys:
            if self.train_d_content:
                pred_gt_images = self.networks['d_content'](torch.cat([content_images, gt_images], dim=1))
                pred_generated_images = self.networks['d_content'](torch.cat([content_images, generated_images], dim=1))

                loss['dcontent_gt'] = self.losses['GANLoss_content'](pred_gt_images, True, for_discriminator=True)
                loss['dcontent_gen'] = self.losses['GANLoss_content'](pred_generated_images, False, for_discriminator=True)
                loss['dcontent_backward'] = (loss['dcontent_gt'] + loss['dcontent_gen'])

        return loss

    def d_style_loss_for_G(self, style_images, generated_images, loss):
        pred_generated = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))
        loss['g_gan_style'] = self.losses['GANLoss_style'](pred_generated, True, for_discriminator=False)

        assert self.train_d_style
        loss['g_backward'] += loss['g_gan_style']
        return loss

    def d_style_loss_for_D(self, style_images, generated_images, gt_images, loss):
        pred_gt_images = self.networks['d_style'](torch.cat([style_images, gt_images], dim=1))
        pred_generated_images = self.networks['d_style'](torch.cat([style_images, generated_images], dim=1))

        loss['dstyle_gt'] = self.losses['GANLoss_style'](pred_gt_images, True, for_discriminator=True)
        loss['dstyle_gen'] = self.losses['GANLoss_style'](pred_generated_images, False, for_discriminator=True)
        loss['dstyle_backward'] = (loss['dstyle_gt'] + loss['dstyle_gen'])

        return loss

    def custom_log(self, loss, metrics, logs, mode):
        # logging values with tensorboard
        for loss_full_key, value in loss.items():
            model_type, loss_type = loss_full_key.split('_')[0], "_".join(loss_full_key.split('_')[1:])
            self.log(f'{model_type}/{mode}_{loss_type}', value)

        for metric_full_key, value in metrics.items():
            model_type, metric_type = metric_full_key.split('_')[0], "_".join(metric_full_key.split('_')[1:])
            self.log(f'{model_type}/{mode}_{metric_type}', value)

        # logging images, params, etc.
        tensorboard = self.logger.experiment
        for key, value in logs.items():
            if 'image' in key:
                sample_images = magic_image_handler(value)
                tensorboard.add_image(f"{mode}/" + key, sample_images, self.global_step, dataformats='HWC')
            elif 'param' in key:
                tensorboard.add_histogram(f"{mode}" + key, value, self.global_step)
            else:
                raise RuntimeError(f"Only logging with one of keywords: image, param | current input: {key}")