# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""Streaming images and labels from datasets created with dataset_tool.py."""

import cv2
import os
import numpy as np
import zipfile
import PIL.Image
import json
import torch
import dnnlib
from torchvision import transforms

from pdb import set_trace as st

from .shapenet import LMDBDataset_MV_Compressed, decompress_array

try:
    import pyspng
except ImportError:
    pyspng = None

#----------------------------------------------------------------------------


# copide from eg3d/train.py
def init_dataset_kwargs(data,
                        class_name='datasets.eg3d_dataset.ImageFolderDataset',
                        reso_gt=128):
    # try:
    # if data == 'None':
    #     dataset_kwargs = dnnlib.EasyDict({})  #
    #     dataset_kwargs.name = 'eg3d_dataset'
    #     dataset_kwargs.resolution = 128
    #     dataset_kwargs.use_labels = False
    #     dataset_kwargs.max_size = 70000
    #     return dataset_kwargs, 'eg3d_dataset'

    dataset_kwargs = dnnlib.EasyDict(class_name=class_name,
                                     reso_gt=reso_gt,
                                     path=data,
                                     use_labels=True,
                                     max_size=None,
                                     xflip=False)
    dataset_obj = dnnlib.util.construct_class_by_name(
        **dataset_kwargs)  # Subclass of training.dataset.Dataset.
    dataset_kwargs.resolution = dataset_obj.resolution  # Be explicit about resolution.
    dataset_kwargs.use_labels = dataset_obj.has_labels  # Be explicit about labels.
    dataset_kwargs.max_size = len(
        dataset_obj)  # Be explicit about dataset size.

    return dataset_kwargs, dataset_obj.name
    # except IOError as err:
    #     raise click.ClickException(f'--data: {err}')


