# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# Dataloader for Spring
# --------------------------------------------------------
import os.path as osp
from glob import glob
import itertools
import numpy as np
import re
import cv2
import os

from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from dust3r.utils.image import imread_cv2


def readPFM(file):
    file = open(file, 'rb')

    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().rstrip()
    if header == b'PF':
        color = True
    elif header == b'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>'  # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    return data


class TartanairDatasets(BaseStereoViewDataset):
    def __init__(self, *args, split, ROOT, **kwargs):
        self.ROOT = ROOT                        # ROOT = "/media/tyhuang/T9/videodepth_data/spring_proc/train"
        super().__init__(*args, **kwargs)

        self.dataset_label = 'Tartanair'
        test_scenes = []

        scene_list = []
        for scene in os.listdir(ROOT):
            #scene_list.append(osp.join(ROOT, scene))
            if scene not in test_scenes and split == 'train':
                if 'Hard' not in scene:
                  scene_list.append(osp.join(ROOT, scene))
            if scene in test_scenes and split == 'test':
                if 'Hard' not in scene:
                  scene_list.append(osp.join(ROOT, scene))


        self.pair_dict = {}
        pair_num = 0
        for scene in scene_list:
            imgs = sorted(glob(osp.join(scene, '*_rgb.jpg')))

            len_imgs = len(imgs)
            # combinations = [(i, j) for i, j in itertools.combinations(range(len_imgs), 2)
            #                 if abs(i - j) <= 10 or (abs(i - j) <= 20 and abs(i - j) % 3 == 0)]
            combinations = [(i, j) for i, j in itertools.combinations(range(len_imgs), 2) if abs(i - j) <= 10 ]
            for (i, j) in combinations:
                self.pair_dict[pair_num] = [imgs[i], imgs[j]]
                pair_num += 1

    def __len__(self):
        return len(self.pair_dict)


    def _get_views(self, idx, resolution, rng):

        views = []
        for img_path in self.pair_dict[idx]:
            rgb_image = imread_cv2(img_path)

            depthmap_path = img_path.replace('_rgb.jpg', '_depth.pfm')
            mask_path = img_path.replace('_rgb.jpg', '_mask.png')
            metadata_path = img_path.replace('_rgb.jpg', '_metadata.npz')
            pred_depth = np.load(img_path.replace('.jpg', '_pred_depth_' + self.depth_prior_name + '.npz'))#['depth']
            focal_length_px = pred_depth['focallength_px']
            pred_depth = pred_depth['depth']
            pred_depth = self.pixel_to_pointcloud(pred_depth, focal_length_px)
            depthmap = readPFM(depthmap_path)

            maskmap = imread_cv2(mask_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
            maskmap = (maskmap / 255.0) > 0.1
            #maskmap = maskmap * (depthmap<100)
            depthmap *= maskmap

            metadata = np.load(metadata_path)
            intrinsics = np.float32(metadata['camera_intrinsics'])
            camera_pose = np.float32(metadata['camera_pose'])
            # max_depth = np.float32(metadata['maximum_depth'])
            #pred_depth = depthmap.copy()
            # depthmap = (depthmap.astype(np.float32) / 10.0)
            # pred_depth = pred_depth#/20.0
            # camera_pose[:3, 3] /= 10.0

            rgb_image, depthmap, pred_depth, intrinsics = self._crop_resize_if_necessary(
                rgb_image, depthmap, pred_depth, intrinsics, resolution, rng=rng, info=img_path)

            num_valid = (depthmap > 0.0).sum()
            # if num_valid==0:
            #   depthmap +=1
            #assert num_valid > 0
            # if num_valid==0:
            #   depthmap +=0.001
            views.append(dict(
                img=rgb_image,
                depthmap=depthmap,
                camera_pose=camera_pose,
                camera_intrinsics=intrinsics,
                dataset=self.dataset_label,
                label=img_path,
                instance=img_path,
                pred_depth=pred_depth
            ))
        return views


if __name__ == "__main__":
    from dust3r.datasets.base.base_stereo_view_dataset import view_name
    from dust3r.viz import SceneViz, auto_cam_size
    from dust3r.utils.image import rgb

    dataset = SpringDatasets(split='train', ROOT="/media/8TB/tyhuang/video_depth/spring_proc/train", resolution=512, aug_crop=16)

    a = len(dataset)
    for idx in np.random.permutation(len(dataset)):
        views = dataset[idx]
        assert len(views) == 2
        print(view_name(views[0]), view_name(views[1]))
        viz = SceneViz()
        poses = [views[view_idx]['camera_pose'] for view_idx in [0, 1]]
        cam_size = max(auto_cam_size(poses), 0.001)
        for view_idx in [0, 1]:
            pts3d = views[view_idx]['pts3d']
            valid_mask = views[view_idx]['valid_mask']
            colors = rgb(views[view_idx]['img'])
            viz.add_pointcloud(pts3d, colors, valid_mask)
            viz.add_camera(pose_c2w=views[view_idx]['camera_pose'],
                           focal=views[view_idx]['camera_intrinsics'][0, 0],
                           color=(idx * 255, (1 - idx) * 255, 0),
                           image=colors,
                           cam_size=cam_size)
        viz.show()