File size: 12,250 Bytes
07d760c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import os
import json
import random
from PIL import Image
import torch
from typing import List, Tuple, Union
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms as T
from onediffusion.dataset.utils import *
import glob

from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras
from onediffusion.dataset.transforms import CenterCropResizeImage
from pytorch3d.renderer import PerspectiveCameras

import numpy as np

def _cameras_from_opencv_projection(
    R: torch.Tensor,
    tvec: torch.Tensor,
    camera_matrix: torch.Tensor,
    image_size: torch.Tensor,
    do_normalize_cameras,
    normalize_scale,
) -> PerspectiveCameras:
    focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
    principal_point = camera_matrix[:, :2, 2]

    # Retype the image_size correctly and flip to width, height.
    image_size_wh = image_size.to(R).flip(dims=(1,))

    # Screen to NDC conversion:
    # For non square images, we scale the points such that smallest side
    # has range [-1, 1] and the largest side has range [-u, u], with u > 1.
    # This convention is consistent with the PyTorch3D renderer, as well as
    # the transformation function `get_ndc_to_screen_transform`.
    scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0
    scale = scale.expand(-1, 2)
    c0 = image_size_wh / 2.0

    # Get the PyTorch3D focal length and principal point.
    focal_pytorch3d = focal_length / scale
    p0_pytorch3d = -(principal_point - c0) / scale

    # For R, T we flip x, y axes (opencv screen space has an opposite
    # orientation of screen axes).
    # We also transpose R (opencv multiplies points from the opposite=left side).
    R_pytorch3d = R.clone().permute(0, 2, 1)
    T_pytorch3d = tvec.clone()
    R_pytorch3d[:, :, :2] *= -1
    T_pytorch3d[:, :2] *= -1

    cams = PerspectiveCameras(
        R=R_pytorch3d,
        T=T_pytorch3d,
        focal_length=focal_pytorch3d,
        principal_point=p0_pytorch3d,
        image_size=image_size,
        device=R.device,
    )
    
    if do_normalize_cameras:
        cams, _ = normalize_cameras(cams, scale=normalize_scale)
    
    cams = first_camera_transform(cams, rotation_only=False)
    return cams

def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0):
    cameras = _cameras_from_opencv_projection(
        R=Rs,
        tvec=Ts,
        camera_matrix=Ks,
        image_size=sizes,
        do_normalize_cameras=do_normalize_cameras,
        normalize_scale=normalize_scale
    )
        
    rays_embedding = cameras_to_rays(
        cameras=cameras,
        num_patches_x=target_size,
        num_patches_y=target_size,
        crop_parameters=None,
        use_plucker=use_plucker
    )
        
    return rays_embedding.rays

def convert_rgba_to_rgb_white_bg(image):
    """Convert RGBA image to RGB with white background"""
    if image.mode == 'RGBA':
        # Create a white background
        background = Image.new('RGBA', image.size, (255, 255, 255, 255))
        # Composite the image onto the white background
        return Image.alpha_composite(background, image).convert('RGB')
    return image.convert('RGB')