class Dataset(torch.utils.data.Dataset):

    def __init__(
            self,
            name,  # Name of the dataset.
            raw_shape,  # Shape of the raw image data (NCHW).
            reso_gt=128,
            max_size=None,  # Artificially limit the size of the dataset. None = no limit. Applied before xflip.
            use_labels=False,  # Enable conditioning labels? False = label dimension is zero.
            xflip=False,  # Artificially double the size of the dataset via x-flips. Applied after max_size.
            random_seed=0,  # Random seed to use when applying max_size.
    ):
        self._name = name
        self._raw_shape = list(raw_shape)
        self._use_labels = use_labels
        self._raw_labels = None
        self._label_shape = None

        # self.reso_gt = 128
        self.reso_gt = reso_gt  # ! hard coded
        self.reso_encoder = 224

        # Apply max_size.
        self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64)
        # self._raw_idx = np.arange(self.__len__(), dtype=np.int64)
        if (max_size is not None) and (self._raw_idx.size > max_size):
            np.random.RandomState(random_seed).shuffle(self._raw_idx)
            self._raw_idx = np.sort(self._raw_idx[:max_size])

        # Apply xflip.
        self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8)
        if xflip:
            self._raw_idx = np.tile(self._raw_idx, 2)
            self._xflip = np.concatenate(
                [self._xflip, np.ones_like(self._xflip)])

        # dino encoder normalizer
        self.normalize_for_encoder_input = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            transforms.Resize(size=(self.reso_encoder, self.reso_encoder),
                              antialias=True),  # type: ignore
        ])

        self.normalize_for_gt = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            transforms.Resize(size=(self.reso_gt, self.reso_gt),
                              antialias=True),  # type: ignore
        ])

    def _get_raw_labels(self):
        if self._raw_labels is None:
            self._raw_labels = self._load_raw_labels(
            ) if self._use_labels else None
            if self._raw_labels is None:
                self._raw_labels = np.zeros([self._raw_shape[0], 0],
                                            dtype=np.float32)
            assert isinstance(self._raw_labels, np.ndarray)
            # assert self._raw_labels.shape[0] == self._raw_shape[0]
            assert self._raw_labels.dtype in [np.float32, np.int64]
            if self._raw_labels.dtype == np.int64:
                assert self._raw_labels.ndim == 1
                assert np.all(self._raw_labels >= 0)
            self._raw_labels_std = self._raw_labels.std(0)
        return self._raw_labels

    def close(self):  # to be overridden by subclass
        pass

    def _load_raw_image(self, raw_idx):  # to be overridden by subclass
        raise NotImplementedError

    def _load_raw_labels(self):  # to be overridden by subclass
        raise NotImplementedError

    def __getstate__(self):
        return dict(self.__dict__, _raw_labels=None)

    def __del__(self):
        try:
            self.close()
        except:
            pass

    def __len__(self):
        return self._raw_idx.size
        # return self._get_raw_labels().shape[0]

    def __getitem__(self, idx):
        # print(self._raw_idx[idx], idx)

        matte = self._load_raw_matte(self._raw_idx[idx])
        assert isinstance(matte, np.ndarray)
        assert list(matte.shape)[1:] == self.image_shape[1:]
        if self._xflip[idx]:
            assert matte.ndim == 1  # CHW
            matte = matte[:, :, ::-1]
        # matte_orig = matte.copy().astype(np.float32) / 255
        matte_orig = matte.copy().astype(np.float32) # segmentation version
        # assert matte_orig.max() == 1
        matte = np.transpose(matte,
                            #  (1, 2, 0)).astype(np.float32) / 255  # [0,1] range
                             (1, 2, 0)).astype(np.float32)  # [0,1] range
        matte = cv2.resize(matte, (self.reso_gt, self.reso_gt),
                           interpolation=cv2.INTER_NEAREST)
        assert matte.min() >= 0 and matte.max(
        ) <= 1, f'{matte.min(), matte.max()}'

        if matte.ndim == 3:  # H, W
            matte = matte[..., 0]

        image = self._load_raw_image(self._raw_idx[idx])

        assert isinstance(image, np.ndarray)
        assert list(image.shape) == self.image_shape
        assert image.dtype == np.uint8
        if self._xflip[idx]:
            assert image.ndim == 3  # CHW
            image = image[:, :, ::-1]

        # blending
        # blending = True
        blending = False
        if blending:
            image = image * matte_orig + (1 - matte_orig) * cv2.GaussianBlur(
                image, (5, 5), cv2.BORDER_DEFAULT)
            # image = image * matte_orig

        image = np.transpose(image, (1, 2, 0)).astype(
            np.float32
        ) / 255  # H W C for torchvision process, normalize to [0,1]

        image_sr = torch.from_numpy(image)[..., :3].permute(
            2, 0, 1) * 2 - 1  # normalize to [-1,1]
        image_to_encoder = self.normalize_for_encoder_input(image)

        image_gt = cv2.resize(image, (self.reso_gt, self.reso_gt),
                              interpolation=cv2.INTER_AREA)
        image_gt = torch.from_numpy(image_gt)[..., :3].permute(
            2, 0, 1) * 2 - 1  # normalize to [-1,1]

        return dict(
            c=self.get_label(idx),
            img_to_encoder=image_to_encoder,  # 224
            img_sr=image_sr,  # 512
            img=image_gt,  # [-1,1] range
            # depth=torch.zeros_like(image_gt)[0, ...] # type: ignore
            depth=matte,
            depth_mask=matte,
            # depth_mask=matte > 0,
            # alpha=matte,
        )  # return dict here

    def get_label(self, idx):
        label = self._get_raw_labels()[self._raw_idx[idx]]
        if label.dtype == np.int64:
            onehot = np.zeros(self.label_shape, dtype=np.float32)
            onehot[label] = 1
            label = onehot
        return label.copy()

    def get_details(self, idx):
        d = dnnlib.EasyDict()
        d.raw_idx = int(self._raw_idx[idx])
        d.xflip = (int(self._xflip[idx]) != 0)
        d.raw_label = self._get_raw_labels()[d.raw_idx].copy()
        return d

    def get_label_std(self):
        return self._raw_labels_std

    @property
    def name(self):
        return self._name

    @property
    def image_shape(self):
        return list(self._raw_shape[1:])

    @property
    def num_channels(self):
        assert len(self.image_shape) == 3  # CHW
        return self.image_shape[0]

    @property
    def resolution(self):
        assert len(self.image_shape) == 3  # CHW
        assert self.image_shape[1] == self.image_shape[2]
        return self.image_shape[1]

    @property
    def label_shape(self):
        if self._label_shape is None:
            raw_labels = self._get_raw_labels()
            if raw_labels.dtype == np.int64:
                self._label_shape = [int(np.max(raw_labels)) + 1]
            else:
                self._label_shape = raw_labels.shape[1:]
        return list(self._label_shape)

    @property
    def label_dim(self):
        assert len(self.label_shape) == 1
        return self.label_shape[0]

    @property
    def has_labels(self):
        return any(x != 0 for x in self.label_shape)

    @property
    def has_onehot_labels(self):
        return self._get_raw_labels().dtype == np.int64


