Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import random | |
from PIL import Image | |
import torch | |
from typing import List, Tuple, Union | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
import torchvision.transforms as T | |
from onediffusion.dataset.utils import * | |
import glob | |
from onediffusion.dataset.raydiff_utils import cameras_to_rays, first_camera_transform, normalize_cameras | |
from onediffusion.dataset.transforms import CenterCropResizeImage | |
from pytorch3d.renderer import PerspectiveCameras | |
import numpy as np | |
def _cameras_from_opencv_projection( | |
R: torch.Tensor, | |
tvec: torch.Tensor, | |
camera_matrix: torch.Tensor, | |
image_size: torch.Tensor, | |
do_normalize_cameras, | |
normalize_scale, | |
) -> PerspectiveCameras: | |
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) | |
principal_point = camera_matrix[:, :2, 2] | |
# Retype the image_size correctly and flip to width, height. | |
image_size_wh = image_size.to(R).flip(dims=(1,)) | |
# Screen to NDC conversion: | |
# For non square images, we scale the points such that smallest side | |
# has range [-1, 1] and the largest side has range [-u, u], with u > 1. | |
# This convention is consistent with the PyTorch3D renderer, as well as | |
# the transformation function `get_ndc_to_screen_transform`. | |
scale = image_size_wh.to(R).min(dim=1, keepdim=True)[0] / 2.0 | |
scale = scale.expand(-1, 2) | |
c0 = image_size_wh / 2.0 | |
# Get the PyTorch3D focal length and principal point. | |
focal_pytorch3d = focal_length / scale | |
p0_pytorch3d = -(principal_point - c0) / scale | |
# For R, T we flip x, y axes (opencv screen space has an opposite | |
# orientation of screen axes). | |
# We also transpose R (opencv multiplies points from the opposite=left side). | |
R_pytorch3d = R.clone().permute(0, 2, 1) | |
T_pytorch3d = tvec.clone() | |
R_pytorch3d[:, :, :2] *= -1 | |
T_pytorch3d[:, :2] *= -1 | |
cams = PerspectiveCameras( | |
R=R_pytorch3d, | |
T=T_pytorch3d, | |
focal_length=focal_pytorch3d, | |
principal_point=p0_pytorch3d, | |
image_size=image_size, | |
device=R.device, | |
) | |
if do_normalize_cameras: | |
cams, _ = normalize_cameras(cams, scale=normalize_scale) | |
cams = first_camera_transform(cams, rotation_only=False) | |
return cams | |
def calculate_rays(Ks, sizes, Rs, Ts, target_size, use_plucker=True, do_normalize_cameras=False, normalize_scale=1.0): | |
cameras = _cameras_from_opencv_projection( | |
R=Rs, | |
tvec=Ts, | |
camera_matrix=Ks, | |
image_size=sizes, | |
do_normalize_cameras=do_normalize_cameras, | |
normalize_scale=normalize_scale | |
) | |
rays_embedding = cameras_to_rays( | |
cameras=cameras, | |
num_patches_x=target_size, | |
num_patches_y=target_size, | |
crop_parameters=None, | |
use_plucker=use_plucker | |
) | |
return rays_embedding.rays | |
def convert_rgba_to_rgb_white_bg(image): | |
"""Convert RGBA image to RGB with white background""" | |
if image.mode == 'RGBA': | |
# Create a white background | |
background = Image.new('RGBA', image.size, (255, 255, 255, 255)) | |
# Composite the image onto the white background | |
return Image.alpha_composite(background, image).convert('RGB') | |
return image.convert('RGB') | |
class MultiviewDataset(Dataset): | |
def __init__( | |
self, | |
scene_folders: str, | |
samples_per_set: Union[int, Tuple[int, int]], # Changed from samples_per_set to samples_range | |
transform=None, | |
caption_keys: Union[str, List] = "caption", | |
multiscale=False, | |
aspect_ratio_type=ASPECT_RATIO_512, | |
c2w_scaling=1.7, | |
default_max_distance=1, # default max distance from all camera of a scene , | |
do_normalize=True, # whether normalize translation of c2w with max_distance | |
swap_xz=False, # whether swap x and z axis of 3D scenes | |
valid_paths: str = "", | |
frame_sliding_windows: float = None # limit all sampled frames to be within this window, so that camera poses won't be too different | |
): | |
if not isinstance(samples_per_set, tuple) and not isinstance(samples_per_set, list): | |
samples_per_set = (samples_per_set, samples_per_set) | |
self.samples_range = samples_per_set # Tuple of (min_samples, max_samples) | |
self.transform = transform | |
self.caption_keys = caption_keys if isinstance(caption_keys, list) else [caption_keys] | |
self.aspect_ratio = aspect_ratio_type | |
self.scene_folders = sorted(glob.glob(scene_folders)) | |
# filter out scene folders that do not have transforms.json | |
self.scene_folders = list(filter(lambda x: os.path.exists(os.path.join(x, "transforms.json")), self.scene_folders)) | |
# if valid_paths.txt exists, only use paths in that file | |
if os.path.exists(valid_paths): | |
with open(valid_paths, 'r') as f: | |
valid_scene_folders = f.read().splitlines() | |
self.scene_folders = sorted(valid_scene_folders) | |
self.c2w_scaling = c2w_scaling | |
self.do_normalize = do_normalize | |
self.default_max_distance = default_max_distance | |
self.swap_xz = swap_xz | |
self.frame_sliding_windows = frame_sliding_windows | |
if multiscale: | |
assert self.aspect_ratio in [ASPECT_RATIO_512, ASPECT_RATIO_1024, ASPECT_RATIO_2048, ASPECT_RATIO_2880] | |
if self.aspect_ratio in [ASPECT_RATIO_2048, ASPECT_RATIO_2880]: | |
self.interpolate_model = T.InterpolationMode.LANCZOS | |
self.ratio_index = {} | |
self.ratio_nums = {} | |
for k, v in self.aspect_ratio.items(): | |
self.ratio_index[float(k)] = [] # used for self.getitem | |
self.ratio_nums[float(k)] = 0 # used for batch-sampler | |
def __len__(self): | |
return len(self.scene_folders) | |
def __getitem__(self, idx): | |
try: | |
scene_path = self.scene_folders[idx] | |
if os.path.exists(os.path.join(scene_path, "images")): | |
image_folder = os.path.join(scene_path, "images") | |
downscale_factor = 1 | |
elif os.path.exists(os.path.join(scene_path, "images_4")): | |
image_folder = os.path.join(scene_path, "images_4") | |
downscale_factor = 1 / 4 | |
elif os.path.exists(os.path.join(scene_path, "images_8")): | |
image_folder = os.path.join(scene_path, "images_8") | |
downscale_factor = 1 / 8 | |
else: | |
raise NotImplementedError | |
json_path = os.path.join(scene_path, "transforms.json") | |
caption_path = os.path.join(scene_path, "caption.json") | |
image_files = os.listdir(image_folder) | |
with open(json_path, 'r') as f: | |
json_data = json.load(f) | |
height, width = json_data['h'], json_data['w'] | |
dh, dw = int(height * downscale_factor), int(width * downscale_factor) | |
fl_x, fl_y = json_data['fl_x'] * downscale_factor, json_data['fl_y'] * downscale_factor | |
cx = dw // 2 | |
cy = dh // 2 | |
frame_list = json_data['frames'] | |
# Randomly select number of samples | |
samples_per_set = random.randint(self.samples_range[0], self.samples_range[1]) | |
# uniformly for all scenes | |
if self.frame_sliding_windows is None: | |
selected_indices = random.sample(range(len(frame_list)), min(samples_per_set, len(frame_list))) | |
# limit the multiview to be in a sliding window (to avoid catastrophic difference in camera angles) | |
else: | |
# Determine the starting index of the sliding window | |
if len(frame_list) <= self.frame_sliding_windows: | |
# If the frame list is smaller than or equal to X, use the entire list | |
window_start = 0 | |
window_end = len(frame_list) | |
else: | |
# Randomly select a starting point for the window | |
window_start = random.randint(0, len(frame_list) - self.frame_sliding_windows) | |
window_end = window_start + self.frame_sliding_windows | |
# Get the indices within the sliding window | |
window_indices = list(range(window_start, window_end)) | |
# Randomly sample indices from the window | |
selected_indices = random.sample(window_indices, samples_per_set) | |
image_files = [os.path.basename(frame_list[i]['file_path']) for i in selected_indices] | |
image_paths = [os.path.join(image_folder, file) for file in image_files] | |
# Load images and convert RGBA to RGB with white background | |
images = [convert_rgba_to_rgb_white_bg(Image.open(image_path)) for image_path in image_paths] | |
if self.transform: | |
images = [self.transform(image) for image in images] | |
else: | |
closest_size, closest_ratio = self.aspect_ratio['1.0'], 1.0 | |
closest_size = tuple(map(int, closest_size)) | |
transform = T.Compose([ | |
T.ToTensor(), | |
CenterCropResizeImage(closest_size), | |
T.Normalize([.5], [.5]), | |
]) | |
images = [transform(image) for image in images] | |
images = torch.stack(images) | |
c2ws = [frame_list[i]['transform_matrix'] for i in selected_indices] | |
c2ws = torch.tensor(c2ws).reshape(-1, 4, 4) | |
# max_distance = json_data.get('max_distance', self.default_max_distance) | |
# if 'max_distance' not in json_data.keys(): | |
# print(f"not found `max_distance` in json path: {json_path}") | |
if self.swap_xz: | |
swap_xz = torch.tensor([[[0, 0, 1., 0], | |
[0, 1., 0, 0], | |
[-1., 0, 0, 0], | |
[0, 0, 0, 1.]]]) | |
c2ws = swap_xz @ c2ws | |
# OPENGL to OPENCV | |
c2ws[:, 0:3, 1:3] *= -1 | |
c2ws = c2ws[:, [1, 0, 2, 3], :] | |
c2ws[:, 2, :] *= -1 | |
w2cs = torch.inverse(c2ws) | |
K = torch.tensor([[[fl_x, 0, cx], [0, fl_y, cy], [0, 0, 1]]]).repeat(len(c2ws), 1, 1) | |
Rs = w2cs[:, :3, :3] | |
Ts = w2cs[:, :3, 3] | |
sizes = torch.tensor([[dh, dw]]).repeat(len(c2ws), 1) | |
# get ray embedding and padding last dimension to 16 (num channels of VAE) | |
# rays_od = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, use_plucker=False, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling) | |
rays = calculate_rays(K, sizes, Rs, Ts, closest_size[0] // 8, do_normalize_cameras=self.do_normalize, normalize_scale=self.c2w_scaling) | |
rays = rays.reshape(samples_per_set, closest_size[0] // 8, closest_size[1] // 8, 6) | |
# padding = (0, 10) # pad the last dimension to 16 | |
# rays = torch.nn.functional.pad(rays, padding, "constant", 0) | |
rays = torch.cat([rays, rays, rays[..., :4]], dim=-1) * 1.658 | |
if os.path.exists(caption_path): | |
with open(caption_path, 'r') as f: | |
caption_key = random.choice(self.caption_keys) | |
caption = json.load(f).get(caption_key, "") | |
else: | |
caption = "" | |
caption = "[[multiview]] " + caption if caption else "[[multiview]]" | |
return { | |
'pixel_values': images, | |
'rays': rays, | |
'aspect_ratio': closest_ratio, | |
'caption': caption, | |
'height': dh, | |
'width': dw, | |
# 'origins': rays_od[..., :3], | |
# 'dirs': rays_od[..., 3:6] | |
} | |
except Exception as e: | |
return self.__getitem__(random.randint(0, len(self.scene_folders) - 1)) | |