import collections
import functools
import os
import re

import yaml

class AttrDict(dict):
    """Dict as attribute trick."""

    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        for key, value in self.__dict__.items():
            if isinstance(value, dict):
                self.__dict__[key] = AttrDict(value)
            elif isinstance(value, (list, tuple)):
                if isinstance(value[0], dict):
                    self.__dict__[key] = [AttrDict(item) for item in value]
                else:
                    self.__dict__[key] = value

    def yaml(self):
        """Convert object to yaml dict and return."""
        yaml_dict = {}
        for key, value in self.__dict__.items():
            if isinstance(value, AttrDict):
                yaml_dict[key] = value.yaml()
            elif isinstance(value, list):
                if isinstance(value[0], AttrDict):
                    new_l = []
                    for item in value:
                        new_l.append(item.yaml())
                    yaml_dict[key] = new_l
                else:
                    yaml_dict[key] = value
            else:
                yaml_dict[key] = value
        return yaml_dict

    def __repr__(self):
        """Print all variables."""
        ret_str = []
        for key, value in self.__dict__.items():
            if isinstance(value, AttrDict):
                ret_str.append('{}:'.format(key))
                child_ret_str = value.__repr__().split('\n')
                for item in child_ret_str:
                    ret_str.append('    ' + item)
            elif isinstance(value, list):
                if isinstance(value[0], AttrDict):
                    ret_str.append('{}:'.format(key))
                    for item in value:
                        # Treat as AttrDict above.
                        child_ret_str = item.__repr__().split('\n')
                        for item in child_ret_str:
                            ret_str.append('    ' + item)
                else:
                    ret_str.append('{}: {}'.format(key, value))
            else:
                ret_str.append('{}: {}'.format(key, value))
        return '\n'.join(ret_str)


class Config(AttrDict):
    r"""Configuration class. This should include every human specifiable
    hyperparameter values for your training."""

    def __init__(self, filename=None, verbose=False, is_train=True):
        super(Config, self).__init__()
        # Set default parameters.
        # Logging.

        large_number = 1000000000
        self.snapshot_save_iter = large_number
        self.snapshot_save_epoch = large_number
        self.snapshot_save_start_iter = 0
        self.snapshot_save_start_epoch = 0
        self.image_save_iter = large_number
        self.eval_epoch = large_number
        self.start_eval_epoch = large_number
        self.eval_epoch = large_number
        self.max_epoch = large_number
        self.max_iter = large_number
        self.logging_iter = 100
        self.image_to_tensorboard=False
        self.which_iter = None
        self.resume = True


        self.checkpoints_dir = 'NTED'
        self.name = 'nted_checkpoint.pt'
        self.phase = 'train' if is_train else 'test'

        # Networks.
        self.gen = AttrDict(type='generators.dummy')
        self.dis = AttrDict(type='discriminators.dummy')

        # Optimizers.
        self.gen_optimizer = AttrDict(type='adam',
                                    lr=0.0001,
                                    adam_beta1=0.0,
                                    adam_beta2=0.999,
                                    eps=1e-8,
                                    lr_policy=AttrDict(iteration_mode=False,
                                                    type='step',
                                                    step_size=large_number,
                                                    gamma=1))
        self.dis_optimizer = AttrDict(type='adam',
                                lr=0.0001,
                                adam_beta1=0.0,
                                adam_beta2=0.999,
                                eps=1e-8,
                                lr_policy=AttrDict(iteration_mode=False,
                                                   type='step',
                                                   step_size=large_number,
                                                   gamma=1))
        # Data.
        self.data = AttrDict(name='dummy',
                             type='datasets.images',
                             num_workers=0)
        self.test_data = AttrDict(name='dummy',
                                  type='datasets.images',
                                  num_workers=0,
                                  test=AttrDict(is_lmdb=False,
                                                roots='',
                                                batch_size=1))
        self.trainer = AttrDict(
            image_to_tensorboard=False,
            hparam_to_tensorboard=False)

        # Cudnn.
        self.cudnn = AttrDict(deterministic=False,
                              benchmark=True)

        # Others.
        self.pretrained_weight = ''
        self.inference_args = AttrDict()


        # Update with given configurations.
        assert os.path.exists(filename), 'File {} not exist.'.format(filename)
        loader = yaml.SafeLoader
        loader.add_implicit_resolver(
            u'tag:yaml.org,2002:float',
            re.compile(u'''^(?:
             [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
            |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
            |\\.[0-9_]+(?:[eE][-+][0-9]+)?
            |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
            |[-+]?\\.(?:inf|Inf|INF)
            |\\.(?:nan|NaN|NAN))$''', re.X),
            list(u'-+0123456789.'))
        try:
            with open(filename, 'r') as f:
                cfg_dict = yaml.load(f, Loader=loader)
        except EnvironmentError:
            print('Please check the file with name of "%s"', filename)
        recursive_update(self, cfg_dict)

        # Put common opts in both gen and dis.
        if 'common' in cfg_dict:
            self.common = AttrDict(**cfg_dict['common'])
            self.gen.common = self.common
            self.dis.common = self.common


        if verbose:
            print(' config '.center(80, '-'))
            print(self.__repr__())
            print(''.center(80, '-'))


def rsetattr(obj, attr, val):
    """Recursively find object and set value"""
    pre, _, post = attr.rpartition('.')
    return setattr(rgetattr(obj, pre) if pre else obj, post, val)


def rgetattr(obj, attr, *args):
    """Recursively find object and return value"""

    def _getattr(obj, attr):
        r"""Get attribute."""
        return getattr(obj, attr, *args)

    return functools.reduce(_getattr, [obj] + attr.split('.'))


def recursive_update(d, u):
    """Recursively update AttrDict d with AttrDict u"""
    for key, value in u.items():
        if isinstance(value, collections.abc.Mapping):
            d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
        elif isinstance(value, (list, tuple)):
            if isinstance(value[0], dict):
                d.__dict__[key] = [AttrDict(item) for item in value]
            else:
                d.__dict__[key] = value
        else:
            d.__dict__[key] = value
    return d