"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""

import argparse
import json
import sys
import os

sys.path.append('.')

from pdb import set_trace as st
import imageio
import numpy as np
import torch as th
import torch.distributed as dist

from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    add_dict_to_argparser,
    args_to_dict,
    continuous_diffusion_defaults,
    control_net_defaults,
)

from pathlib import Path

from tqdm import tqdm, trange
import dnnlib
from dnnlib.util import EasyDict, InfiniteSampler
from nsr.train_util_diffusion import TrainLoop3DDiffusion as TrainLoop
from guided_diffusion.continuous_diffusion import make_diffusion as make_sde_diffusion
import nsr
import nsr.lsgm
from nsr.script_util import create_3DAE_model, encoder_and_nsr_defaults, loss_defaults, AE_with_Diffusion, rendering_options_defaults, eg3d_options_default, dataset_defaults

from datasets.shapenet import load_eval_data
from torch.utils.data import Subset
from datasets.eg3d_dataset import init_dataset_kwargs
from datasets.eg3d_dataset import LMDBDataset_MV_Compressed_eg3d

SEED = 0


def main(args):

    # args.rendering_kwargs = rendering_options_defaults(args)

    dist_util.setup_dist(args)
    logger.configure(dir=args.logdir)

    th.cuda.empty_cache()

    th.cuda.manual_seed_all(SEED)
    np.random.seed(SEED)

    # * set denoise model args
    logger.log("creating model and diffusion...")
    args.img_size = [args.image_size_encoder]
    # ! no longer required for LDM
    # args.denoise_in_channels = args.out_chans
    # args.denoise_out_channels = args.out_chans
    args.image_size = args.image_size_encoder  # 224, follow the triplane size

    denoise_model, diffusion = create_model_and_diffusion(
        **args_to_dict(args,
                       model_and_diffusion_defaults().keys()))

    if 'cldm' in args.trainer_name:
        assert isinstance(denoise_model, tuple)
        denoise_model, controlNet = denoise_model

        controlNet.to(dist_util.dev())
        controlNet.train()
    else:
        controlNet = None

    opts = eg3d_options_default()
    if args.sr_training:
        args.sr_kwargs = dnnlib.EasyDict(
            channel_base=opts.cbase,
            channel_max=opts.cmax,
            fused_modconv_default='inference_only',
            use_noise=True
        )  # ! close noise injection? since noise_mode='none' in eg3d

    # denoise_model.load_state_dict(
    #     dist_util.load_state_dict(args.ddpm_model_path, map_location="cpu"))
    denoise_model.to(dist_util.dev())
    if args.use_fp16:
        denoise_model.convert_to_fp16()
    denoise_model.eval()

    # * auto-encoder reconstruction model
    logger.log("creating 3DAE...")
    auto_encoder = create_3DAE_model(
        **args_to_dict(args,
                       encoder_and_nsr_defaults().keys()))

    # logger.log("AE triplane decoder reuses G_ema decoder...")
    # auto_encoder.decoder.register_buffer('w_avg', G_ema.backbone.mapping.w_avg)

    # print(auto_encoder.decoder.w_avg.shape) # [512]

    # auto_encoder.load_state_dict(
    #     dist_util.load_state_dict(args.rec_model_path, map_location="cpu"))

    auto_encoder.to(dist_util.dev())
    auto_encoder.eval()

    # TODO, how to set the scale?
    logger.log("create dataset")

    # data = None

    if args.objv_dataset:
        from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data, load_wds_data
    else:  # shapenet
        from datasets.shapenet import load_data, load_eval_data, load_memory_data
    
    eval_data = None

    # if args.cfg in ('afhq', 'ffhq'):
    #     # ! load data
    #     if args.use_lmdb:
    #         logger.log("creating LMDB eg3d data loader...")
    #         training_set = LMDBDataset_MV_Compressed_eg3d(
    #             args.data_dir,
    #             args.image_size,
    #             args.image_size_encoder,
    #         )
    #     else:

    #         logger.log("creating eg3d data loader...")
    #         training_set_kwargs, dataset_name = init_dataset_kwargs(
    #             data=args.data_dir,
    #             class_name='datasets.eg3d_dataset.ImageFolderDataset'
    #         )  # only load pose here
    #         # if args.cond and not training_set_kwargs.use_labels:
    #         # raise Exception('check here')

    #         # training_set_kwargs.use_labels = args.cond
    #         training_set_kwargs.use_labels = True
    #         training_set_kwargs.xflip = True
    #         training_set_kwargs.random_seed = SEED
    #         # desc = f'{args.cfg:s}-{dataset_name:s}-gpus{c.num_gpus:d}-batch{c.batch_size:d}-gamma{c.loss_kwargs.r1_gamma:g}'

    #         # * construct ffhq/afhq dataset
    #         training_set = dnnlib.util.construct_class_by_name(
    #             **training_set_kwargs)  # subclass of training.dataset.Dataset

    #         training_set = dnnlib.util.construct_class_by_name(
    #             **training_set_kwargs)  # subclass of training.dataset.Dataset

    #         # training_set_sampler = InfiniteSampler(
    #         #     dataset=training_set,
    #         #     rank=dist_util.get_rank(),
    #         #     num_replicas=dist_util.get_world_size(),
    #         #     seed=SEED)

    #         # data = iter(
    #         #     th.utils.data.DataLoader(dataset=training_set,
    #         #                             sampler=training_set_sampler,
    #         #                             batch_size=args.batch_size,
    #         #                             pin_memory=True,
    #         #                             num_workers=args.num_workers,))
    #         #                             #  prefetch_factor=2))

    # #     training_set_sampler = InfiniteSampler(
    # #         dataset=training_set,
    # #         rank=dist_util.get_rank(),
    # #         num_replicas=dist_util.get_world_size(),
    # #         seed=SEED)

    # #     data = iter(
    # #         th.utils.data.DataLoader(
    # #             dataset=training_set,
    # #             sampler=training_set_sampler,
    # #             batch_size=args.batch_size,
    # #             pin_memory=True,
    # #             num_workers=args.num_workers,
    # #             persistent_workers=args.num_workers > 0,
    # #             # prefetch_factor=max(8//args.batch_size, 2),
    # #         ))

    #     eval_data = th.utils.data.DataLoader(dataset=Subset(
    #         training_set, np.arange(25)),
    #                                          batch_size=args.eval_batch_size,
    #                                          num_workers=1)

    # else:

    #     logger.log("creating data loader...")

    #     # if args.objv_dataset:
    #     #     from datasets.g_buffer_objaverse import load_data, load_eval_data, load_memory_data
    #     # else: # shapenet
    #     #     from datasets.shapenet import load_data, load_eval_data, load_memory_data

    #     # eval_data = load_eval_data(
    #     #     file_path=args.eval_data_dir,
    #     #     batch_size=args.eval_batch_size,
    #     #     reso=args.image_size,
    #     #     reso_encoder=args.image_size_encoder,  # 224 -> 128
    #     #     num_workers=args.num_workers,
    #     #     load_depth=True,  # for evaluation
    #     #     interval=args.interval,
    #     #     use_lmdb=args.use_lmdb,
    #     # )

    #     if args.use_wds:
    #         if args.eval_data_dir == 'NONE':
    #             with open(args.eval_shards_lst) as f:
    #                 eval_shards_lst = [url.strip() for url in f.readlines()]
    #         else:
    #             eval_shards_lst = args.eval_data_dir  # auto expanded

    #         eval_data = load_wds_data(
    #             eval_shards_lst, args.image_size, args.image_size_encoder,
    #             args.eval_batch_size, args.num_workers,
    #             **args_to_dict(args,
    #                            dataset_defaults().keys()))

    #     else:
    #         eval_data = load_eval_data(
    #             file_path=args.eval_data_dir,
    #             batch_size=args.eval_batch_size,
    #             reso=args.image_size,
    #             reso_encoder=args.image_size_encoder,  # 224 -> 128
    #             num_workers=args.num_workers,
    #             # load_depth=True,  # for evaluation
    #             **args_to_dict(args,
    #                            dataset_defaults().keys()))

    TrainLoop = {
        'adm': nsr.TrainLoop3DDiffusion,
        'vpsde_crossattn': nsr.lsgm.TrainLoop3DDiffusionLSGM_crossattn,
    }[args.trainer_name]

    # continuous
    if 'vpsde' in args.trainer_name:
        sde_diffusion = make_sde_diffusion(
            dnnlib.EasyDict(
                args_to_dict(args,
                             continuous_diffusion_defaults().keys())))
        assert args.mixed_prediction, 'enable mixed_prediction by default'
        logger.log('create VPSDE diffusion.')
    else:
        sde_diffusion = None

    # if 'cldm' in args.trainer_name:
    #     assert isinstance(denoise_model, tuple)
    #     denoise_model, controlNet = denoise_model

    #     controlNet.to(dist_util.dev())
    #     controlNet.train()
    # else:
    #     controlNet = None

    training_loop_class = TrainLoop(rec_model=auto_encoder,
                                    denoise_model=denoise_model,
                                    control_model=controlNet,
                                    diffusion=diffusion,
                                    sde_diffusion=sde_diffusion,
                                    loss_class=None,
                                    data=None,
                                    eval_data=eval_data,
                                    **vars(args))

    logger.log("sampling...")
    dist_util.synchronize()

    # all_images = []
    # all_labels = []
    # while len(all_images) * args.batch_size < args.num_samples:

    if dist_util.get_rank() == 0:

        (Path(logger.get_dir()) / 'FID_Cals').mkdir(exist_ok=True,
                                                    parents=True)

        with open(os.path.join(args.logdir, 'args.json'), 'w') as f:
            json.dump(vars(args), f, indent=2)

        # load eval pose
        if args.cfg == 'ffhq':
            camera = th.load('assets/ffhq_eval_pose.pt',
                            map_location=dist_util.dev())[:]
        elif args.cfg == 'shapenet':
            camera = th.load('assets/shapenet_eval_pose.pt',
                            map_location=dist_util.dev())[:]

        for sample_idx in trange(args.num_samples):
            model_kwargs = {}

            # if args.class_cond:
            #     classes = th.randint(low=0,
            #                          high=NUM_CLASSES,
            #                          size=(args.batch_size, ),
            #                          device=dist_util.dev())
            #     model_kwargs["y"] = classes
            training_loop_class.step = sample_idx  # save to different position
            if args.create_controlnet or 'crossattn' in args.trainer_name:
                training_loop_class.eval_cldm(
                    prompt=args.prompt,
                    unconditional_guidance_scale=args.
                    unconditional_guidance_scale,
                    use_ddim=args.use_ddim,
                    save_img=args.save_img,
                    use_train_trajectory=args.use_train_trajectory,
                    export_mesh=args.export_mesh,
                    camera=camera,
                    overwrite_diff_inp_size=args.overwrite_diff_inp_size,
                    # training_loop_class.rec_model,
                    # training_loop_class.ddpm_model
                )
            else:
                # evaluate ldm
                training_loop_class.eval_ddpm_sample(
                    training_loop_class.rec_model,
                    save_img=args.save_img,
                    use_train_trajectory=args.use_train_trajectory,
                    export_mesh=args.export_mesh,
                    camera=camera,
                    # training_loop_class.ddpm_model
                )

    dist.barrier()
    logger.log("sampling complete")