#----------------------------------------------------------------------------


class ImageFolderDataset(Dataset):

    def __init__(
            self,
            path,  # Path to directory or zip.
            resolution=None,  # Ensure specific resolution, None = highest available.
            reso_gt=128,
            **super_kwargs,  # Additional arguments for the Dataset base class.
    ):
        self._path = path
        # self._matte_path = path.replace('unzipped_ffhq_512',
        #                                 'unzipped_ffhq_matte')
        self._matte_path = path.replace('unzipped_ffhq_512',
                                        'ffhq_512_seg')
        self._zipfile = None

        if os.path.isdir(self._path):
            self._type = 'dir'
            self._all_fnames = {
                os.path.relpath(os.path.join(root, fname), start=self._path)
                for root, _dirs, files in os.walk(self._path)
                for fname in files
            }
        elif self._file_ext(self._path) == '.zip':
            self._type = 'zip'
            self._all_fnames = set(self._get_zipfile().namelist())
        else:
            raise IOError('Path must point to a directory or zip')

        PIL.Image.init()
        self._image_fnames = sorted(
            fname for fname in self._all_fnames
            if self._file_ext(fname) in PIL.Image.EXTENSION)
        if len(self._image_fnames) == 0:
            raise IOError('No image files found in the specified path')

        name = os.path.splitext(os.path.basename(self._path))[0]
        raw_shape = [len(self._image_fnames)] + list(
            self._load_raw_image(0).shape)
        # raw_shape = [len(self._image_fnames)] + list(
        #     self._load_raw_image(0).shape)
        if resolution is not None and (raw_shape[2] != resolution
                                       or raw_shape[3] != resolution):
            raise IOError('Image files do not match the specified resolution')
        super().__init__(name=name,
                         raw_shape=raw_shape,
                         reso_gt=reso_gt,
                         **super_kwargs)

    @staticmethod
    def _file_ext(fname):
        return os.path.splitext(fname)[1].lower()

    def _get_zipfile(self):
        assert self._type == 'zip'
        if self._zipfile is None:
            self._zipfile = zipfile.ZipFile(self._path)
        return self._zipfile

    def _open_file(self, fname):
        if self._type == 'dir':
            return open(os.path.join(self._path, fname), 'rb')
        if self._type == 'zip':
            return self._get_zipfile().open(fname, 'r')
        return None

    def _open_matte_file(self, fname):
        if self._type == 'dir':
            return open(os.path.join(self._matte_path, fname), 'rb')
        # if self._type == 'zip':
        #     return self._get_zipfile().open(fname, 'r')
        # return None

    def close(self):
        try:
            if self._zipfile is not None:
                self._zipfile.close()
        finally:
            self._zipfile = None

    def __getstate__(self):
        return dict(super().__getstate__(), _zipfile=None)

    def _load_raw_image(self, raw_idx):
        fname = self._image_fnames[raw_idx]
        with self._open_file(fname) as f:
            if pyspng is not None and self._file_ext(fname) == '.png':
                image = pyspng.load(f.read())
            else:
                image = np.array(PIL.Image.open(f))
        if image.ndim == 2:
            image = image[:, :, np.newaxis]  # HW => HWC
        image = image.transpose(2, 0, 1)  # HWC => CHW
        return image

    def _load_raw_matte(self, raw_idx):
        # ! from seg version
        fname = self._image_fnames[raw_idx]
        with self._open_matte_file(fname) as f:
            if pyspng is not None and self._file_ext(fname) == '.png':
                image = pyspng.load(f.read())
            else:
                image = np.array(PIL.Image.open(f))
        # if image.max() != 1:
        image = (image > 0).astype(np.float32) # process segmentation
        if image.ndim == 2:
            image = image[:, :, np.newaxis]  # HW => HWC
        image = image.transpose(2, 0, 1)  # HWC => CHW
        return image

    def _load_raw_matte_orig(self, raw_idx):
        fname = self._image_fnames[raw_idx]
        with self._open_matte_file(fname) as f:
            if pyspng is not None and self._file_ext(fname) == '.png':
                image = pyspng.load(f.read())
            else:
                image = np.array(PIL.Image.open(f))
        st() # process segmentation
        if image.ndim == 2:
            image = image[:, :, np.newaxis]  # HW => HWC
        image = image.transpose(2, 0, 1)  # HWC => CHW
        return image

    def _load_raw_labels(self):
        fname = 'dataset.json'
        if fname not in self._all_fnames:
            return None
        with self._open_file(fname) as f:
            # st()
            labels = json.load(f)['labels']
        if labels is None:
            return None
        labels = dict(labels)
        labels_ = []
        for fname, _ in labels.items():
            # if 'mirror' not in fname:
            labels_.append(labels[fname])
        labels = labels_
        # !
        # labels = [
        #     labels[fname.replace('\\', '/')] for fname in self._image_fnames
        # ]
        labels = np.array(labels)
        labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim])
        self._raw_labels = labels
        return labels


