Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import numpy as np | |
import PIL | |
import torch | |
import torchvision | |
from src.mast3r_src.dust3r.dust3r.datasets.utils.transforms import ImgNorm | |
from src.mast3r_src.dust3r.dust3r.utils.geometry import depthmap_to_absolute_camera_coordinates, geotrf | |
from src.mast3r_src.dust3r.dust3r.utils.misc import invalid_to_zeros | |
import src.mast3r_src.dust3r.dust3r.datasets.utils.cropping as cropping | |
def crop_resize_if_necessary(image, depthmap, intrinsics, resolution): | |
"""Adapted from DUST3R's Co3D dataset implementation""" | |
if not isinstance(image, PIL.Image.Image): | |
image = PIL.Image.fromarray(image) | |
# Downscale with lanczos interpolation so that image.size == resolution cropping centered on the principal point | |
# The new window will be a rectangle of size (2*min_margin_x, 2*min_margin_y) centered on (cx,cy) | |
W, H = image.size | |
cx, cy = intrinsics[:2, 2].round().astype(int) | |
min_margin_x = min(cx, W - cx) | |
min_margin_y = min(cy, H - cy) | |
assert min_margin_x > W / 5 | |
assert min_margin_y > H / 5 | |
l, t = cx - min_margin_x, cy - min_margin_y | |
r, b = cx + min_margin_x, cy + min_margin_y | |
crop_bbox = (l, t, r, b) | |
image, depthmap, intrinsics = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) | |
# High-quality Lanczos down-scaling | |
target_resolution = np.array(resolution) | |
image, depthmap, intrinsics = cropping.rescale_image_depthmap(image, depthmap, intrinsics, target_resolution) | |
# Actual cropping (if necessary) with bilinear interpolation | |
intrinsics2 = cropping.camera_matrix_of_crop(intrinsics, image.size, resolution, offset_factor=0.5) | |
crop_bbox = cropping.bbox_from_intrinsics_in_out(intrinsics, intrinsics2, resolution) | |
image, depthmap, intrinsics2 = cropping.crop_image_depthmap(image, depthmap, intrinsics, crop_bbox) | |
return image, depthmap, intrinsics2 | |
class DUST3RSplattingDataset(torch.utils.data.Dataset): | |
def __init__(self, data, coverage, resolution, num_epochs_per_epoch=1, alpha=0.3, beta=0.3): | |
super(DUST3RSplattingDataset, self).__init__() | |
self.data = data | |
self.coverage = coverage | |
self.num_context_views = 2 | |
self.num_target_views = 3 | |
self.resolution = resolution | |
self.transform = ImgNorm | |
self.org_transform = torchvision.transforms.ToTensor() | |
self.num_epochs_per_epoch = num_epochs_per_epoch | |
self.alpha = alpha | |
self.beta = beta | |
def __getitem__(self, idx): | |
sequence = self.data.sequences[idx // self.num_epochs_per_epoch] | |
sequence_length = len(self.data.color_paths[sequence]) | |
context_views, target_views = self.sample(sequence, self.num_target_views, self.alpha, self.beta) | |
views = {"context": [], "target": [], "scene": sequence} | |
# Fetch the context views | |
for c_view in context_views: | |
assert c_view < sequence_length, f"Invalid view index: {c_view}, sequence length: {sequence_length}, c_views: {context_views}" | |
view = self.data.get_view(sequence, c_view, self.resolution) | |
# Transform the input | |
view['img'] = self.transform(view['original_img']) | |
view['original_img'] = self.org_transform(view['original_img']) | |
# Create the point cloud and validity mask | |
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) | |
view['pts3d'] = pts3d | |
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) | |
assert view['valid_mask'].any(), f"Invalid mask for sequence: {sequence}, view: {c_view}" | |
views['context'].append(view) | |
# Fetch the target views | |
for t_view in target_views: | |
view = self.data.get_view(sequence, t_view, self.resolution) | |
view['original_img'] = self.org_transform(view['original_img']) | |
views['target'].append(view) | |
return views | |
def __len__(self): | |
return len(self.data.sequences) * self.num_epochs_per_epoch | |
def sample(self, sequence, num_target_views, context_overlap_threshold=0.5, target_overlap_threshold=0.6): | |
first_context_view = random.randint(0, len(self.data.color_paths[sequence]) - 1) | |
# Pick a second context view that has sufficient overlap with the first context view | |
valid_second_context_views = [] | |
for frame in range(len(self.data.color_paths[sequence])): | |
if frame == first_context_view: | |
continue | |
overlap = self.coverage[sequence][first_context_view][frame] | |
if overlap > context_overlap_threshold: | |
valid_second_context_views.append(frame) | |
if len(valid_second_context_views) > 0: | |
second_context_view = random.choice(valid_second_context_views) | |
# If there are no valid second context views, pick the best one | |
else: | |
best_view = None | |
best_overlap = None | |
for frame in range(len(self.data.color_paths[sequence])): | |
if frame == first_context_view: | |
continue | |
overlap = self.coverage[sequence][first_context_view][frame] | |
if best_view is None or overlap > best_overlap: | |
best_view = frame | |
best_overlap = overlap | |
second_context_view = best_view | |
# Pick the target views | |
valid_target_views = [] | |
for frame in range(len(self.data.color_paths[sequence])): | |
if frame == first_context_view or frame == second_context_view: | |
continue | |
overlap_max = max( | |
self.coverage[sequence][first_context_view][frame], | |
self.coverage[sequence][second_context_view][frame] | |
) | |
if overlap_max > target_overlap_threshold: | |
valid_target_views.append(frame) | |
if len(valid_target_views) >= num_target_views: | |
target_views = random.sample(valid_target_views, num_target_views) | |
# If there are not enough valid target views, pick the best ones | |
else: | |
overlaps = [] | |
for frame in range(len(self.data.color_paths[sequence])): | |
if frame == first_context_view or frame == second_context_view: | |
continue | |
overlap = max( | |
self.coverage[sequence][first_context_view][frame], | |
self.coverage[sequence][second_context_view][frame] | |
) | |
overlaps.append((frame, overlap)) | |
overlaps.sort(key=lambda x: x[1], reverse=True) | |
target_views = [frame for frame, _ in overlaps[:num_target_views]] | |
return [first_context_view, second_context_view], target_views | |
class DUST3RSplattingTestDataset(torch.utils.data.Dataset): | |
def __init__(self, data, samples, resolution): | |
self.data = data | |
self.samples = samples | |
self.resolution = resolution | |
self.transform = ImgNorm | |
self.org_transform = torchvision.transforms.ToTensor() | |
def get_view(self, sequence, c_view): | |
view = self.data.get_view(sequence, c_view, self.resolution) | |
# Transform the input | |
view['img'] = self.transform(view['original_img']) | |
view['original_img'] = self.org_transform(view['original_img']) | |
# Create the point cloud and validity mask | |
pts3d, valid_mask = depthmap_to_absolute_camera_coordinates(**view) | |
view['pts3d'] = pts3d | |
view['valid_mask'] = valid_mask & np.isfinite(pts3d).all(axis=-1) | |
assert view['valid_mask'].any(), f"Invalid mask for sequence: {sequence}, view: {c_view}" | |
return view | |
def __getitem__(self, idx): | |
sequence, c_view_1, c_view_2, target_view = self.samples[idx] | |
c_view_1, c_view_2, target_view = int(c_view_1), int(c_view_2), int(target_view) | |
fetched_c_view_1 = self.get_view(sequence, c_view_1) | |
fetched_c_view_2 = self.get_view(sequence, c_view_2) | |
fetched_target_view = self.get_view(sequence, target_view) | |
views = {"context": [fetched_c_view_1, fetched_c_view_2], "target": [fetched_target_view], "scene": sequence} | |
return views | |
def __len__(self): | |
return len(self.samples) | |