class MultiviewDataset(Dataset):
    def __init__(
        self, 
        scene_folders: str, 
        samples_per_set: Union[int, Tuple[int, int]],  # Changed from samples_per_set to samples_range
        transform=None, 
        caption_keys: Union[str, List] = "caption",
        multiscale=False, 
        aspect_ratio_type=ASPECT_RATIO_512,
        c2w_scaling=1.7,
        default_max_distance=1, # default max distance from all camera of a scene ,
        do_normalize=True, # whether normalize translation of c2w with max_distance
        swap_xz=False, # whether swap x and z axis of 3D scenes
        valid_paths: str = "",
        frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different
    ):
        if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list):
            samples_per_set = (samples_per_set, samples_per_set)
        self.samples_range = samples_per_set  # Tuple of (min_samples, max_samples)
        self.transform = transform
        self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys]
        self.aspect_ratio = aspect_ratio_type
        self.scene_folders = sorted(glob.glob(scene_folders))
        # filter out scene folders that do not have transforms.json
        self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders))

        # if valid_paths.txt exists, only use paths in that file
        if os.path.exists(valid_paths):
            with open(valid_paths, 'r') as f:
                valid_scene_folders = f.read().splitlines()
            self.scene_folders = sorted(valid_scene_folders)
            
        self.c2w_scaling = c2w_scaling
        self.do_normalize = do_normalize
        self.default_max_distance = default_max_distance
        self.swap_xz = swap_xz
        self.frame_sliding_windows = frame_sliding_windows
        
        if multiscale:
            assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880]
            if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]:
                self.interpolate_model = T.InterpolationMode.LANCZOS
            self.ratio_index = {}
            self.ratio_nums = {}
            for k, v in self.aspect_ratio.items():
                self.ratio_index[float(k)] = []     # used for self.getitem
                self.ratio_nums[float(k)] = 0      # used for batch-sampler

    def __len__(self):
        return len(self.scene_folders)

    def __getitem__(self, idx):
        try:
            scene_path = self.scene_folders[idx]

            if os.path.exists(os.path.join(scene_path, "images")):
                image_folder = os.path.join(scene_path, "images")
                downscale_factor = 1
            elif os.path.exists(os.path.join(scene_path, "images_4")):
                image_folder = os.path.join(scene_path, "images_4")
                downscale_factor = 1 / 4
            elif os.path.exists(os.path.join(scene_path, "images_8")):
                image_folder = os.path.join(scene_path, "images_8")
                downscale_factor = 1 / 8
            else:
                raise NotImplementedError
            
            json_path = os.path.join(scene_path, "transforms.json")
            caption_path = os.path.join(scene_path, "caption.json")
            image_files = os.listdir(image_folder)
            
            with open(json_path, 'r') as f:
                json_data = json.load(f)
                height, width = json_data['h'], json_data['w']
                
                dh, dw = int(height * downscale_factor), int(width * downscale_factor)
                fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor
                cx = dw // 2
                cy = dh // 2
                
                frame_list = json_data['frames']
            
            # Randomly select number of samples
            
            samples_per_set = random.randint(self.samples_range[0], self.samples_range[1])
            
            # uniformly for all scenes
            if self.frame_sliding_windows is None:
                selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list)))
            # limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles)
            else:
                # Determine the starting index of the sliding window
                if len(frame_list) <= self.frame_sliding_windows:
                    # If the frame list is smaller than or equal to X, use the entire list
                    window_start = 0
                    window_end = len(frame_list)
                else:
                    # Randomly select a starting point for the window
                    window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows)
                    window_end = window_start + self.frame_sliding_windows

                # Get the indices within the sliding window
                window_indices = list(range(window_start, window_end))

                # Randomly sample indices from the window
                selected_indices = random.sample(window_indices, samples_per_set)
            
            image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices]
            image_paths = [os.path.join(image_folder, file) for file in image_files]
            
            # Load images and convert RGBA to RGB with white background
            images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths]
            
            if self.transform:
                images = [self.transform(image) for image in images]
            else:
                closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0
                closest_size = tuple(map(int, closest_size))
                transform = T.Compose([
                            T.ToTensor(),
                            CenterCropResizeImage(closest_size),
                            T.Normalize([.5], [.5]),
                        ])
                images = [transform(image) for image in images]
            images = torch.stack(images)
            
            c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices]
            c2ws = torch.tensor(c2ws).reshape(-1, 4, 4)
            # max_distance = json_data.get('max_distance', self.default_max_distance)
            # if 'max_distance' not in json_data.keys():
                # print(f"not found `max_distance` in json path: {json_path}")

            if self.swap_xz:
                swap_xz = torch.tensor([[[0, 0, 1., 0],
                        [0, 1., 0, 0],
                        [-1., 0, 0, 0],
                        [0, 0, 0, 1.]]])
                c2ws = swap_xz @ c2ws
            
            # OPENGL to OPENCV
            c2ws[:, 0:3, 1:3] *= -1
            c2ws = c2ws[:, [1, 0, 2, 3], :]
            c2ws[:, 2, :] *= -1

            w2cs = torch.inverse(c2ws)
            K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1)
            Rs = w2cs[:, :3, :3]
            Ts = w2cs[:, :3, 3]
            sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1)
            
            # get ray embedding and padding last dimension to 16 (num channels of VAE)
            # rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
            rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling)
            rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6)
            # padding = (0, 10)  # pad the last dimension to 16
            # rays = torch.nn.functional.pad(rays, padding, "constant", 0)
            rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658
            
            if os.path.exists(caption_path):
                with open(caption_path, 'r') as f:
                    caption_key = random.choice(self.caption_keys)
                    caption = json.load(f).get(caption_key, "")
            else:
                caption = ""
            
            caption = "[[multiview]] " + caption if caption else "[[multiview]]"
            
            return {
                'pixel_values': images,
                'rays': rays,
                'aspect_ratio': closest_ratio,
                'caption': caption,
                'height': dh,
                'width': dw,
                # 'origins': rays_od[..., :3],
                # 'dirs': rays_od[..., 3:6]
            }
        except Exception as e:
            return self.__getitem__(random.randint(0, len(self.scene_folders) - 1))