def create_argparser():
    defaults = dict(
        image_size_encoder=224,
        triplane_scaling_divider=1.0,  # divide by this value
        diffusion_input_size=-1,
        trainer_name='adm',
        use_amp=False,
        # triplane_scaling_divider=1.0, # divide by this value

        # * sampling flags
        clip_denoised=False,
        num_samples=10,
        use_ddim=False,
        ddpm_model_path="",
        cldm_model_path="",
        rec_model_path="",

        # * eval logging flags
        logdir="/mnt/lustre/yslan/logs/nips23/",
        data_dir="",
        eval_data_dir="",
        eval_batch_size=1,
        num_workers=1,

        # * training flags for loading TrainingLoop class
        overfitting=False,
        image_size=128,
        iterations=150000,
        schedule_sampler="uniform",
        anneal_lr=False,
        lr=5e-5,
        weight_decay=0.0,
        lr_anneal_steps=0,
        batch_size=1,
        microbatch=-1,  # -1 disables microbatches
        ema_rate="0.9999",  # comma-separated list of EMA values
        log_interval=50,
        eval_interval=2500,
        save_interval=10000,
        resume_checkpoint="",
        resume_cldm_checkpoint="",
        resume_checkpoint_EG3D="",
        use_fp16=False,
        fp16_scale_growth=1e-3,
        load_submodule_name='',  # for loading pretrained auto_encoder model
        ignore_resume_opt=False,
        freeze_ae=False,
        denoised_ae=True,
        # inference prompt
        prompt="a red chair",
        interval=1,
        objv_dataset=False,
        use_lmdb=False,
        save_img=False,
        use_train_trajectory=
        False,  # use train trajectory to sample images for fid calculation
        unconditional_guidance_scale=1.0,
        cond_key='img_sr',
        use_eos_feature=False,
        export_mesh=False,
        overwrite_diff_inp_size=None,
        allow_tf32=True,
    )

    defaults.update(model_and_diffusion_defaults())
    defaults.update(encoder_and_nsr_defaults())  # type: ignore
    defaults.update(loss_defaults())
    defaults.update(continuous_diffusion_defaults())
    defaults.update(control_net_defaults())
    defaults.update(dataset_defaults())

    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)

    return parser


if __name__ == "__main__":

    # os.environ["TORCH_CPP_LOG_LEVEL"] = "INFO"
    # os.environ["NCCL_DEBUG"] = "INFO"

    os.environ[
        "TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"  # set to DETAIL for runtime logging.

    args = create_argparser().parse_args()

    args.local_rank = int(os.environ["LOCAL_RANK"])
    args.gpus = th.cuda.device_count()

    args.rendering_kwargs = rendering_options_defaults(args)

    main(args)