LN3Diff_I23D / datasets /g_buffer_objaverse.py
NIRVANALAN
init
11e6f7b
raw
history blame
115 kB
import os
import kiui
from kiui.op import recenter
import collections
import math
import time
import itertools
import pickle
from typing import Any
import lmdb
import cv2
import imageio
import numpy as np
from PIL import Image
import Imath
import OpenEXR
from pdb import set_trace as st
from pathlib import Path
import torchvision
from nsr.camera_utils import generate_input_camera
from einops import rearrange, repeat
from functools import partial
import io
import gzip
import random
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.utils.data.distributed import DistributedSampler
from pathlib import Path
import lz4.frame
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from utils.general_utils import PILtoTorch, matrix_to_quaternion
from guided_diffusion import logger
import json
import webdataset as wds
from .shapenet import LMDBDataset, LMDBDataset_MV_Compressed, decompress_and_open_image_gzip, decompress_array
from kiui.op import safe_normalize
from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World
def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
return 2 * math.atan(pixels / (2 * focal))
def resize_depth_mask(depth_to_resize, resolution):
depth_resized = cv2.resize(depth_to_resize, (resolution, resolution),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
return depth_resized, depth_resized > 0 # type: ignore
def resize_depth_mask_Tensor(depth_to_resize, resolution):
if depth_to_resize.shape[-1] != resolution:
depth_resized = torch.nn.functional.interpolate(
input=depth_to_resize.unsqueeze(1),
size=(resolution, resolution),
# mode='bilinear',
mode='nearest',
align_corners=False,
).squeeze(1)
else:
depth_resized = depth_to_resize
return depth_resized, depth_resized > 0 # type: ignore
def load_dataset(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
use_lmdb=False,
use_wds=False,
use_lmdb_compressed=False,
infi_sampler=True):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# st()
if use_wds:
return load_wds_data(file_path, reso, reso_encoder, batch_size,
num_workers)
if use_lmdb:
logger.log('using LMDB dataset')
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
if use_lmdb_compressed:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
# dataset = dataset_cls(file_path)
else:
if 'nv' in trainer_name:
dataset_cls = NovelViewObjverseDataset
else:
dataset_cls = MultiViewObjverseDataset # 1.5-2iter/s
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
loader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
pin_memory=True,
persistent_workers=num_workers > 0,
shuffle=False)
return loader
def load_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
use_lmdb=False,
use_wds=False,
use_lmdb_compressed=False,
plucker_embedding=False,
infi_sampler=True):
if use_lmdb:
logger.log('using LMDB dataset')
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later.
if use_lmdb_compressed:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
if 'nv' in trainer_name:
dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later.
else:
if 'nv' in trainer_name:
dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s
else:
dataset_cls = MultiViewObjverseDataset
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
plucker_embedding=plucker_embedding)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# st()
if infi_sampler:
train_sampler = DistributedSampler(dataset=dataset,
shuffle=True,
drop_last=True)
loader = DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=True,
pin_memory=True,
persistent_workers=num_workers > 0,
sampler=train_sampler)
while True:
yield from loader
# else:
# # loader = DataLoader(dataset,
# # batch_size=batch_size,
# # num_workers=num_workers,
# # drop_last=False,
# # pin_memory=True,
# # persistent_workers=num_workers > 0,
# # shuffle=False)
# st()
# return dataset
def load_eval_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=1,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
interval=1,
use_lmdb=False,
plucker_embedding=False,
load_real=False,
four_view_for_latent=False,
shuffle_across_cls=False,
load_extra_36_view=False,
gs_cam_format=False,
single_view_for_i23d=False,
**kwargs,
):
if use_lmdb:
logger.log('using LMDB dataset')
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later.
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=True,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
interval=interval)
elif load_real:
dataset = RealDataset(file_path,
# dataset = RealMVDataset(file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=load_depth,
test=True,
imgnet_normalize=imgnet_normalize,
interval=interval,
plucker_embedding=plucker_embedding)
else:
dataset = MultiViewObjverseDataset(
file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=load_depth,
test=True,
imgnet_normalize=imgnet_normalize,
interval=interval,
plucker_embedding=plucker_embedding,
four_view_for_latent=four_view_for_latent,
load_extra_36_view=load_extra_36_view,
shuffle_across_cls=shuffle_across_cls,
gs_cam_format=gs_cam_format,
single_view_for_i23d=single_view_for_i23d,
)
print('eval dataset size: {}'.format(len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
shuffle=False,
)
# sampler=train_sampler)
return loader
def load_data_for_lmdb(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec',
shuffle_across_cls=False,
four_view_for_latent=False,
wds_split=1):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# if 'nv' in trainer_name:
# dataset_cls = NovelViewDataset
# else:
# dataset_cls = MultiViewDataset
# st()
dataset_cls = MultiViewObjverseDatasetforLMDB
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split,
four_view_for_latent=four_view_for_latent)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
loader = DataLoader(
dataset,
shuffle=False,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
# prefetch_factor=2,
# prefetch_factor=3,
pin_memory=True,
persistent_workers=num_workers > 0,
)
# sampler=train_sampler)
# while True:
# yield from loader
return loader, dataset.dataset_name, len(dataset)
def load_lmdb_for_lmdb(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
# shuffle=True,
num_workers=6,
load_depth=False,
preprocess=None,
imgnet_normalize=True,
dataset_size=-1,
trainer_name='input_rec'):
# st()
# dataset_cls = {
# 'input_rec': MultiViewDataset,
# 'nv': NovelViewDataset,
# }[trainer_name]
# if 'nv' in trainer_name:
# dataset_cls = NovelViewDataset
# else:
# dataset_cls = MultiViewDataset
# st()
dataset_cls = Objv_LMDBDataset_MV_Compressed_for_lmdb
dataset = dataset_cls(file_path,
reso,
reso_encoder,
test=False,
preprocess=preprocess,
load_depth=load_depth,
imgnet_normalize=imgnet_normalize,
dataset_size=dataset_size)
logger.log('dataset_cls: {}, dataset size: {}'.format(
trainer_name, len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True)
loader = DataLoader(
dataset,
shuffle=False,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
prefetch_factor=2,
# prefetch_factor=3,
pin_memory=True,
persistent_workers=True,
)
# sampler=train_sampler)
# while True:
# yield from loader
return loader, len(dataset)
def load_memory_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=1,
# load_depth=True,
preprocess=None,
imgnet_normalize=True,
**kwargs):
# load a single-instance into the memory to speed up training IO
# dataset = MultiViewObjverseDataset(file_path,
dataset = NovelViewObjverseDataset(file_path,
reso,
reso_encoder,
preprocess=preprocess,
load_depth=True,
test=False,
overfitting=True,
imgnet_normalize=imgnet_normalize,
overfitting_bs=batch_size,
**kwargs)
logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset)))
# train_sampler = DistributedSampler(dataset=dataset)
loader = DataLoader(
dataset,
batch_size=len(dataset),
num_workers=num_workers,
drop_last=False,
shuffle=False,
)
all_data: dict = next(
iter(loader)
) # torchvision.utils.save_image(all_data['img'], 'gt.jpg', normalize=True, value_range=(-1,1))
if kwargs.get('gs_cam_format', False): # gs rendering pipeline
# ! load V=4 images for training in a batch.
while True:
# indices = torch.randperm(len(dataset))[:4]
indices = torch.randperm(
len(dataset))[:batch_size] # all instances
# indices2 = torch.randperm(len(dataset))[:] # all instances
batch_c = collections.defaultdict(dict)
for k in ['c', 'nv_c']:
for k_c, v_c in all_data[k].items():
batch_c[k][k_c] = torch.index_select(
v_c, dim=0, index=indices).reshape(
batch_size //
4, 4, *v_c.shape[1:]).float() if isinstance(
v_c, torch.Tensor) else v_c.float() # float
batch_c['c']['tanfov'] = batch_c['c']['tanfov'][0][0].item()
batch_c['nv_c']['tanfov'] = batch_c['nv_c']['tanfov'][0][0].item()
batch_data = {}
for k, v in all_data.items():
if k not in ['c', 'nv_c']:
batch_data[k] = torch.index_select(
v, dim=0, index=indices).float() if isinstance(
v, torch.Tensor) else v # float
yield {
**batch_data,
**batch_c,
}
else:
while True:
start_idx = np.random.randint(0, len(dataset) - batch_size + 1)
yield {
k: v[start_idx:start_idx + batch_size]
for k, v in all_data.items()
}
def read_dnormal(normald_path, cond_pos, h=None, w=None):
cond_cam_dis = np.linalg.norm(cond_pos, 2)
near = 0.867 #sqrt(3) * 0.5
near_distance = cond_cam_dis - near
normald = cv2.imread(normald_path, cv2.IMREAD_UNCHANGED).astype(np.float32)
depth = normald[..., 3:]
depth[depth < near_distance] = 0
if h is not None:
assert w is not None
# depth = cv2.resize(depth, (h, w)) # 512,512, 1 -> self.reso, self.reso
depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_NEAREST
) # 512,512, 1 -> self.reso, self.reso
else:
depth = depth[..., 0]
return torch.from_numpy(depth).float()
def get_intri(target_im=None, h=None, w=None, normalize=False):
if target_im is None:
assert (h is not None and w is not None)
else:
h, w = target_im.shape[:2]
fx = fy = 1422.222
res_raw = 1024
f_x = f_y = fx * h / res_raw
K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
if normalize: # center is [0.5, 0.5], eg3d renderer tradition
K[:6] /= h
# print("intr: ", K)
return K
def convert_pose(C2W):
# https://github.com/modelscope/richdreamer/blob/c3d9a77fa15fc42dbae12c2d41d64aaec14efd37/dataset/gobjaverse/depth_warp_example.py#L402
flip_yz = np.eye(4)
flip_yz[1, 1] = -1
flip_yz[2, 2] = -1
C2W = np.matmul(C2W, flip_yz)
return torch.from_numpy(C2W)
def read_camera_matrix_single(json_file):
with open(json_file, 'r', encoding='utf8') as reader:
json_content = json.load(reader)
'''
# NOTE that different from unity2blender experiments.
camera_matrix = np.eye(4)
camera_matrix[:3, 0] = np.array(json_content['x'])
camera_matrix[:3, 1] = -np.array(json_content['y'])
camera_matrix[:3, 2] = -np.array(json_content['z'])
camera_matrix[:3, 3] = np.array(json_content['origin'])
'''
camera_matrix = np.eye(4) # blender-based
camera_matrix[:3, 0] = np.array(json_content['x'])
camera_matrix[:3, 1] = np.array(json_content['y'])
camera_matrix[:3, 2] = np.array(json_content['z'])
camera_matrix[:3, 3] = np.array(json_content['origin'])
# print(camera_matrix)
# '''
# return convert_pose(camera_matrix)
return camera_matrix
def unity2blender(normal):
normal_clone = normal.copy()
normal_clone[..., 0] = -normal[..., -1]
normal_clone[..., 1] = -normal[..., 0]
normal_clone[..., 2] = normal[..., 1]
return normal_clone
def blender2midas(img):
'''Blender: rub
midas: lub
'''
img[..., 0] = -img[..., 0]
img[..., 1] = -img[..., 1]
img[..., -1] = -img[..., -1]
return img
def current_milli_time():
return round(time.time() * 1000)
# modified from ShapeNet class
class MultiViewObjverseDataset(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
four_view_for_latent=False,
single_view_for_i23d=False,
load_extra_36_view=False,
gs_cam_format=False,
**kwargs):
self.load_extra_36_view = load_extra_36_view
# st()
self.gs_cam_format = gs_cam_format
self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction
self.single_view_for_i23d = single_view_for_i23d
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.intrinsics = get_intri(h=self.reso, w=self.reso,
normalize=True).reshape(9)
assert not self.classes, "Not support class condition now."
dataset_name = Path(self.file_path).stem.split('_')[0]
self.dataset_name = dataset_name
self.zfar = 100.0
self.znear = 0.01
# if test:
# self.ins_list = sorted(os.listdir(self.file_path))[0:1] # the first 1 instance for evaluation reference.
# else:
# ! TODO, read from list?
def load_single_cls_instances(file_path):
ins_list = [] # the first 1 instance for evaluation reference.
for dict_dir in os.listdir(file_path)[:]:
for ins_dir in os.listdir(os.path.join(file_path, dict_dir)):
# self.ins_list.append(os.path.join(self.file_path, dict_dir, ins_dir,))
ins_list.append(
os.path.join(file_path, dict_dir, ins_dir,
'campos_512_v4'))
return ins_list
if shuffle_across_cls:
self.ins_list = []
# for subset in ['Animals', 'Transportations_tar', 'Furnitures']:
# for subset in ['Furnitures']:
# selected subset for training
for subset in [ # ! around 17W instances in total. MVImageNet is the next thing to deal with? Later.
# 'daily-used',
# 'Food',
# 'Plants',
# 'Electronics',
# 'BuildingsOutdoor',
# 'Human-Shape',
'Animals',
# 'Transportations_tar',
# 'Furnitures',
]: # selected subset for training
self.ins_list += load_single_cls_instances(
os.path.join(self.file_path, subset))
# st()
current_time = int(current_milli_time()
) # randomly shuffle given current time
random.seed(current_time)
random.shuffle(self.ins_list)
else: # preprocess single class
self.ins_list = load_single_cls_instances(self.file_path)
self.ins_list = sorted(self.ins_list)
# if test:
# self.ins_list = self.ins_list[0:1]
if overfitting:
self.ins_list = self.ins_list[:1]
self.rgb_list = []
self.pose_list = []
self.depth_list = []
self.data_ins_list = []
self.instance_data_length = -1
with open(
# '/mnt/yslan/objaverse/richdreamer/dataset/text_captions_cap3d.json',
'./datasets/text_captions_cap3d.json',
) as f:
self.caption_data = json.load(f)
self.shuffle_across_cls = shuffle_across_cls
# for ins in self.ins_list[47000:]:
if four_view_for_latent:
self.wds_split_all = 1 # ! when dumping latent
ins_list_to_process = self.ins_list
else:
self.wds_split_all = 4
# self.wds_split_all = 8 # ! 8 cls in total
all_ins_size = len(self.ins_list)
ratio_size = all_ins_size // self.wds_split_all + 1
ins_list_to_process = self.ins_list[ratio_size *
(wds_split - 1):ratio_size *
wds_split]
# st()
for ins in ins_list_to_process:
# ins = os.path.join(
# # self.file_path, ins , 'campos_512_v4'
# self.file_path, ins ,
# # 'compos_512_v4'
# )
# cur_rgb_path = os.path.join(self.file_path, ins, 'compos_512_v4')
# cur_pose_path = os.path.join(self.file_path, ins, 'pose')
# st()
# ][:27])
if self.four_view_for_latent:
# ! v=6 version infer latent
cur_all_fname = [f'{idx:05d}' for idx in [25,0,9,18,27,33]]
elif self.single_view_for_i23d:
cur_all_fname = [f'{idx:05d}'
for idx in [2]] # ! furniture side view
else:
cur_all_fname = [t.split('.')[0] for t in os.listdir(ins)
] # use full set for training
if shuffle_across_cls:
random.seed(current_time)
random.shuffle(cur_all_fname)
else:
cur_all_fname = sorted(cur_all_fname)
if self.instance_data_length == -1:
self.instance_data_length = len(cur_all_fname)
else:
try: # data missing?
assert len(cur_all_fname) == self.instance_data_length
except:
with open('missing_ins.txt', 'a') as f:
f.write(str(Path(ins.parent)) +
'\n') # remove the "campos_512_v4"
continue
self.pose_list += ([
os.path.join(ins, fname, fname + '.json')
for fname in cur_all_fname
])
self.rgb_list += ([
os.path.join(ins, fname, fname + '.png')
for fname in cur_all_fname
])
self.depth_list += ([
os.path.join(ins, fname, fname + '_nd.exr')
for fname in cur_all_fname
])
self.data_ins_list += ([ins] * len(cur_all_fname))
# check
# ! setup normalizataion
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
def get_source_cw2wT(self, source_cameras_view_to_world):
return matrix_to_quaternion(
source_cameras_view_to_world[:3, :3].transpose(0, 1))
def c_to_3dgs_format(self, pose):
# TODO, switch to torch version (batched later)
c2w = pose[:16].reshape(4, 4) # 3x4
# ! load cam
w2c = np.linalg.inv(c2w)
R = np.transpose(
w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
fx = pose[16]
FovX = focal2fov(fx, 1)
FovY = focal2fov(fx, 1)
tanfovx = math.tan(FovX * 0.5)
tanfovy = math.tan(FovY * 0.5)
assert tanfovx == tanfovy
trans = np.array([0.0, 0.0, 0.0])
scale = 1.0
world_view_transform = torch.tensor(getWorld2View2(R, T, trans,
scale)).transpose(
0, 1)
projection_matrix = getProjectionMatrix(znear=self.znear,
zfar=self.zfar,
fovX=FovX,
fovY=FovY).transpose(0, 1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(
projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
view_world_transform = torch.tensor(getView2World(R, T, trans,
scale)).transpose(
0, 1)
# item.update(viewpoint_cam=[viewpoint_cam])
c = {}
c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform)
c.update(
# projection_matrix=projection_matrix, # K
cam_view=world_view_transform, # world_view_transform
cam_view_proj=full_proj_transform, # full_proj_transform
cam_pos=camera_center,
tanfov=tanfovx, # TODO, fix in the renderer
# orig_c2w=c2w,
# orig_w2c=w2c,
orig_pose=torch.from_numpy(pose),
orig_c2w=torch.from_numpy(c2w),
orig_w2c=torch.from_numpy(w2c),
# tanfovy=tanfovy,
)
return c # dict for gs rendering
def __len__(self):
return len(self.rgb_list)
def load_bbox(self, mask):
# st()
nonzero_value = torch.nonzero(mask)
height, width = nonzero_value.max(dim=0)[0]
top, left = nonzero_value.min(dim=0)[0]
bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
return bbox
def __getitem__(self, idx):
# try:
data = self._read_data(idx)
return data
# except Exception as e:
# with open('error_log.txt', 'a') as f:
# f.write(str(e) + '\n')
# with open('error_idx.txt', 'a') as f:
# f.write(str(self.data_ins_list[idx]) + '\n')
# print(e, flush=True)
# return {}
def gen_rays(self, c2w):
# Generate rays
self.h = self.reso_encoder
self.w = self.reso_encoder
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
indexing='ij')
# normalize to 0-1 pixel range
yy = yy / self.h
xx = xx / self.w
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
cx, cy, fx, fy = self.intrinsics[2], self.intrinsics[
5], self.intrinsics[0], self.intrinsics[4]
# cx *= self.w
# cy *= self.h
# f_x = f_y = fx * h / res_raw
c2w = torch.from_numpy(c2w).float()
xx = (xx - cx) / fx
yy = (yy - cy) / fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(-1, 3, 1)
del xx, yy, zz
# st()
dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
origins = origins.view(self.h, self.w, 3)
dirs = dirs.view(self.h, self.w, 3)
return origins, dirs
def _read_data(self, idx):
rgb_fname = self.rgb_list[idx]
pose_fname = self.pose_list[idx]
raw_img = imageio.imread(rgb_fname)
# ! RGBD
alpha_mask = raw_img[..., -1:] / 255
raw_img = alpha_mask * raw_img[..., :3] + (
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
raw_img = raw_img.astype(
np.uint8) # otherwise, float64 won't call ToTensor()
# return raw_img
# st()
if self.preprocess is None:
img_to_encoder = cv2.resize(raw_img,
(self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
img_to_encoder = img_to_encoder[
..., :3] #[3, reso_encoder, reso_encoder]
img_to_encoder = self.normalize(img_to_encoder)
else:
img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip
# return img_to_encoder
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
# interpolation=cv2.INTER_AREA)
# img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA)
# img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
# img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution
# img_sr = cv2.resize(
# raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4
# ) # just as refinement, since eg3d uses 64->128 final resolution
# img = torch.from_numpy(img)[..., :3].permute(
# 2, 0, 1) / 255.0 #[3, reso, reso]
img = torch.from_numpy(img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
# img_sr = torch.from_numpy(img_sr)[..., :3].permute(
# 2, 0, 1
# ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16]
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
# return c2w
# if self.load_depth:
# depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx],
# try:
depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso,
self.reso)
# return depth
# except:
# # print(self.depth_list[idx])
# raise NotImplementedError(self.depth_list[idx])
# if depth
# try:
bbox = self.load_bbox(depth > 0)
# except:
# print(rgb_fname)
# return {}
# st()
# plucker
rays_o, rays_d = self.gen_rays(c2w)
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1) # [h, w, 6]
img_to_encoder = torch.cat(
[img_to_encoder, rays_plucker.permute(2, 0, 1)],
0).float() # concat in C dim
# ! add depth as input
normalized_depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:],
self.reso_encoder,
self.reso_encoder).unsqueeze(0)
# normalized_depth = depth.unsqueeze(0) # min=0
img_to_encoder = torch.cat([img_to_encoder, normalized_depth],
0) # concat in C dim
c = np.concatenate([c2w.reshape(16), self.intrinsics],
axis=0).reshape(25).astype(
np.float32) # 25, no '1' dim needed.
if self.gs_cam_format:
c = self.c_to_3dgs_format(c)
else:
c = torch.from_numpy(c)
ret_dict = {
# 'rgb_fname': rgb_fname,
'img_to_encoder': img_to_encoder,
'img': img,
'c': c,
# 'img_sr': img_sr,
# 'ins_name': self.data_ins_list[idx]
}
ins = str(
(Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent)
if self.shuffle_across_cls:
caption = self.caption_data['/'.join(ins.split('/')[1:])]
else:
caption = self.caption_data[ins]
ret_dict.update({
'depth': depth,
'depth_mask': depth > 0,
# 'depth_mask_sr': depth_mask_sr,
'bbox': bbox,
'caption': caption,
'rays_plucker': rays_plucker, # cam embedding used in lgm
'ins': ins, # placeholder
})
return ret_dict
class RealDataset(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
) -> None:
super().__init__()
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.rgb_list = []
all_fname = [
t for t in os.listdir(self.file_path)
if t.split('.')[1] in ['png', 'jpg']
]
all_fname = [name for name in all_fname if '-input' in name ]
self.rgb_list += ([
os.path.join(self.file_path, fname) for fname in all_fname
])
# if len(self.rgb_list) == 1:
# # placeholder
# self.rgb_list = self.rgb_list * 40
# ! setup normalizataion
transformations = [
transforms.ToTensor(), # [0,1] range
]
assert imgnet_normalize
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
# camera = torch.load('eval_pose.pt', map_location='cpu')
# self.eval_camera = camera
# pre-cache
# self.calc_rays_plucker()
def __len__(self):
return len(self.rgb_list)
def __getitem__(self, index) -> Any:
# return super().__getitem__(index)
rgb_fname = self.rgb_list[index]
# ! preprocess, normalize
raw_img = imageio.imread(rgb_fname)
# interpolation=cv2.INTER_AREA)
if raw_img.shape[-1] == 4:
alpha_mask = raw_img[..., 3:4] / 255.0
bg_white = np.ones_like(alpha_mask) * 255.0
raw_img = raw_img[..., :3] * alpha_mask + (
1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder]
raw_img = raw_img.astype(np.uint8)
# raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2)
# log gt
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
img = torch.from_numpy(img)[..., :3].permute(
2, 0, 1
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
ret_dict = {
# 'rgb_fname': rgb_fname,
# 'img_to_encoder':
# img_to_encoder.unsqueeze(0).repeat_interleave(40, 0),
'img': img,
# 'c': self.eval_camera, # TODO, get pre-calculated samples
# 'ins': 'placeholder',
# 'bbox': 'placeholder',
# 'caption': 'placeholder',
}
# ! repeat as a intance
return ret_dict
class RealMVDataset(Dataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
interval=1,
plucker_embedding=False,
shuffle_across_cls=False,
wds_split=1, # 4 splits to accelerate preprocessing
) -> None:
super().__init__()
self.file_path = file_path
self.overfitting = overfitting
self.scene_scale = scene_scale
self.reso = reso
self.reso_encoder = reso_encoder
self.classes = False
self.load_depth = load_depth
self.preprocess = preprocess
self.plucker_embedding = plucker_embedding
self.rgb_list = []
all_fname = [
t for t in os.listdir(self.file_path)
if t.split('.')[1] in ['png', 'jpg']
]
# all_fname = [name for name in all_fname if '-input' in name ]
# all_fname = [name for name in all_fname if 'sorting_board-input' in name ]
all_fname = [name for name in all_fname if 'teasure_chest-input' in name ]
# all_fname = [name for name in all_fname if 'bubble_mart_blue-input' in name ]
# all_fname = [name for name in all_fname if 'chair_comfort-input' in name ]
self.rgb_list += ([
os.path.join(self.file_path, fname) for fname in all_fname
])
# if len(self.rgb_list) == 1:
# # placeholder
# self.rgb_list = self.rgb_list * 40
# ! setup normalizataion
transformations = [
transforms.ToTensor(), # [0,1] range
]
# load zero123pp pose
# '''
camera = torch.load('assets/objv_eval_pose.pt', map_location='cpu') # 40, 25
# v=6 render
# zero123pp_pose = torch.load('assets/input_cameras.pt', map_location='cpu')[0]
# zero123pp_pose = torch.load('assets/input_cameras_1.5.pt', map_location='cpu')[0]
# V = zero123pp_pose.shape[0]
# zero123pp_pose = torch.cat([zero123pp_pose[:, :12], # change radius to 1.5 here
# torch.Tensor([0,0,0,1]).reshape(1,4).repeat_interleave(V, 0).to(camera),
# camera[:V, 16:]],
# 1)
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float)
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float)
# zero123pp_pose, _ = generate_input_camera(1.6, [[elevations[i], azimuths[i]] for i in range(6)], fov=30)
zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30)
K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same
# st()
zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1)
# ! directly adopt gt input
# self.indices = np.array([0,2,4,5])
# eval_camera = zero123pp_pose[self.indices]
# self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder
# ! adopt mv-diffusion output as input.
# self.indices = np.array([1,0,2,4,5])
self.indices = np.array([0,1,2,3,4,5])
eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization
# eval_camera = zero123pp_pose[self.indices]
# self.eval_camera = eval_camera
# self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder
# # * normalize here
# self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used.
# ! no normalization required
self.eval_camera = eval_camera
# st()
# self.eval_camera = self.eval_camera + np.random.randn(*self.eval_camera.shape) * 0.04 - 0.02
# '''
# self.eval_camera = torch.load('/nas/shared/V2V/yslan/logs/nips24/LSGM/t23d/FM/9cls/i23d/dit-XL2-MV-ampFast/gpu1-batch10-lr0-bf16-qknorm-fixinpRange-debug/2590000_c.pt')[1:6].float().cpu().numpy()
# self.eval_camera = self.normalize_camera(self.eval_camera, self.eval_camera[0:1])
# self.eval_camera = self.eval_camera + np.random.randn(*self.eval_camera.shape) * 0.04 - 0.02
# pass
# FIXME check how sensitive it is: requries training augmentation.
# self.eval_camera = torch.load('/nas/shared/V2V/yslan/logs/nips24/LSGM/t23d/FM/9cls/i23d/dit-XL2-MV-ampFast/gpu1-batch10-lr0-bf16-qknorm-fixinpRange-debug/2590000_c.pt')[1:]
def normalize_camera(self, c, c_frame0):
# assert c.shape[0] == self.chunk_size # 8 o r10
B = c.shape[0]
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4
canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4)
inverse_canonical_pose = np.linalg.inv(canonical_camera_poses)
inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0)
cam_radius = np.linalg.norm(
c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3],
axis=-1,
keepdims=False) # since g-buffer adopts dynamic radius here.
frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0)
frame1_fixed_pos[:, 2, -1] = -cam_radius
transform = frame1_fixed_pos @ inverse_canonical_pose
new_camera_poses = np.repeat(
transform, 1, axis=0
) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave()
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]],
axis=-1)
return c
def __len__(self):
return len(self.rgb_list)
def __getitem__(self, index) -> Any:
# return super().__getitem__(index)
rgb_fname = self.rgb_list[index]
# ! if loading training imgs
# raw_img = imageio.imread(rgb_fname)
# if raw_img.shape[-1] == 4:
# alpha_mask = raw_img[..., 3:4] / 255.0
# bg_white = np.ones_like(alpha_mask) * 255.0
# raw_img = raw_img[..., :3] * alpha_mask + (
# 1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder]
# raw_img = raw_img.astype(np.uint8)
# img = rearrange(raw_img, 'h (n w) c -> n c h w', n=6)[self.indices]
# img = torch.from_numpy(img) / 127.5 - 1 #
# '''
# img = cv2.resize(raw_img, (320,320),
# interpolation=cv2.INTER_LANCZOS4) # for easy concat
# img = torch.from_numpy(img)[..., :3].permute(
# 2, 0, 1
# ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range
# img = rearrange(raw_img, 'h (n w) c -> n c h w', n=6)[1:6] # ! if loading gt inp views
# '''
# ! if loading mv-diff output views
mv_img = imageio.imread(rgb_fname.replace('-input', ''))
# st()
# mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320)
mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.2) for img in mv_img], axis=0)
# mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.3) for img in mv_img], axis=0)
# mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0)
# mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.4) for img in mv_img], axis=0)
mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition
mv_img = torch.from_numpy(mv_img) / 127.5 - 1
# img = torch.cat([img[None], mv_img], 0)
img = mv_img # ! directly adopt output views for reconstruction/generation
# '''
ret_dict = {
'img': img,
'c': self.eval_camera
}
return ret_dict
class NovelViewObjverseDataset(MultiViewObjverseDataset):
"""novel view prediction version.
"""
def __init__(self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
**kwargs):
super().__init__(file_path, reso, reso_encoder, preprocess, classes,
load_depth, test, scene_scale, overfitting,
imgnet_normalize, dataset_size, overfitting_bs,
**kwargs)
def __getitem__(self, idx):
input_view = super().__getitem__(
idx) # get previous input view results
# get novel view of the same instance
novel_view = super().__getitem__(
(idx // self.instance_data_length) * self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view
class MultiViewObjverseDatasetforLMDB(MultiViewObjverseDataset):
def __init__(
self,
file_path,
reso,
reso_encoder,
preprocess=None,
classes=False,
load_depth=False,
test=False,
scene_scale=1,
overfitting=False,
imgnet_normalize=True,
dataset_size=-1,
overfitting_bs=-1,
shuffle_across_cls=False,
wds_split=1,
four_view_for_latent=False,
):
super().__init__(file_path,
reso,
reso_encoder,
preprocess,
classes,
load_depth,
test,
scene_scale,
overfitting,
imgnet_normalize,
dataset_size,
overfitting_bs,
shuffle_across_cls=shuffle_across_cls,
wds_split=wds_split,
four_view_for_latent=four_view_for_latent)
assert self.reso == 256
with open(
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
) as f:
self.caption_data = json.load(f)
lmdb_path = '/cpfs01/user/yangpeiqing.p/yslan/data/Furnitures_uncompressed/'
# with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
# self.idx_to_ins_mapping = json.load(f)
def __len__(self):
return super().__len__()
# return 100 # for speed debug
def __getitem__(self, idx):
# ret_dict = super().__getitem__(idx)
rgb_fname = self.rgb_list[idx]
pose_fname = self.pose_list[idx]
raw_img = imageio.imread(rgb_fname) # [..., :3]
# assert raw_img.shape[-1] == 4
if raw_img.shape[-1] == 4: # ! set bg to white
alpha_mask = raw_img[..., -1:] / 255 # [0,1]
raw_img = alpha_mask * raw_img[..., :3] + (
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255
raw_img = raw_img.astype(np.uint8)
raw_img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16]
c = np.concatenate([c2w.reshape(16), self.intrinsics],
axis=0).reshape(25).astype(
np.float32) # 25, no '1' dim needed.
c = torch.from_numpy(c)
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed.
# if self.load_depth:
# depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx],
# try:
depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso,
self.reso)
# except:
# # print(self.depth_list[idx])
# raise NotImplementedError(self.depth_list[idx])
# if depth
# try:
bbox = self.load_bbox(depth > 0)
ins = str(
(Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent)
if self.shuffle_across_cls:
caption = self.caption_data['/'.join(ins.split('/')[1:])]
else:
caption = self.caption_data[ins]
ret_dict = {
'raw_img': raw_img,
'c': c,
'depth': depth,
# 'depth_mask': depth_mask, # 64x64 here?
'bbox': bbox,
'ins': ins,
'caption': caption,
# 'fname': rgb_fname,
}
return ret_dict
class Objv_LMDBDataset_MV_Compressed(LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path,
reso,
reso_encoder,
imgnet_normalize,
dataset_size=dataset_size,
**kwargs)
self.instance_data_length = 40 # ! could save some key attributes in LMDB
if test:
self.length = self.instance_data_length
elif dataset_size > 0:
self.length = dataset_size * self.instance_data_length
# load caption data, and idx-to-ins mapping
with open(
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
) as f:
self.caption_data = json.load(f)
with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
self.idx_to_ins_mapping = json.load(f)
def _load_data(self, idx):
# '''
raw_img, depth, c, bbox = self._load_lmdb_data(idx)
# raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# resize depth and bbox
caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
return {
**self._post_process_sample(raw_img, depth),
'c': c,
'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8),
# 'bbox': (bbox*(self.reso/256.0)).astype(np.uint8), # TODO, double check 512 in wds?
'caption': caption
}
# '''
# raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# st()
# return {}
def __getitem__(self, idx):
return self._load_data(idx)
class Objv_LMDBDataset_MV_NoCompressed(Objv_LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
dataset_size, test, **kwargs)
def _load_data(self, idx):
# '''
raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# resize depth and bbox
caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
return {
**self._post_process_sample(raw_img, depth), 'c': c,
'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8),
'caption': caption
}
return {}
class Objv_LMDBDataset_NV_NoCompressed(Objv_LMDBDataset_MV_NoCompressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
dataset_size, test, **kwargs)
def __getitem__(self, idx):
input_view = self._load_data(idx) # get previous input view results
# get novel view of the same instance
try:
novel_view = self._load_data(
(idx // self.instance_data_length) *
self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
except Exception as e:
raise NotImplementedError(idx)
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view
class Objv_LMDBDataset_MV_Compressed_for_lmdb(LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
test=False,
**kwargs):
super().__init__(lmdb_path,
reso,
reso_encoder,
imgnet_normalize,
dataset_size=dataset_size,
**kwargs)
self.instance_data_length = 40 # ! could save some key attributes in LMDB
if test:
self.length = self.instance_data_length
elif dataset_size > 0:
self.length = dataset_size * self.instance_data_length
# load caption data, and idx-to-ins mapping
with open(
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json'
) as f:
self.caption_data = json.load(f)
with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f:
self.idx_to_ins_mapping = json.load(f)
# def _load_data(self, idx):
# # '''
# raw_img, depth, c, bbox = self._load_lmdb_data(idx)
# # resize depth and bbox
# caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]]
# # st()
# return {
# **self._post_process_sample(raw_img, depth), 'c': c,
# 'bbox': (bbox*(self.reso/512.0)).astype(np.uint8),
# 'caption': caption
# }
# # '''
# # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx)
# # st()
# # return {}
def load_bbox(self, mask):
# st()
nonzero_value = torch.nonzero(mask)
height, width = nonzero_value.max(dim=0)[0]
top, left = nonzero_value.min(dim=0)[0]
bbox = torch.tensor([top, left, height, width], dtype=torch.float32)
return bbox
def __getitem__(self, idx):
raw_img, depth, c, bbox = self._load_lmdb_data(idx)
return {'raw_img': raw_img, 'depth': depth, 'c': c, 'bbox': bbox}
class Objv_LMDBDataset_NV_Compressed(Objv_LMDBDataset_MV_Compressed):
def __init__(self,
lmdb_path,
reso,
reso_encoder,
imgnet_normalize=True,
dataset_size=-1,
**kwargs):
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize,
dataset_size, **kwargs)
def __getitem__(self, idx):
input_view = self._load_data(idx) # get previous input view results
# get novel view of the same instance
try:
novel_view = self._load_data(
(idx // self.instance_data_length) *
self.instance_data_length +
random.randint(0, self.instance_data_length - 1))
except Exception as e:
raise NotImplementedError(idx)
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance'
input_view.update({f'nv_{k}': v for k, v in novel_view.items()})
return input_view
#
# test tar loading
def load_wds_ResampledShard(file_path,
batch_size,
num_workers,
reso,
reso_encoder,
test=False,
preprocess=None,
imgnet_normalize=True,
plucker_embedding=False,
decode_encode_img_only=False,
load_instance=False,
mv_input=False,
split_chunk_input=False,
duplicate_sample=True,
append_depth=False,
gs_cam_format=False,
orthog_duplicate=False,
**kwargs):
# return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
class PostProcess:
def __init__(
self,
reso,
reso_encoder,
imgnet_normalize,
plucker_embedding,
decode_encode_img_only,
mv_input,
split_chunk_input,
duplicate_sample,
append_depth,
gs_cam_format,
orthog_duplicate,
) -> None:
self.gs_cam_format = gs_cam_format
self.append_depth = append_depth
self.plucker_embedding = plucker_embedding
self.decode_encode_img_only = decode_encode_img_only
self.duplicate_sample = duplicate_sample
self.orthog_duplicate = orthog_duplicate
self.zfar = 100.0
self.znear = 0.01
transformations = []
if not split_chunk_input:
transformations.append(transforms.ToTensor())
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
self.reso_encoder = reso_encoder
self.reso = reso
self.instance_data_length = 40
# self.pair_per_instance = 1 # compat
self.mv_input = mv_input
self.split_chunk_input = split_chunk_input # 8
self.chunk_size = 8 if split_chunk_input else 40
# st()
if split_chunk_input:
self.pair_per_instance = 1
else:
self.pair_per_instance = 4 if mv_input else 2 # check whether improves IO
def gen_rays(self, c):
# Generate rays
intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
self.h = self.reso_encoder
self.w = self.reso_encoder
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
indexing='ij')
# normalize to 0-1 pixel range
yy = yy / self.h
xx = xx / self.w
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
0], intrinsics[4]
# cx *= self.w
# cy *= self.h
# f_x = f_y = fx * h / res_raw
c2w = torch.from_numpy(c2w).float()
xx = (xx - cx) / fx
yy = (yy - cy) / fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(-1, 3, 1)
del xx, yy, zz
# st()
dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
origins = origins.view(self.h, self.w, 3)
dirs = dirs.view(self.h, self.w, 3)
return origins, dirs
def _post_process_batch_sample(
self, sample): # sample is an instance batch here
caption, ins = sample[-2:]
instance_samples = []
for instance_idx in range(sample[0].shape[0]):
instance_samples.append(
self._post_process_sample(item[instance_idx]
for item in sample[:-2]))
return (*instance_samples, caption, ins)
def _post_process_sample(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
raw_img, depth, c, bbox = data_sample
bbox = (bbox * (self.reso / 256)).astype(
np.uint8) # normalize bbox to the reso range
if raw_img.shape[-2] != self.reso_encoder:
img_to_encoder = cv2.resize(
raw_img, (self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
else:
img_to_encoder = raw_img
img_to_encoder = self.normalize(img_to_encoder)
if self.plucker_embedding:
rays_o, rays_d = self.gen_rays(c)
rays_plucker = torch.cat(
[torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
if self.decode_encode_img_only:
depth_reso, fg_mask_reso = depth, depth
else:
depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso)
# return {
# # **sample,
# 'img_to_encoder': img_to_encoder,
# 'img': img,
# 'depth_mask': fg_mask_reso,
# # 'img_sr': img_sr,
# 'depth': depth_reso,
# 'c': c,
# 'bbox': bbox,
# 'caption': caption,
# 'ins': ins
# # ! no need to load img_sr for now
# }
# if len(data_sample) == 4:
return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox)
# else:
# return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox, data_sample[-2], data_sample[-1])
def _post_process_sample_batch(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
raw_img, depth, c, bbox = data_sample
bbox = (bbox * (self.reso / 256)).astype(
np.uint8) # normalize bbox to the reso range
assert raw_img.shape[-2] == self.reso_encoder
# img_to_encoder = cv2.resize(
# raw_img, (self.reso_encoder, self.reso_encoder),
# interpolation=cv2.INTER_LANCZOS4)
# else:
# img_to_encoder = raw_img
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1,
2) / 255.0 # [0,1]
img_to_encoder = self.normalize(raw_img)
if self.plucker_embedding:
rays_plucker = []
for idx in range(c.shape[0]):
rays_o, rays_d = self.gen_rays(c[idx])
rays_plucker.append(
torch.cat(
[torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w
rays_plucker = torch.stack(rays_plucker, 0)
img_to_encoder = torch.cat([img_to_encoder, rays_plucker],
1) # concat in C dim
if self.append_depth:
normalized_depth = torch.from_numpy(depth).clone().unsqueeze(
1) # min=0
# normalized_depth -= torch.min(normalized_depth) # always 0 here
# normalized_depth /= torch.max(normalized_depth)
# normalized_depth = normalized_depth.unsqueeze(1) * 2 - 1 # normalize to [-1,1]
img_to_encoder = torch.cat([img_to_encoder, normalized_depth],
1) # concat in C dim
# img = cv2.resize(raw_img, (self.reso, self.reso),
# interpolation=cv2.INTER_LANCZOS4)
# img = torch.from_numpy(raw_img).permute(2, 0, 1) / 127.5 - 1
# st()
if raw_img.shape[-1] != self.reso:
img = torch.nn.functional.interpolate(
input=raw_img,
size=(self.reso, self.reso),
mode='bilinear',
align_corners=False,
) * 2 - 1 # [-1,1] range
else:
img = raw_img * 2 - 1
if self.decode_encode_img_only:
depth_reso, fg_mask_reso = depth, depth
else:
depth_reso, fg_mask_reso = resize_depth_mask_Tensor(
torch.from_numpy(depth), self.reso)
# if not self.gs_cam_format: # otherwise still playing with np format later
c = torch.from_numpy(c)
return (img_to_encoder, img, fg_mask_reso, depth_reso, c,
torch.from_numpy(bbox))
def rand_sample_idx(self):
return random.randint(0, self.instance_data_length - 1)
def rand_pair(self):
return (self.rand_sample_idx() for _ in range(2))
def paired_post_process(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
# expanded_return = []
for _ in range(self.pair_per_instance):
cano_idx, nv_idx = self.rand_pair()
cano_sample = self._post_process_sample(
item[cano_idx] for item in sample[:-2])
nv_sample = self._post_process_sample(item[nv_idx]
for item in sample[:-2])
all_inp_list.extend(cano_sample)
all_nv_list.extend(nv_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
# return [cano_sample, nv_sample, caption, ins]
# return (*cano_sample, *nv_sample, caption, ins)
def get_source_cw2wT(self, source_cameras_view_to_world):
return matrix_to_quaternion(
source_cameras_view_to_world[:3, :3].transpose(0, 1))
def c_to_3dgs_format(self, pose):
# TODO, switch to torch version (batched later)
c2w = pose[:16].reshape(4, 4) # 3x4
# ! load cam
w2c = np.linalg.inv(c2w)
R = np.transpose(
w2c[:3, :3]
) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
fx = pose[16]
FovX = focal2fov(fx, 1)
FovY = focal2fov(fx, 1)
tanfovx = math.tan(FovX * 0.5)
tanfovy = math.tan(FovY * 0.5)
assert tanfovx == tanfovy
trans = np.array([0.0, 0.0, 0.0])
scale = 1.0
view_world_transform = torch.tensor(
getView2World(R, T, trans, scale)).transpose(0, 1)
world_view_transform = torch.tensor(
getWorld2View2(R, T, trans, scale)).transpose(0, 1)
projection_matrix = getProjectionMatrix(znear=self.znear,
zfar=self.zfar,
fovX=FovX,
fovY=FovY).transpose(0, 1)
full_proj_transform = (world_view_transform.unsqueeze(0).bmm(
projection_matrix.unsqueeze(0))).squeeze(0)
camera_center = world_view_transform.inverse()[3, :3]
# item.update(viewpoint_cam=[viewpoint_cam])
c = {}
#
c["source_cv2wT_quat"] = self.get_source_cw2wT(
view_world_transform)
c.update(
# projection_matrix=projection_matrix, # K
cam_view=world_view_transform, # world_view_transform
cam_view_proj=full_proj_transform, # full_proj_transform
cam_pos=camera_center,
tanfov=tanfovx, # TODO, fix in the renderer
orig_pose=torch.from_numpy(pose),
orig_c2w=torch.from_numpy(c2w),
orig_w2c=torch.from_numpy(w2c),
# tanfovy=tanfovy,
)
return c # dict for gs rendering
def paired_post_process_chunk(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
assert sample[0].shape[0] == 8 # random chunks
# expanded_return = []
if self.duplicate_sample:
processed_sample = self._post_process_sample_batch(
item for item in sample[:-2])
if self.orthog_duplicate:
indices = torch.cat([torch.randperm(8),
torch.randperm(8)]) # for now
else:
indices = torch.randperm(8)
shuffle_processed_sample = []
for _, item in enumerate(processed_sample):
shuffle_processed_sample.append(
torch.index_select(item, dim=0, index=indices))
processed_sample = shuffle_processed_sample
if not self.orthog_duplicate:
all_inp_list.extend(item[:4] for item in processed_sample)
all_nv_list.extend(item[4:] for item in processed_sample)
else:
all_inp_list.extend(item[:8] for item in processed_sample)
all_nv_list.extend(item[8:] for item in processed_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
else:
processed_sample = self._post_process_sample_batch( # avoid shuffle shorten processing time
item[:4] for item in sample[:-2])
all_inp_list.extend(item for item in processed_sample)
all_nv_list.extend(
item for item in processed_sample) # ! placeholder
return (*all_inp_list, *all_nv_list, caption, ins)
# randomly shuffle 8 views, avoid overfitting
def single_sample_create_dict(self, sample, prefix=''):
# if len(sample) == 1:
# sample = sample[0]
# assert len(sample) == 6
img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
if self.gs_cam_format:
# TODO, can optimize later after model converges
B, V, _ = c.shape # B 4 25
c = rearrange(c, 'B V C -> (B V) C').cpu().numpy()
all_gs_c = [self.c_to_3dgs_format(pose) for pose in c]
c = {
k:
rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]),
'(B V) ... -> B V ...',
B=B,
V=V) if isinstance(all_gs_c[0][k], torch.Tensor)
else all_gs_c[0][k]
for k in all_gs_c[0].keys()
}
# c = collate_gs_c
return {
# **sample,
f'{prefix}img_to_encoder': img_to_encoder,
f'{prefix}img': img,
f'{prefix}depth_mask': fg_mask_reso,
f'{prefix}depth': depth_reso,
f'{prefix}c': c,
f'{prefix}bbox': bbox,
}
def single_instance_sample_create_dict(self, sample, prfix=''):
assert len(sample) == 42
inp_sample_list = [[] for _ in range(6)]
for item in sample[:40]:
for item_idx in range(6):
inp_sample_list[item_idx].append(item[0][item_idx])
inp_sample = self.single_sample_create_dict(
(torch.stack(item_list) for item_list in inp_sample_list),
prefix='')
return {
**inp_sample, #
'caption': sample[-2],
'ins': sample[-1]
}
def decode_zip(self, sample_pyd, shape=(256, 256)):
if isinstance(sample_pyd, tuple):
sample_pyd = sample_pyd[0]
assert isinstance(sample_pyd, dict)
raw_img = decompress_and_open_image_gzip(
sample_pyd['raw_img'],
is_img=True,
decompress=True,
decompress_fn=lz4.frame.decompress)
caption = sample_pyd['caption'].decode('utf-8')
ins = sample_pyd['ins'].decode('utf-8')
c = decompress_array(sample_pyd['c'], (
self.chunk_size,
25,
),
np.float32,
decompress=True,
decompress_fn=lz4.frame.decompress)
bbox = decompress_array(
sample_pyd['bbox'],
(
self.chunk_size,
4,
),
np.float32,
# decompress=False)
decompress=True,
decompress_fn=lz4.frame.decompress)
if self.decode_encode_img_only:
depth = np.zeros(shape=(self.chunk_size,
*shape)) # save loading time
else:
depth = decompress_array(sample_pyd['depth'],
(self.chunk_size, *shape),
np.float32,
decompress=True,
decompress_fn=lz4.frame.decompress)
# return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c}
# return raw_img, depth, c, bbox, caption, ins
# return raw_img, bbox, caption, ins
# return bbox, caption, ins
return raw_img, depth, c, bbox, caption, ins
# ! run single-instance pipeline first
# return raw_img[0], depth[0], c[0], bbox[0], caption, ins
def create_dict(self, sample):
# sample = [item[0] for item in sample] # wds wrap items in []
# st()
cano_sample_list = [[] for _ in range(6)]
nv_sample_list = [[] for _ in range(6)]
# st()
# bs = (len(sample)-2) // 6
for idx in range(0, self.pair_per_instance):
cano_sample = sample[6 * idx:6 * (idx + 1)]
nv_sample = sample[6 * self.pair_per_instance +
6 * idx:6 * self.pair_per_instance + 6 *
(idx + 1)]
for item_idx in range(6):
cano_sample_list[item_idx].append(cano_sample[item_idx])
nv_sample_list[item_idx].append(nv_sample[item_idx])
# ! cycle input/output view for more pairs
cano_sample_list[item_idx].append(nv_sample[item_idx])
nv_sample_list[item_idx].append(cano_sample[item_idx])
# if self.split_chunk_input:
# cano_sample = self.single_sample_create_dict(
# (torch.cat(item_list, 0) for item_list in cano_sample_list),
# prefix='')
# nv_sample = self.single_sample_create_dict(
# (torch.cat(item_list, 0) for item_list in nv_sample_list),
# prefix='nv_')
# else:
cano_sample = self.single_sample_create_dict(
(torch.cat(item_list, 0) for item_list in cano_sample_list),
prefix='')
nv_sample = self.single_sample_create_dict(
(torch.cat(item_list, 0) for item_list in nv_sample_list),
prefix='nv_')
return {
**cano_sample,
**nv_sample, 'caption': sample[-2],
'ins': sample[-1]
}
def prepare_mv_input(self, sample):
# sample = [item[0] for item in sample] # wds wrap items in []
bs = len(sample['caption']) # number of instances
chunk_size = sample['img'].shape[0] // bs
if self.split_chunk_input:
for k, v in sample.items():
if isinstance(v, torch.Tensor):
sample[k] = rearrange(v,
"b f c ... -> (b f) c ...",
f=4 if not self.orthog_duplicate
else 8).contiguous()
# img = rearrange(sample['img'], "(b f) c h w -> b f c h w", f=4).contiguous()
# gt = rearrange(sample['nv_img'], "(b f) c h w -> b c (f h) w", f=4).contiguous()
# img = rearrange(sample['img'], "b f c h w -> b c (f h) w", f=4).contiguous()
# gt = rearrange(sample['nv_img'], "b f c h w -> b c (f h) w", f=4).contiguous()
# torchvision.utils.save_image(img, 'inp.jpg', normalize=True)
# torchvision.utils.save_image(gt, 'nv.jpg', normalize=True)
# ! shift nv
else:
for k, v in sample.items():
if k not in ['ins', 'caption']:
rolled_idx = torch.LongTensor(
list(
itertools.chain.from_iterable(
list(range(i, sample['img'].shape[0], bs))
for i in range(bs))))
v = torch.index_select(v, dim=0, index=rolled_idx)
sample[k] = v
# img = sample['img']
# gt = sample['nv_img']
# torchvision.utils.save_image(img[0], 'inp.jpg', normalize=True)
# torchvision.utils.save_image(gt[0], 'nv.jpg', normalize=True)
for k, v in sample.items():
if 'nv' in k:
rolled_idx = torch.LongTensor(
list(
itertools.chain.from_iterable(
list(
np.roll(
np.arange(i * chunk_size, (i + 1) *
chunk_size), 4)
for i in range(bs)))))
v = torch.index_select(v, dim=0, index=rolled_idx)
sample[k] = v
# torchvision.utils.save_image(sample['nv_img'], 'nv.png', normalize=True)
# torchvision.utils.save_image(sample['img'], 'inp.png', normalize=True)
return sample
post_process_cls = PostProcess(
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
duplicate_sample=duplicate_sample,
append_depth=append_depth,
gs_cam_format=gs_cam_format,
orthog_duplicate=orthog_duplicate,
)
# ! add shuffling
if isinstance(file_path, list): # lst of shard urls
all_shards = []
for url_path in file_path:
all_shards.extend(wds.shardlists.expand_source(url_path))
logger.log('all_shards', all_shards)
else:
all_shards = file_path # to be expanded
if not load_instance: # during reconstruction training, load pair
if not split_chunk_input:
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(50),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.shuffle(
1000
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple(
"sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
wds.map(post_process_cls.paired_post_process
), # create input-novelview paired samples
# wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
16,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
else:
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(100),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.shuffle(
# 7500 if not duplicate_sample else 2500
# 7500 if not duplicate_sample else 5000
# 1000,
250,
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple(
"sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
wds.map(post_process_cls.paired_post_process_chunk
), # create input-novelview paired samples
# wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
20,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers
# > 0).unbatched().shuffle(1000).batched(batch_size).map(
> 0).unbatched().shuffle(250).batched(batch_size).map(
post_process_cls.create_dict)
if mv_input:
loader_shard = loader_shard.map(post_process_cls.prepare_mv_input)
else: # load single instance during test/eval
assert batch_size == 1
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(50),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.detshuffle(
100
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
# wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples
wds.map(post_process_cls._post_process_batch_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
2,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers
> 0).unbatched().shuffle(200).batched(batch_size).map(
post_process_cls.single_instance_sample_create_dict)
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
# 1000).batched(batch_size).map(post_process_cls.create_dict)
# .map(collate)
# .map(collate)
# .batched(batch_size)
#
# .unbatched().shuffle(1000).batched(batch_size).map(post_process)
# # https://github.com/webdataset/webdataset/issues/187
# return next(iter(loader_shard))
#return dataset
return loader_shard
# test tar loading
def load_wds_diff_ResampledShard(file_path,
batch_size,
num_workers,
reso,
reso_encoder,
test=False,
preprocess=None,
imgnet_normalize=True,
plucker_embedding=False,
decode_encode_img_only=False,
mv_latent_dir='',
**kwargs):
# return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd']
class PostProcess:
def __init__(
self,
reso,
reso_encoder,
imgnet_normalize,
plucker_embedding,
decode_encode_img_only,
mv_latent_dir,
) -> None:
self.plucker_embedding = plucker_embedding
self.mv_latent_dir = mv_latent_dir
self.decode_encode_img_only = decode_encode_img_only
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
self.reso_encoder = reso_encoder
self.reso = reso
self.instance_data_length = 40
# self.pair_per_instance = 1 # compat
self.pair_per_instance = 2 # check whether improves IO
# self.pair_per_instance = 3 # check whether improves IO
# self.pair_per_instance = 4 # check whether improves IO
def get_rays_kiui(self, c, opengl=True):
h, w = self.reso_encoder, self.reso_encoder
intrinsics, pose = c[16:], c[:16].reshape(4, 4)
# cx, cy, fx, fy = intrinsics[2], intrinsics[5]
fx = fy = 525 # pixel space
cx = cy = 256 # rendering default K
factor = self.reso / (cx * 2) # 128 / 512
fx = fx * factor
fy = fy * factor
x, y = torch.meshgrid(
torch.arange(w, device=pose.device),
torch.arange(h, device=pose.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()
cx = w * 0.5
cy = h * 0.5
# focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
camera_dirs = F.pad(
torch.stack(
[
(x - cx + 0.5) / fx,
(y - cy + 0.5) / fy * (-1.0 if opengl else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if opengl else 1.0),
) # [hw, 3]
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3]
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]
rays_o = rays_o.view(h, w, 3)
rays_d = safe_normalize(rays_d).view(h, w, 3)
return rays_o, rays_d
def gen_rays(self, c):
# Generate rays
intrinsics, c2w = c[16:], c[:16].reshape(4, 4)
self.h = self.reso_encoder
self.w = self.reso_encoder
yy, xx = torch.meshgrid(
torch.arange(self.h, dtype=torch.float32) + 0.5,
torch.arange(self.w, dtype=torch.float32) + 0.5,
indexing='ij')
# normalize to 0-1 pixel range
yy = yy / self.h
xx = xx / self.w
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3)
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[
0], intrinsics[4]
# cx *= self.w
# cy *= self.h
# f_x = f_y = fx * h / res_raw
c2w = torch.from_numpy(c2w).float()
xx = (xx - cx) / fx
yy = (yy - cy) / fy
zz = torch.ones_like(xx)
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention
dirs /= torch.norm(dirs, dim=-1, keepdim=True)
dirs = dirs.reshape(-1, 3, 1)
del xx, yy, zz
# st()
dirs = (c2w[None, :3, :3] @ dirs)[..., 0]
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous()
origins = origins.view(self.h, self.w, 3)
dirs = dirs.view(self.h, self.w, 3)
return origins, dirs
def _post_process_sample(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
raw_img, c, caption, ins = data_sample
# bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range
# if raw_img.shape[-2] != self.reso_encoder:
# img_to_encoder = cv2.resize(
# raw_img, (self.reso_encoder, self.reso_encoder),
# interpolation=cv2.INTER_LANCZOS4)
# else:
# img_to_encoder = raw_img
# img_to_encoder = self.normalize(img_to_encoder)
# if self.plucker_embedding:
# rays_o, rays_d = self.gen_rays(c)
# rays_plucker = torch.cat(
# [torch.cross(rays_o, rays_d, dim=-1), rays_d],
# dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
# img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
# img = cv2.resize(raw_img, (self.reso, self.reso),
# interpolation=cv2.INTER_LANCZOS4)
img = raw_img # 256x256
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
# load latent
latent_path = Path(self.mv_latent_dir, ins, 'latent.npy')
latent = np.load(latent_path)
# return (img_to_encoder, img, c, caption, ins)
return (latent, img, c, caption, ins)
def rand_sample_idx(self):
return random.randint(0, self.instance_data_length - 1)
def rand_pair(self):
return (self.rand_sample_idx() for _ in range(2))
def paired_post_process(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
# expanded_return = []
for _ in range(self.pair_per_instance):
cano_idx, nv_idx = self.rand_pair()
cano_sample = self._post_process_sample(
item[cano_idx] for item in sample[:-2])
nv_sample = self._post_process_sample(item[nv_idx]
for item in sample[:-2])
all_inp_list.extend(cano_sample)
all_nv_list.extend(nv_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
# return [cano_sample, nv_sample, caption, ins]
# return (*cano_sample, *nv_sample, caption, ins)
# def single_sample_create_dict(self, sample, prefix=''):
# # if len(sample) == 1:
# # sample = sample[0]
# # assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
# return {
# # **sample,
# f'{prefix}img_to_encoder': img_to_encoder,
# f'{prefix}img': img,
# f'{prefix}depth_mask': fg_mask_reso,
# f'{prefix}depth': depth_reso,
# f'{prefix}c': c,
# f'{prefix}bbox': bbox,
# }
def single_sample_create_dict(self, sample, prefix=''):
# if len(sample) == 1:
# sample = sample[0]
# assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
# img_to_encoder, img, c, caption, ins = sample
# img, c, caption, ins = sample
latent, img, c, caption, ins = sample
# load latent
return {
# **sample,
# 'img_to_encoder': img_to_encoder,
'latent': latent,
'img': img,
'c': c,
'caption': caption,
'ins': ins
}
def decode_zip(self, sample_pyd, shape=(256, 256)):
if isinstance(sample_pyd, tuple):
sample_pyd = sample_pyd[0]
assert isinstance(sample_pyd, dict)
raw_img = decompress_and_open_image_gzip(
sample_pyd['raw_img'],
is_img=True,
decompress=True,
decompress_fn=lz4.frame.decompress)
caption = sample_pyd['caption'].decode('utf-8')
ins = sample_pyd['ins'].decode('utf-8')
c = decompress_array(sample_pyd['c'], (25, ),
np.float32,
decompress=True,
decompress_fn=lz4.frame.decompress)
# bbox = decompress_array(
# sample_pyd['bbox'],
# (
# 40,
# 4,
# ),
# np.float32,
# # decompress=False)
# decompress=True,
# decompress_fn=lz4.frame.decompress)
# if self.decode_encode_img_only:
# depth = np.zeros(shape=(40, *shape)) # save loading time
# else:
# depth = decompress_array(sample_pyd['depth'], (40, *shape),
# np.float32,
# decompress=True,
# decompress_fn=lz4.frame.decompress)
# return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c}
# return raw_img, depth, c, bbox, caption, ins
# return raw_img, bbox, caption, ins
# return bbox, caption, ins
return raw_img, c, caption, ins
# ! run single-instance pipeline first
# return raw_img[0], depth[0], c[0], bbox[0], caption, ins
def create_dict(self, sample):
# sample = [item[0] for item in sample] # wds wrap items in []
# cano_sample_list = [[] for _ in range(6)]
# nv_sample_list = [[] for _ in range(6)]
# for idx in range(0, self.pair_per_instance):
# cano_sample = sample[6*idx:6*(idx+1)]
# nv_sample = sample[6*self.pair_per_instance+6*idx:6*self.pair_per_instance+6*(idx+1)]
# for item_idx in range(6):
# cano_sample_list[item_idx].append(cano_sample[item_idx])
# nv_sample_list[item_idx].append(nv_sample[item_idx])
# # ! cycle input/output view for more pairs
# cano_sample_list[item_idx].append(nv_sample[item_idx])
# nv_sample_list[item_idx].append(cano_sample[item_idx])
cano_sample = self.single_sample_create_dict(sample, prefix='')
# nv_sample = self.single_sample_create_dict((torch.cat(item_list) for item_list in nv_sample_list) , prefix='nv_')
return cano_sample
# return {
# **cano_sample,
# # **nv_sample,
# 'caption': sample[-2],
# 'ins': sample[-1]
# }
post_process_cls = PostProcess(
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_latent_dir=mv_latent_dir,
)
if isinstance(file_path, list): # lst of shard urls
all_shards = []
for url_path in file_path:
all_shards.extend(wds.shardlists.expand_source(url_path))
logger.log('all_shards', all_shards)
else:
all_shards = file_path # to be expanded
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(50),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.detshuffle(
15000
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
# wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples
wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
80,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers
> 0).unbatched().shuffle(2500).batched(batch_size).map(
post_process_cls.create_dict)
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
# 1000).batched(batch_size).map(post_process_cls.create_dict)
# .map(collate)
# .map(collate)
# .batched(batch_size)
#
# .unbatched().shuffle(1000).batched(batch_size).map(post_process)
# # https://github.com/webdataset/webdataset/issues/187
# return next(iter(loader_shard))
#return dataset
return loader_shard
def load_wds_data(
file_path="",
reso=64,
reso_encoder=224,
batch_size=1,
num_workers=6,
plucker_embedding=False,
decode_encode_img_only=False,
load_wds_diff=False,
load_wds_latent=False,
load_instance=False, # for evaluation
mv_input=False,
split_chunk_input=False,
duplicate_sample=True,
mv_latent_dir='',
append_depth=False,
gs_cam_format=False,
orthog_duplicate=False,
**args):
if load_wds_diff:
assert num_workers == 0 # on aliyun, worker=0 performs much much faster
wds_loader = load_wds_diff_ResampledShard(
file_path,
batch_size=batch_size,
num_workers=num_workers,
reso=reso,
reso_encoder=reso_encoder,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
append_depth=append_depth,
mv_latent_dir=mv_latent_dir,
gs_cam_format=gs_cam_format,
orthog_duplicate=orthog_duplicate,
)
elif load_wds_latent:
# for diffusion training, cache latent
wds_loader = load_wds_latent_ResampledShard(
file_path,
batch_size=batch_size,
num_workers=num_workers,
reso=reso,
reso_encoder=reso_encoder,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
)
# elif load_instance:
# wds_loader = load_wds_instance_ResampledShard(
# file_path,
# batch_size=batch_size,
# num_workers=num_workers,
# reso=reso,
# reso_encoder=reso_encoder,
# plucker_embedding=plucker_embedding,
# decode_encode_img_only=decode_encode_img_only
# )
else:
wds_loader = load_wds_ResampledShard(
file_path,
batch_size=batch_size,
num_workers=num_workers,
reso=reso,
reso_encoder=reso_encoder,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
load_instance=load_instance,
mv_input=mv_input,
split_chunk_input=split_chunk_input,
duplicate_sample=duplicate_sample,
append_depth=append_depth,
gs_cam_format=gs_cam_format,
orthog_duplicate=orthog_duplicate,
)
while True:
yield from wds_loader
# yield from wds_loader
class PostProcess_forlatent:
def __init__(
self,
reso,
reso_encoder,
imgnet_normalize,
plucker_embedding,
decode_encode_img_only,
) -> None:
self.plucker_embedding = plucker_embedding
self.decode_encode_img_only = decode_encode_img_only
transformations = [
transforms.ToTensor(), # [0,1] range
]
if imgnet_normalize:
transformations.append(
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225)) # type: ignore
)
else:
transformations.append(
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))) # type: ignore
self.normalize = transforms.Compose(transformations)
self.reso_encoder = reso_encoder
self.reso = reso
self.instance_data_length = 40
# self.pair_per_instance = 1 # compat
self.pair_per_instance = 2 # check whether improves IO
# self.pair_per_instance = 3 # check whether improves IO
# self.pair_per_instance = 4 # check whether improves IO
def _post_process_sample(self, data_sample):
# raw_img, depth, c, bbox, caption, ins = data_sample
raw_img, c, caption, ins = data_sample
# bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range
if raw_img.shape[-2] != self.reso_encoder:
img_to_encoder = cv2.resize(
raw_img, (self.reso_encoder, self.reso_encoder),
interpolation=cv2.INTER_LANCZOS4)
else:
img_to_encoder = raw_img
img_to_encoder = self.normalize(img_to_encoder)
if self.plucker_embedding:
rays_o, rays_d = self.gen_rays(c)
rays_plucker = torch.cat(
[torch.cross(rays_o, rays_d, dim=-1), rays_d],
dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w
img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0)
img = cv2.resize(raw_img, (self.reso, self.reso),
interpolation=cv2.INTER_LANCZOS4)
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1
return (img_to_encoder, img, c, caption, ins)
def rand_sample_idx(self):
return random.randint(0, self.instance_data_length - 1)
def rand_pair(self):
return (self.rand_sample_idx() for _ in range(2))
def paired_post_process(self, sample):
# repeat n times?
all_inp_list = []
all_nv_list = []
caption, ins = sample[-2:]
# expanded_return = []
for _ in range(self.pair_per_instance):
cano_idx, nv_idx = self.rand_pair()
cano_sample = self._post_process_sample(
item[cano_idx] for item in sample[:-2])
nv_sample = self._post_process_sample(item[nv_idx]
for item in sample[:-2])
all_inp_list.extend(cano_sample)
all_nv_list.extend(nv_sample)
return (*all_inp_list, *all_nv_list, caption, ins)
# return [cano_sample, nv_sample, caption, ins]
# return (*cano_sample, *nv_sample, caption, ins)
# def single_sample_create_dict(self, sample, prefix=''):
# # if len(sample) == 1:
# # sample = sample[0]
# # assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
# return {
# # **sample,
# f'{prefix}img_to_encoder': img_to_encoder,
# f'{prefix}img': img,
# f'{prefix}depth_mask': fg_mask_reso,
# f'{prefix}depth': depth_reso,
# f'{prefix}c': c,
# f'{prefix}bbox': bbox,
# }
def single_sample_create_dict(self, sample, prefix=''):
# if len(sample) == 1:
# sample = sample[0]
# assert len(sample) == 6
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample
img_to_encoder, img, c, caption, ins = sample
return {
# **sample,
'img_to_encoder': img_to_encoder,
'img': img,
'c': c,
'caption': caption,
'ins': ins
}
def decode_zip(self, sample_pyd, shape=(256, 256)):
if isinstance(sample_pyd, tuple):
sample_pyd = sample_pyd[0]
assert isinstance(sample_pyd, dict)
latent = sample_pyd['latent']
caption = sample_pyd['caption'].decode('utf-8')
c = sample_pyd['c']
# img = sample_pyd['img']
# st()
return latent, caption, c
def create_dict(self, sample):
return {
# **sample,
'latent': sample[0],
'caption': sample[1],
'c': sample[2],
}
# test tar loading
def load_wds_latent_ResampledShard(file_path,
batch_size,
num_workers,
reso,
reso_encoder,
test=False,
preprocess=None,
imgnet_normalize=True,
plucker_embedding=False,
decode_encode_img_only=False,
**kwargs):
post_process_cls = PostProcess_forlatent(
reso,
reso_encoder,
imgnet_normalize=imgnet_normalize,
plucker_embedding=plucker_embedding,
decode_encode_img_only=decode_encode_img_only,
)
if isinstance(file_path, list): # lst of shard urls
all_shards = []
for url_path in file_path:
all_shards.extend(wds.shardlists.expand_source(url_path))
logger.log('all_shards', all_shards)
else:
all_shards = file_path # to be expanded
dataset = wds.DataPipeline(
wds.ResampledShards(all_shards), # url_shard
# at this point we have an iterator over all the shards
wds.shuffle(50),
wds.split_by_worker, # if multi-node
wds.tarfile_to_samples(),
# add wds.split_by_node here if you are using multiple nodes
wds.detshuffle(
2500
), # shuffles in the memory, leverage large RAM for more efficient loading
wds.decode(wds.autodecode.basichandlers), # TODO
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict
wds.map(post_process_cls.decode_zip),
# wds.map(post_process_cls._post_process_sample),
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading
wds.batched(
150,
partial=True,
# collation_fn=collate
) # streaming more data at once, and rebatch later
)
loader_shard = wds.WebLoader(
dataset,
num_workers=num_workers,
drop_last=False,
batch_size=None,
shuffle=False,
persistent_workers=num_workers
> 0).unbatched().shuffle(1000).batched(batch_size).map(
post_process_cls.create_dict)
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict)
# 1000).batched(batch_size).map(post_process_cls.create_dict)
# .map(collate)
# .map(collate)
# .batched(batch_size)
#
# .unbatched().shuffle(1000).batched(batch_size).map(post_process)
# # https://github.com/webdataset/webdataset/issues/187
# return next(iter(loader_shard))
#return dataset
return loader_shard