import os
import json
import random
from PIL import Image
import torch
from typing import List, Tuple, Union
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms as T
from onediffusion.dataset.utils import *
import glob

from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras
from onediffusion.dataset.transforms import CenterCropResizeImage
from pytorch3d.renderer import PerspectiveCameras

import numpy as np

def _cameras_from_opencv_projection(
    R: torch.Tensor,
    tvec: torch.Tensor,
    camera_matrix: torch.Tensor,
    image_size: torch.Tensor,
    do_normalize_cameras,
    normalize_scale,
) -> PerspectiveCameras:
    focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
    principal_point = camera_matrix[:, :2, 2]

    # Retype the image_size correctly and flip to width, height.
    image_size_wh = image_size.to(R).flip(dims=(1,))

    # Screen to NDC conversion:
    # For non square images, we scale the points such that smallest side
    # has range [-1, 1] and the largest side has range [-u, u], with u > 1.
    # This convention is consistent with the PyTorch3D renderer, as well as
    # the transformation function `get_ndc_to_screen_transform`.
    scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
    scale = scale.expand(-1, 2)
    c0 = image_size_wh / 2.0

    # Get the PyTorch3D focal length and principal point.
    focal_pytorch3d = focal_length / scale
    p0_pytorch3d = -(principal_point - c0) / scale

    # For R, T we flip x, y axes (opencv screen space has an opposite
    # orientation of screen axes).
    # We also transpose R (opencv multiplies points from the opposite=left side).
    R_pytorch3d = R.clone().permute(0, 2, 1)
    T_pytorch3d = tvec.clone()
    R_pytorch3d[:, :, :2] *= -1
    T_pytorch3d[:, :2] *= -1

    cams = PerspectiveCameras(
        R=R_pytorch3d,
        T=T_pytorch3d,
        focal_length=focal_pytorch3d,
        principal_point=p0_pytorch3d,
        image_size=image_size,
        device=R.device,
    )
    
    if do_normalize_cameras:
        cams, _ = normalize_cameras(cams, scale=normalize_scale)
    
    cams = first_camera_transform(cams, rotation_only=False)
    return cams

def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0):
    cameras = _cameras_from_opencv_projection(
        R=Rs,
        tvec=Ts,
        camera_matrix=Ks,
        image_size=sizes,
        do_normalize_cameras=do_normalize_cameras,
        normalize_scale=normalize_scale
    )
        
    rays_embedding = cameras_to_rays(
        cameras=cameras,
        num_patches_x=target_size,
        num_patches_y=target_size,
        crop_parameters=None,
        use_plucker=use_plucker
    )
        
    return rays_embedding.rays

def convert_rgba_to_rgb_white_bg(image):
    """Convert RGBA image to RGB with white background"""
    if image.mode == 'RGBA':
        # Create a white background
        background = Image.new('RGBA', image.size, (255, 255, 255, 255))
        # Composite the image onto the white background
        return Image.alpha_composite(background, image).convert('RGB')
    return image.convert('RGB')