#----------------------------------------------------------------------------


# class ImageFolderDatasetUnzipped(ImageFolderDataset):

#     def __init__(self, path, resolution=None, **super_kwargs):
#         super().__init__(path, resolution, **super_kwargs)


# class ImageFolderDatasetPose(ImageFolderDataset):

#     def __init__(
#             self,
#             path,  # Path to directory or zip.
#             resolution=None,  # Ensure specific resolution, None = highest available.
#             **super_kwargs,  # Additional arguments for the Dataset base class.
#     ):
#         super().__init__(path, resolution, **super_kwargs)
#         # only return labels

#     def __len__(self):
#         return self._raw_idx.size
#         # return self._get_raw_labels().shape[0]

#     def __getitem__(self, idx):
#         # image = self._load_raw_image(self._raw_idx[idx])
#         # assert isinstance(image, np.ndarray)
#         # assert list(image.shape) == self.image_shape
#         # assert image.dtype == np.uint8
#         # if self._xflip[idx]:
#         # assert image.ndim == 3  # CHW
#         # image = image[:, :, ::-1]
#         return dict(c=self.get_label(idx), )  # return dict here


class ImageFolderDatasetLMDB(ImageFolderDataset):
    def __init__(self, path, resolution=None, reso_gt=128, **super_kwargs):
        super().__init__(path, resolution, reso_gt, **super_kwargs)
    
    def __getitem__(self, idx):
        # print(self._raw_idx[idx], idx)

        matte = self._load_raw_matte(self._raw_idx[idx])
        assert isinstance(matte, np.ndarray)
        assert list(matte.shape)[1:] == self.image_shape[1:]
        if self._xflip[idx]:
            assert matte.ndim == 1  # CHW
            matte = matte[:, :, ::-1]
        # matte_orig = matte.copy().astype(np.float32) / 255
        matte_orig = matte.copy().astype(np.float32) # segmentation version
        assert matte_orig.max() <= 1 # some ffhq images are dirty, so may be all zero
        matte = np.transpose(matte,
                            #  (1, 2, 0)).astype(np.float32) / 255  # [0,1] range
                             (1, 2, 0)).astype(np.float32)  # [0,1] range

        # ! load 512 matte
        # matte = cv2.resize(matte, (self.reso_gt, self.reso_gt),
        #                    interpolation=cv2.INTER_NEAREST)

        assert matte.min() >= 0 and matte.max(
        ) <= 1, f'{matte.min(), matte.max()}'

        if matte.ndim == 3:  # H, W
            matte = matte[..., 0]

        image = self._load_raw_image(self._raw_idx[idx])

        assert isinstance(image, np.ndarray)
        assert list(image.shape) == self.image_shape
        assert image.dtype == np.uint8
        if self._xflip[idx]:
            assert image.ndim == 3  # CHW
            image = image[:, :, ::-1]

        # blending
        # blending = True
        # blending = False
        # if blending:
        #     image = image * matte_orig + (1 - matte_orig) * cv2.GaussianBlur(
        #         image, (5, 5), cv2.BORDER_DEFAULT)
            # image = image * matte_orig

        # image = np.transpose(image, (1, 2, 0)).astype(
        #     np.float32
        # ) / 255  # H W C for torchvision process, normalize to [0,1]

        # image_sr = torch.from_numpy(image)[..., :3].permute(
        #     2, 0, 1) * 2 - 1  # normalize to [-1,1]
        # image_to_encoder = self.normalize_for_encoder_input(image)

        # image_gt = cv2.resize(image, (self.reso_gt, self.reso_gt),
        #                       interpolation=cv2.INTER_AREA)
        # image_gt = torch.from_numpy(image_gt)[..., :3].permute(
        #     2, 0, 1) * 2 - 1  # normalize to [-1,1]

        return dict(
            c=self.get_label(idx),
            # img_to_encoder=image_to_encoder,  # 224
            # img_sr=image_sr,  # 512
            img=image,  # [-1,1] range
            # depth=torch.zeros_like(image_gt)[0, ...] # type: ignore
            # depth=matte,
            depth_mask=matte,
        )  # return dict here