class MultiviewDataset(Dataset):
    def __init__(
        self, 
        scene_folders: str, 
        samples_per_set: Union[int, Tuple[int, int]],  # Changed from samples_per_set to samples_range
        transform=None, 
        caption_keys: Union[str, List] = "caption",
        multiscale=False, 
        aspect_ratio_type=ASPECT_RATIO_512,
        c2w_scaling=1.7,
        default_max_distance=1, # default max distance from all camera of a scene ,
        do_normalize=True, # whether normalize translation of c2w with max_distance
        swap_xz=False, # whether swap x and z axis of 3D scenes
        valid_paths: str = "",
        frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different
    ):
        if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list):
            samples_per_set = (samples_per_set, samples_per_set)
        self.samples_range = samples_per_set  # Tuple of (min_samples, max_samples)
        self.transform = transform
        self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys]
        self.aspect_ratio = aspect_ratio_type
        self.scene_folders = sorted(glob.glob(scene_folders))
        # filter out scene folders that do not have transforms.json
        self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders))

        # if valid_paths.txt exists, only use paths in that file
        if os.path.exists(valid_paths):
            with open(valid_paths, 'r') as f:
                valid_scene_folders = f.read().splitlines()
            self.scene_folders = sorted(valid_scene_folders)
            
        self.c2w_scaling = c2w_scaling
        self.do_normalize = do_normalize
        self.default_max_distance = default_max_distance
        self.swap_xz = swap_xz
        self.frame_sliding_windows = frame_sliding_windows
        
        if multiscale:
            assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880]
            if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
                self.interpolate_model = T.InterpolationMode.LANCZOS
            self.ratio_index = {}
            self.ratio_nums = {}
            for k, v in self.aspect_ratio.items():
                self.ratio_index[float(k)] = []     # used for self.getitem
                self.ratio_nums[float(k)] = 0      # used for batch-sampler

    def __len__(self):
        return len(self.scene_folders)

    def __getitem__(self, idx):
        try:
            scene_path = self.scene_folders[idx]

            if os.path.exists(os.path.join(scene_path, "images")):
                image_folder = os.path.join(scene_path, "images")
                downscale_factor = 1
            elif os.path.exists(os.path.join(scene_path, "images_4")):
                image_folder = os.path.join(scene_path, "images_4")
                downscale_factor = 1 / 4
            elif os.path.exists(os.path.join(scene_path, "images_8")):
                image_folder = os.path.join(scene_path, "images_8")
                downscale_factor = 1 / 8
            else:
                raise NotImplementedError
            
            json_path = os.path.join(scene_path, "transforms.json")
            caption_path = os.path.join(scene_path, "caption.json")
            image_files = os.listdir(image_folder)
            
            with open(json_path, 'r') as f:
                json_data = json.load(f)
                height, width = json_data['h'], json_data['w']
                
                dh, dw = int(height * downscale_factor), int(width * downscale_factor)
                fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor
                cx = dw // 2
                cy = dh // 2
                
                frame_list = json_data['frames']
            
            # Randomly select number of samples
            
            samples_per_set = random.randint(self.samples_range[0], self.samples_range[1])
            
            # uniformly for all scenes
            if self.frame_sliding_windows is None:
                selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list)))
            # limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles)
            else:
                # Determine the starting index of the sliding window
                if len(frame_list) <= self.frame_sliding_windows:
                    # If the frame list is smaller than or equal to X, use the entire list
                    window_start = 0
                    window_end = len(frame_list)
                else:
                    # Randomly select a starting point for the window
                    window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows)
                    window_end = window_start + self.frame_sliding_windows

                # Get the indices within the sliding window
                window_indices = list(range(window_start, window_end))

                # Randomly sample indices from the window
                selected_indices = random.sample(window_indices, samples_per_set)
            
            image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices]
            image_paths = [os.path.join(image_folder, file) for file in image_files]
            
            # Load images and convert RGBA to RGB with white background
            images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths]
            
            if self.transform:
                images = [self.transform(image) for image in images]
            else:
                closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0
                closest_size = tuple(map(int, closest_size))
                transform = T.Compose([
                            T.ToTensor(),
                            CenterCropResizeImage(closest_size),
                            T.Normalize([.5], [.5]),
                        ])
                images = [transform(image) for image in images]
            images = torch.stack(images)
            
            c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices]
            c2ws = torch.tensor(c2ws).reshape(-1, 4, 4)
            # max_distance = json_data.get('max_distance', self.default_max_distance)
            # if 'max_distance' not in json_data.keys():
                # print(f"not found `max_distance` in json path: {json_path}")

            if self.swap_xz:
                swap_xz = torch.tensor([[[0, 0, 1., 0],
                        [0, 1., 0, 0],
                        [-1., 0, 0, 0],
                        [0, 0, 0, 1.]]])
                c2ws = swap_xz @ c2ws
            
            # OPENGL to OPENCV
            c2ws[:, 0:3, 1:3] *= -1
            c2ws = c2ws[:, [1, 0, 2, 3], :]
            c2ws[:, 2, :] *= -1

            w2cs = torch.inverse(c2ws)
            K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1)
            Rs = w2cs[:, :3, :3]
            Ts = w2cs[:, :3, 3]
            sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1)
            
            # get ray embedding and padding last dimension to 16 (num channels of VAE)
            # rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
            rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
            rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6)
            # padding = (0, 10)  # pad the last dimension to 16
            # rays = torch.nn.functional.pad(rays, padding, "constant", 0)
            rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658
            
            if os.path.exists(caption_path):
                with open(caption_path, 'r') as f:
                    caption_key = random.choice(self.caption_keys)
                    caption = json.load(f).get(caption_key, "")
            else:
                caption = ""
            
            caption = "[[multiview]] " + caption if caption else "[[multiview]]"
            
            return {
                'pixel_values': images,
                'rays': rays,
                'aspect_ratio': closest_ratio,
                'caption': caption,
                'height': dh,
                'width': dw,
                # 'origins': rays_od[..., :3],
                # 'dirs': rays_od[..., 3:6]
            }
        except Exception as e:
            return self.__getitem__(random.randint(0, len(self.scene_folders) - 1))