class LMDBDataset_MV_Compressed_eg3d(LMDBDataset_MV_Compressed):

    def __init__(self,
                 lmdb_path,
                 reso,
                 reso_encoder,
                 imgnet_normalize=True,
                 **kwargs):
        super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
                         **kwargs)

        self.normalize_for_encoder_input = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            transforms.Resize(size=(self.reso_encoder, self.reso_encoder),
                              antialias=True),  # type: ignore
        ])

        self.normalize_for_gt = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            transforms.Resize(size=(self.reso, self.reso),
                              antialias=True),  # type: ignore
        ])

    def __getitem__(self, idx):
        # sample = super(LMDBDataset).__getitem__(idx)

        # do gzip uncompress online
        with self.env.begin(write=False) as txn:
            img_key = f'{idx}-img'.encode('utf-8')
            image = self.load_image_fn(txn.get(img_key))

            depth_key = f'{idx}-depth_mask'.encode('utf-8')
            # depth = decompress_array(txn.get(depth_key), (512,512), np.float32)
            depth = decompress_array(txn.get(depth_key), (64,64), np.float32)

            c_key = f'{idx}-c'.encode('utf-8')
            c = decompress_array(txn.get(c_key), (25, ), np.float32)

        # ! post processing, e.g., normalizing
        depth = cv2.resize(depth, (self.reso, self.reso),
                           interpolation=cv2.INTER_NEAREST)

        image = np.transpose(image, (1, 2, 0)).astype(
            np.float32
        ) / 255  # H W C for torchvision process, normalize to [0,1]

        image_sr = torch.from_numpy(image)[..., :3].permute(
            2, 0, 1) * 2 - 1  # normalize to [-1,1]
        image_to_encoder = self.normalize_for_encoder_input(image)

        image_gt = cv2.resize(image, (self.reso, self.reso),
                              interpolation=cv2.INTER_AREA)
        image_gt = torch.from_numpy(image_gt)[..., :3].permute(
            2, 0, 1) * 2 - 1  # normalize to [-1,1]


        return {
            'img_to_encoder': image_to_encoder,  # 224
            'img_sr': image_sr,  # 512
            'img': image_gt,  # [-1,1] range
            'c': c,
            'depth': depth,
            'depth_mask': depth,
        }