Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import kiui | |
from kiui.op import recenter | |
import collections | |
import math | |
import time | |
import itertools | |
import pickle | |
from typing import Any | |
import lmdb | |
import cv2 | |
import imageio | |
import numpy as np | |
from PIL import Image | |
import Imath | |
import OpenEXR | |
from pdb import set_trace as st | |
from pathlib import Path | |
import torchvision | |
from nsr.camera_utils import generate_input_camera | |
from einops import rearrange, repeat | |
from functools import partial | |
import io | |
import gzip | |
import random | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
from torch.utils.data.distributed import DistributedSampler | |
from pathlib import Path | |
import lz4.frame | |
import torch.multiprocessing | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
from utils.general_utils import PILtoTorch, matrix_to_quaternion | |
from guided_diffusion import logger | |
import json | |
import webdataset as wds | |
from .shapenet import LMDBDataset, LMDBDataset_MV_Compressed, decompress_and_open_image_gzip, decompress_array | |
from kiui.op import safe_normalize | |
from utils.gs_utils.graphics_utils import getWorld2View2, getProjectionMatrix, getView2World | |
def fov2focal(fov, pixels): | |
return pixels / (2 * math.tan(fov / 2)) | |
def focal2fov(focal, pixels): | |
return 2 * math.atan(pixels / (2 * focal)) | |
def resize_depth_mask(depth_to_resize, resolution): | |
depth_resized = cv2.resize(depth_to_resize, (resolution, resolution), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
return depth_resized, depth_resized > 0 # type: ignore | |
def resize_depth_mask_Tensor(depth_to_resize, resolution): | |
if depth_to_resize.shape[-1] != resolution: | |
depth_resized = torch.nn.functional.interpolate( | |
input=depth_to_resize.unsqueeze(1), | |
size=(resolution, resolution), | |
# mode='bilinear', | |
mode='nearest', | |
align_corners=False, | |
).squeeze(1) | |
else: | |
depth_resized = depth_to_resize | |
return depth_resized, depth_resized > 0 # type: ignore | |
def load_dataset( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
trainer_name='input_rec', | |
use_lmdb=False, | |
use_wds=False, | |
use_lmdb_compressed=False, | |
infi_sampler=True): | |
# st() | |
# dataset_cls = { | |
# 'input_rec': MultiViewDataset, | |
# 'nv': NovelViewDataset, | |
# }[trainer_name] | |
# st() | |
if use_wds: | |
return load_wds_data(file_path, reso, reso_encoder, batch_size, | |
num_workers) | |
if use_lmdb: | |
logger.log('using LMDB dataset') | |
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
if use_lmdb_compressed: | |
if 'nv' in trainer_name: | |
dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
if 'nv' in trainer_name: | |
dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
# dataset = dataset_cls(file_path) | |
else: | |
if 'nv' in trainer_name: | |
dataset_cls = NovelViewObjverseDataset | |
else: | |
dataset_cls = MultiViewObjverseDataset # 1.5-2iter/s | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
dataset_size=dataset_size) | |
logger.log('dataset_cls: {}, dataset size: {}'.format( | |
trainer_name, len(dataset))) | |
loader = DataLoader(dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
shuffle=False) | |
return loader | |
def load_data( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
trainer_name='input_rec', | |
use_lmdb=False, | |
use_wds=False, | |
use_lmdb_compressed=False, | |
plucker_embedding=False, | |
infi_sampler=True): | |
if use_lmdb: | |
logger.log('using LMDB dataset') | |
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
if use_lmdb_compressed: | |
if 'nv' in trainer_name: | |
dataset_cls = Objv_LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
if 'nv' in trainer_name: | |
dataset_cls = Objv_LMDBDataset_NV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
dataset_cls = Objv_LMDBDataset_MV_NoCompressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
if 'nv' in trainer_name: | |
dataset_cls = NovelViewObjverseDataset # 1.5-2iter/s | |
else: | |
dataset_cls = MultiViewObjverseDataset | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
dataset_size=dataset_size, | |
plucker_embedding=plucker_embedding) | |
logger.log('dataset_cls: {}, dataset size: {}'.format( | |
trainer_name, len(dataset))) | |
# st() | |
if infi_sampler: | |
train_sampler = DistributedSampler(dataset=dataset, | |
shuffle=True, | |
drop_last=True) | |
loader = DataLoader(dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=True, | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
sampler=train_sampler) | |
while True: | |
yield from loader | |
# else: | |
# # loader = DataLoader(dataset, | |
# # batch_size=batch_size, | |
# # num_workers=num_workers, | |
# # drop_last=False, | |
# # pin_memory=True, | |
# # persistent_workers=num_workers > 0, | |
# # shuffle=False) | |
# st() | |
# return dataset | |
def load_eval_data( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
num_workers=1, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
interval=1, | |
use_lmdb=False, | |
plucker_embedding=False, | |
load_real=False, | |
four_view_for_latent=False, | |
shuffle_across_cls=False, | |
load_extra_36_view=False, | |
gs_cam_format=False, | |
single_view_for_i23d=False, | |
**kwargs, | |
): | |
if use_lmdb: | |
logger.log('using LMDB dataset') | |
dataset_cls = Objv_LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=True, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
interval=interval) | |
elif load_real: | |
dataset = RealDataset(file_path, | |
# dataset = RealMVDataset(file_path, | |
reso, | |
reso_encoder, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
test=True, | |
imgnet_normalize=imgnet_normalize, | |
interval=interval, | |
plucker_embedding=plucker_embedding) | |
else: | |
dataset = MultiViewObjverseDataset( | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
test=True, | |
imgnet_normalize=imgnet_normalize, | |
interval=interval, | |
plucker_embedding=plucker_embedding, | |
four_view_for_latent=four_view_for_latent, | |
load_extra_36_view=load_extra_36_view, | |
shuffle_across_cls=shuffle_across_cls, | |
gs_cam_format=gs_cam_format, | |
single_view_for_i23d=single_view_for_i23d, | |
) | |
print('eval dataset size: {}'.format(len(dataset))) | |
# train_sampler = DistributedSampler(dataset=dataset) | |
loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
shuffle=False, | |
) | |
# sampler=train_sampler) | |
return loader | |
def load_data_for_lmdb( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
trainer_name='input_rec', | |
shuffle_across_cls=False, | |
four_view_for_latent=False, | |
wds_split=1): | |
# st() | |
# dataset_cls = { | |
# 'input_rec': MultiViewDataset, | |
# 'nv': NovelViewDataset, | |
# }[trainer_name] | |
# if 'nv' in trainer_name: | |
# dataset_cls = NovelViewDataset | |
# else: | |
# dataset_cls = MultiViewDataset | |
# st() | |
dataset_cls = MultiViewObjverseDatasetforLMDB | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
dataset_size=dataset_size, | |
shuffle_across_cls=shuffle_across_cls, | |
wds_split=wds_split, | |
four_view_for_latent=four_view_for_latent) | |
logger.log('dataset_cls: {}, dataset size: {}'.format( | |
trainer_name, len(dataset))) | |
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) | |
loader = DataLoader( | |
dataset, | |
shuffle=False, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
# prefetch_factor=2, | |
# prefetch_factor=3, | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
) | |
# sampler=train_sampler) | |
# while True: | |
# yield from loader | |
return loader, dataset.dataset_name, len(dataset) | |
def load_lmdb_for_lmdb( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
trainer_name='input_rec'): | |
# st() | |
# dataset_cls = { | |
# 'input_rec': MultiViewDataset, | |
# 'nv': NovelViewDataset, | |
# }[trainer_name] | |
# if 'nv' in trainer_name: | |
# dataset_cls = NovelViewDataset | |
# else: | |
# dataset_cls = MultiViewDataset | |
# st() | |
dataset_cls = Objv_LMDBDataset_MV_Compressed_for_lmdb | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
dataset_size=dataset_size) | |
logger.log('dataset_cls: {}, dataset size: {}'.format( | |
trainer_name, len(dataset))) | |
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) | |
loader = DataLoader( | |
dataset, | |
shuffle=False, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
prefetch_factor=2, | |
# prefetch_factor=3, | |
pin_memory=True, | |
persistent_workers=True, | |
) | |
# sampler=train_sampler) | |
# while True: | |
# yield from loader | |
return loader, len(dataset) | |
def load_memory_data( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
num_workers=1, | |
# load_depth=True, | |
preprocess=None, | |
imgnet_normalize=True, | |
**kwargs): | |
# load a single-instance into the memory to speed up training IO | |
# dataset = MultiViewObjverseDataset(file_path, | |
dataset = NovelViewObjverseDataset(file_path, | |
reso, | |
reso_encoder, | |
preprocess=preprocess, | |
load_depth=True, | |
test=False, | |
overfitting=True, | |
imgnet_normalize=imgnet_normalize, | |
overfitting_bs=batch_size, | |
**kwargs) | |
logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset))) | |
# train_sampler = DistributedSampler(dataset=dataset) | |
loader = DataLoader( | |
dataset, | |
batch_size=len(dataset), | |
num_workers=num_workers, | |
drop_last=False, | |
shuffle=False, | |
) | |
all_data: dict = next( | |
iter(loader) | |
) # torchvision.utils.save_image(all_data['img'], 'gt.jpg', normalize=True, value_range=(-1,1)) | |
if kwargs.get('gs_cam_format', False): # gs rendering pipeline | |
# ! load V=4 images for training in a batch. | |
while True: | |
# indices = torch.randperm(len(dataset))[:4] | |
indices = torch.randperm( | |
len(dataset))[:batch_size] # all instances | |
# indices2 = torch.randperm(len(dataset))[:] # all instances | |
batch_c = collections.defaultdict(dict) | |
for k in ['c', 'nv_c']: | |
for k_c, v_c in all_data[k].items(): | |
batch_c[k][k_c] = torch.index_select( | |
v_c, dim=0, index=indices).reshape( | |
batch_size // | |
4, 4, *v_c.shape[1:]).float() if isinstance( | |
v_c, torch.Tensor) else v_c.float() # float | |
batch_c['c']['tanfov'] = batch_c['c']['tanfov'][0][0].item() | |
batch_c['nv_c']['tanfov'] = batch_c['nv_c']['tanfov'][0][0].item() | |
batch_data = {} | |
for k, v in all_data.items(): | |
if k not in ['c', 'nv_c']: | |
batch_data[k] = torch.index_select( | |
v, dim=0, index=indices).float() if isinstance( | |
v, torch.Tensor) else v # float | |
yield { | |
**batch_data, | |
**batch_c, | |
} | |
else: | |
while True: | |
start_idx = np.random.randint(0, len(dataset) - batch_size + 1) | |
yield { | |
k: v[start_idx:start_idx + batch_size] | |
for k, v in all_data.items() | |
} | |
def read_dnormal(normald_path, cond_pos, h=None, w=None): | |
cond_cam_dis = np.linalg.norm(cond_pos, 2) | |
near = 0.867 #sqrt(3) * 0.5 | |
near_distance = cond_cam_dis - near | |
normald = cv2.imread(normald_path, cv2.IMREAD_UNCHANGED).astype(np.float32) | |
depth = normald[..., 3:] | |
depth[depth < near_distance] = 0 | |
if h is not None: | |
assert w is not None | |
# depth = cv2.resize(depth, (h, w)) # 512,512, 1 -> self.reso, self.reso | |
depth = cv2.resize(depth, (h, w), interpolation=cv2.INTER_NEAREST | |
) # 512,512, 1 -> self.reso, self.reso | |
else: | |
depth = depth[..., 0] | |
return torch.from_numpy(depth).float() | |
def get_intri(target_im=None, h=None, w=None, normalize=False): | |
if target_im is None: | |
assert (h is not None and w is not None) | |
else: | |
h, w = target_im.shape[:2] | |
fx = fy = 1422.222 | |
res_raw = 1024 | |
f_x = f_y = fx * h / res_raw | |
K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
if normalize: # center is [0.5, 0.5], eg3d renderer tradition | |
K[:6] /= h | |
# print("intr: ", K) | |
return K | |
def convert_pose(C2W): | |
# https://github.com/modelscope/richdreamer/blob/c3d9a77fa15fc42dbae12c2d41d64aaec14efd37/dataset/gobjaverse/depth_warp_example.py#L402 | |
flip_yz = np.eye(4) | |
flip_yz[1, 1] = -1 | |
flip_yz[2, 2] = -1 | |
C2W = np.matmul(C2W, flip_yz) | |
return torch.from_numpy(C2W) | |
def read_camera_matrix_single(json_file): | |
with open(json_file, 'r', encoding='utf8') as reader: | |
json_content = json.load(reader) | |
''' | |
# NOTE that different from unity2blender experiments. | |
camera_matrix = np.eye(4) | |
camera_matrix[:3, 0] = np.array(json_content['x']) | |
camera_matrix[:3, 1] = -np.array(json_content['y']) | |
camera_matrix[:3, 2] = -np.array(json_content['z']) | |
camera_matrix[:3, 3] = np.array(json_content['origin']) | |
''' | |
camera_matrix = np.eye(4) # blender-based | |
camera_matrix[:3, 0] = np.array(json_content['x']) | |
camera_matrix[:3, 1] = np.array(json_content['y']) | |
camera_matrix[:3, 2] = np.array(json_content['z']) | |
camera_matrix[:3, 3] = np.array(json_content['origin']) | |
# print(camera_matrix) | |
# ''' | |
# return convert_pose(camera_matrix) | |
return camera_matrix | |
def unity2blender(normal): | |
normal_clone = normal.copy() | |
normal_clone[..., 0] = -normal[..., -1] | |
normal_clone[..., 1] = -normal[..., 0] | |
normal_clone[..., 2] = normal[..., 1] | |
return normal_clone | |
def blender2midas(img): | |
'''Blender: rub | |
midas: lub | |
''' | |
img[..., 0] = -img[..., 0] | |
img[..., 1] = -img[..., 1] | |
img[..., -1] = -img[..., -1] | |
return img | |
def current_milli_time(): | |
return round(time.time() * 1000) | |
# modified from ShapeNet class | |
class MultiViewObjverseDataset(Dataset): | |
def __init__( | |
self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1, | |
interval=1, | |
plucker_embedding=False, | |
shuffle_across_cls=False, | |
wds_split=1, # 4 splits to accelerate preprocessing | |
four_view_for_latent=False, | |
single_view_for_i23d=False, | |
load_extra_36_view=False, | |
gs_cam_format=False, | |
**kwargs): | |
self.load_extra_36_view = load_extra_36_view | |
# st() | |
self.gs_cam_format = gs_cam_format | |
self.four_view_for_latent = four_view_for_latent # export 0 12 30 36, 4 views for reconstruction | |
self.single_view_for_i23d = single_view_for_i23d | |
self.file_path = file_path | |
self.overfitting = overfitting | |
self.scene_scale = scene_scale | |
self.reso = reso | |
self.reso_encoder = reso_encoder | |
self.classes = False | |
self.load_depth = load_depth | |
self.preprocess = preprocess | |
self.plucker_embedding = plucker_embedding | |
self.intrinsics = get_intri(h=self.reso, w=self.reso, | |
normalize=True).reshape(9) | |
assert not self.classes, "Not support class condition now." | |
dataset_name = Path(self.file_path).stem.split('_')[0] | |
self.dataset_name = dataset_name | |
self.zfar = 100.0 | |
self.znear = 0.01 | |
# if test: | |
# self.ins_list = sorted(os.listdir(self.file_path))[0:1] # the first 1 instance for evaluation reference. | |
# else: | |
# ! TODO, read from list? | |
def load_single_cls_instances(file_path): | |
ins_list = [] # the first 1 instance for evaluation reference. | |
for dict_dir in os.listdir(file_path)[:]: | |
for ins_dir in os.listdir(os.path.join(file_path, dict_dir)): | |
# self.ins_list.append(os.path.join(self.file_path, dict_dir, ins_dir,)) | |
ins_list.append( | |
os.path.join(file_path, dict_dir, ins_dir, | |
'campos_512_v4')) | |
return ins_list | |
if shuffle_across_cls: | |
self.ins_list = [] | |
# for subset in ['Animals', 'Transportations_tar', 'Furnitures']: | |
# for subset in ['Furnitures']: | |
# selected subset for training | |
for subset in [ # ! around 17W instances in total. MVImageNet is the next thing to deal with? Later. | |
# 'daily-used', | |
# 'Food', | |
# 'Plants', | |
# 'Electronics', | |
# 'BuildingsOutdoor', | |
# 'Human-Shape', | |
'Animals', | |
# 'Transportations_tar', | |
# 'Furnitures', | |
]: # selected subset for training | |
self.ins_list += load_single_cls_instances( | |
os.path.join(self.file_path, subset)) | |
# st() | |
current_time = int(current_milli_time() | |
) # randomly shuffle given current time | |
random.seed(current_time) | |
random.shuffle(self.ins_list) | |
else: # preprocess single class | |
self.ins_list = load_single_cls_instances(self.file_path) | |
self.ins_list = sorted(self.ins_list) | |
# if test: | |
# self.ins_list = self.ins_list[0:1] | |
if overfitting: | |
self.ins_list = self.ins_list[:1] | |
self.rgb_list = [] | |
self.pose_list = [] | |
self.depth_list = [] | |
self.data_ins_list = [] | |
self.instance_data_length = -1 | |
with open( | |
# '/mnt/yslan/objaverse/richdreamer/dataset/text_captions_cap3d.json', | |
'./datasets/text_captions_cap3d.json', | |
) as f: | |
self.caption_data = json.load(f) | |
self.shuffle_across_cls = shuffle_across_cls | |
# for ins in self.ins_list[47000:]: | |
if four_view_for_latent: | |
self.wds_split_all = 1 # ! when dumping latent | |
ins_list_to_process = self.ins_list | |
else: | |
self.wds_split_all = 4 | |
# self.wds_split_all = 8 # ! 8 cls in total | |
all_ins_size = len(self.ins_list) | |
ratio_size = all_ins_size // self.wds_split_all + 1 | |
ins_list_to_process = self.ins_list[ratio_size * | |
(wds_split - 1):ratio_size * | |
wds_split] | |
# st() | |
for ins in ins_list_to_process: | |
# ins = os.path.join( | |
# # self.file_path, ins , 'campos_512_v4' | |
# self.file_path, ins , | |
# # 'compos_512_v4' | |
# ) | |
# cur_rgb_path = os.path.join(self.file_path, ins, 'compos_512_v4') | |
# cur_pose_path = os.path.join(self.file_path, ins, 'pose') | |
# st() | |
# ][:27]) | |
if self.four_view_for_latent: | |
# ! v=6 version infer latent | |
cur_all_fname = [f'{idx:05d}' for idx in [25,0,9,18,27,33]] | |
elif self.single_view_for_i23d: | |
cur_all_fname = [f'{idx:05d}' | |
for idx in [2]] # ! furniture side view | |
else: | |
cur_all_fname = [t.split('.')[0] for t in os.listdir(ins) | |
] # use full set for training | |
if shuffle_across_cls: | |
random.seed(current_time) | |
random.shuffle(cur_all_fname) | |
else: | |
cur_all_fname = sorted(cur_all_fname) | |
if self.instance_data_length == -1: | |
self.instance_data_length = len(cur_all_fname) | |
else: | |
try: # data missing? | |
assert len(cur_all_fname) == self.instance_data_length | |
except: | |
with open('missing_ins.txt', 'a') as f: | |
f.write(str(Path(ins.parent)) + | |
'\n') # remove the "campos_512_v4" | |
continue | |
self.pose_list += ([ | |
os.path.join(ins, fname, fname + '.json') | |
for fname in cur_all_fname | |
]) | |
self.rgb_list += ([ | |
os.path.join(ins, fname, fname + '.png') | |
for fname in cur_all_fname | |
]) | |
self.depth_list += ([ | |
os.path.join(ins, fname, fname + '_nd.exr') | |
for fname in cur_all_fname | |
]) | |
self.data_ins_list += ([ins] * len(cur_all_fname)) | |
# check | |
# ! setup normalizataion | |
transformations = [ | |
transforms.ToTensor(), # [0,1] range | |
] | |
if imgnet_normalize: | |
transformations.append( | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225)) # type: ignore | |
) | |
else: | |
transformations.append( | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))) # type: ignore | |
self.normalize = transforms.Compose(transformations) | |
def get_source_cw2wT(self, source_cameras_view_to_world): | |
return matrix_to_quaternion( | |
source_cameras_view_to_world[:3, :3].transpose(0, 1)) | |
def c_to_3dgs_format(self, pose): | |
# TODO, switch to torch version (batched later) | |
c2w = pose[:16].reshape(4, 4) # 3x4 | |
# ! load cam | |
w2c = np.linalg.inv(c2w) | |
R = np.transpose( | |
w2c[:3, :3]) # R is stored transposed due to 'glm' in CUDA code | |
T = w2c[:3, 3] | |
fx = pose[16] | |
FovX = focal2fov(fx, 1) | |
FovY = focal2fov(fx, 1) | |
tanfovx = math.tan(FovX * 0.5) | |
tanfovy = math.tan(FovY * 0.5) | |
assert tanfovx == tanfovy | |
trans = np.array([0.0, 0.0, 0.0]) | |
scale = 1.0 | |
world_view_transform = torch.tensor(getWorld2View2(R, T, trans, | |
scale)).transpose( | |
0, 1) | |
projection_matrix = getProjectionMatrix(znear=self.znear, | |
zfar=self.zfar, | |
fovX=FovX, | |
fovY=FovY).transpose(0, 1) | |
full_proj_transform = (world_view_transform.unsqueeze(0).bmm( | |
projection_matrix.unsqueeze(0))).squeeze(0) | |
camera_center = world_view_transform.inverse()[3, :3] | |
view_world_transform = torch.tensor(getView2World(R, T, trans, | |
scale)).transpose( | |
0, 1) | |
# item.update(viewpoint_cam=[viewpoint_cam]) | |
c = {} | |
c["source_cv2wT_quat"] = self.get_source_cw2wT(view_world_transform) | |
c.update( | |
# projection_matrix=projection_matrix, # K | |
cam_view=world_view_transform, # world_view_transform | |
cam_view_proj=full_proj_transform, # full_proj_transform | |
cam_pos=camera_center, | |
tanfov=tanfovx, # TODO, fix in the renderer | |
# orig_c2w=c2w, | |
# orig_w2c=w2c, | |
orig_pose=torch.from_numpy(pose), | |
orig_c2w=torch.from_numpy(c2w), | |
orig_w2c=torch.from_numpy(w2c), | |
# tanfovy=tanfovy, | |
) | |
return c # dict for gs rendering | |
def __len__(self): | |
return len(self.rgb_list) | |
def load_bbox(self, mask): | |
# st() | |
nonzero_value = torch.nonzero(mask) | |
height, width = nonzero_value.max(dim=0)[0] | |
top, left = nonzero_value.min(dim=0)[0] | |
bbox = torch.tensor([top, left, height, width], dtype=torch.float32) | |
return bbox | |
def __getitem__(self, idx): | |
# try: | |
data = self._read_data(idx) | |
return data | |
# except Exception as e: | |
# with open('error_log.txt', 'a') as f: | |
# f.write(str(e) + '\n') | |
# with open('error_idx.txt', 'a') as f: | |
# f.write(str(self.data_ins_list[idx]) + '\n') | |
# print(e, flush=True) | |
# return {} | |
def gen_rays(self, c2w): | |
# Generate rays | |
self.h = self.reso_encoder | |
self.w = self.reso_encoder | |
yy, xx = torch.meshgrid( | |
torch.arange(self.h, dtype=torch.float32) + 0.5, | |
torch.arange(self.w, dtype=torch.float32) + 0.5, | |
indexing='ij') | |
# normalize to 0-1 pixel range | |
yy = yy / self.h | |
xx = xx / self.w | |
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
cx, cy, fx, fy = self.intrinsics[2], self.intrinsics[ | |
5], self.intrinsics[0], self.intrinsics[4] | |
# cx *= self.w | |
# cy *= self.h | |
# f_x = f_y = fx * h / res_raw | |
c2w = torch.from_numpy(c2w).float() | |
xx = (xx - cx) / fx | |
yy = (yy - cy) / fy | |
zz = torch.ones_like(xx) | |
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
dirs = dirs.reshape(-1, 3, 1) | |
del xx, yy, zz | |
# st() | |
dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() | |
origins = origins.view(self.h, self.w, 3) | |
dirs = dirs.view(self.h, self.w, 3) | |
return origins, dirs | |
def _read_data(self, idx): | |
rgb_fname = self.rgb_list[idx] | |
pose_fname = self.pose_list[idx] | |
raw_img = imageio.imread(rgb_fname) | |
# ! RGBD | |
alpha_mask = raw_img[..., -1:] / 255 | |
raw_img = alpha_mask * raw_img[..., :3] + ( | |
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 | |
raw_img = raw_img.astype( | |
np.uint8) # otherwise, float64 won't call ToTensor() | |
# return raw_img | |
# st() | |
if self.preprocess is None: | |
img_to_encoder = cv2.resize(raw_img, | |
(self.reso_encoder, self.reso_encoder), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
img_to_encoder = img_to_encoder[ | |
..., :3] #[3, reso_encoder, reso_encoder] | |
img_to_encoder = self.normalize(img_to_encoder) | |
else: | |
img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip | |
# return img_to_encoder | |
img = cv2.resize(raw_img, (self.reso, self.reso), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
# img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA) | |
# img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution | |
# img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution | |
# img_sr = cv2.resize( | |
# raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4 | |
# ) # just as refinement, since eg3d uses 64->128 final resolution | |
# img = torch.from_numpy(img)[..., :3].permute( | |
# 2, 0, 1) / 255.0 #[3, reso, reso] | |
img = torch.from_numpy(img)[..., :3].permute( | |
2, 0, 1 | |
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
# img_sr = torch.from_numpy(img_sr)[..., :3].permute( | |
# 2, 0, 1 | |
# ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16] | |
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. | |
# return c2w | |
# if self.load_depth: | |
# depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx], | |
# try: | |
depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso, | |
self.reso) | |
# return depth | |
# except: | |
# # print(self.depth_list[idx]) | |
# raise NotImplementedError(self.depth_list[idx]) | |
# if depth | |
# try: | |
bbox = self.load_bbox(depth > 0) | |
# except: | |
# print(rgb_fname) | |
# return {} | |
# st() | |
# plucker | |
rays_o, rays_d = self.gen_rays(c2w) | |
rays_plucker = torch.cat([torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
dim=-1) # [h, w, 6] | |
img_to_encoder = torch.cat( | |
[img_to_encoder, rays_plucker.permute(2, 0, 1)], | |
0).float() # concat in C dim | |
# ! add depth as input | |
normalized_depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:], | |
self.reso_encoder, | |
self.reso_encoder).unsqueeze(0) | |
# normalized_depth = depth.unsqueeze(0) # min=0 | |
img_to_encoder = torch.cat([img_to_encoder, normalized_depth], | |
0) # concat in C dim | |
c = np.concatenate([c2w.reshape(16), self.intrinsics], | |
axis=0).reshape(25).astype( | |
np.float32) # 25, no '1' dim needed. | |
if self.gs_cam_format: | |
c = self.c_to_3dgs_format(c) | |
else: | |
c = torch.from_numpy(c) | |
ret_dict = { | |
# 'rgb_fname': rgb_fname, | |
'img_to_encoder': img_to_encoder, | |
'img': img, | |
'c': c, | |
# 'img_sr': img_sr, | |
# 'ins_name': self.data_ins_list[idx] | |
} | |
ins = str( | |
(Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) | |
if self.shuffle_across_cls: | |
caption = self.caption_data['/'.join(ins.split('/')[1:])] | |
else: | |
caption = self.caption_data[ins] | |
ret_dict.update({ | |
'depth': depth, | |
'depth_mask': depth > 0, | |
# 'depth_mask_sr': depth_mask_sr, | |
'bbox': bbox, | |
'caption': caption, | |
'rays_plucker': rays_plucker, # cam embedding used in lgm | |
'ins': ins, # placeholder | |
}) | |
return ret_dict | |
class RealDataset(Dataset): | |
def __init__( | |
self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1, | |
interval=1, | |
plucker_embedding=False, | |
shuffle_across_cls=False, | |
wds_split=1, # 4 splits to accelerate preprocessing | |
) -> None: | |
super().__init__() | |
self.file_path = file_path | |
self.overfitting = overfitting | |
self.scene_scale = scene_scale | |
self.reso = reso | |
self.reso_encoder = reso_encoder | |
self.classes = False | |
self.load_depth = load_depth | |
self.preprocess = preprocess | |
self.plucker_embedding = plucker_embedding | |
self.rgb_list = [] | |
all_fname = [ | |
t for t in os.listdir(self.file_path) | |
if t.split('.')[1] in ['png', 'jpg'] | |
] | |
all_fname = [name for name in all_fname if '-input' in name ] | |
self.rgb_list += ([ | |
os.path.join(self.file_path, fname) for fname in all_fname | |
]) | |
# if len(self.rgb_list) == 1: | |
# # placeholder | |
# self.rgb_list = self.rgb_list * 40 | |
# ! setup normalizataion | |
transformations = [ | |
transforms.ToTensor(), # [0,1] range | |
] | |
assert imgnet_normalize | |
if imgnet_normalize: | |
transformations.append( | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225)) # type: ignore | |
) | |
else: | |
transformations.append( | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))) # type: ignore | |
self.normalize = transforms.Compose(transformations) | |
# camera = torch.load('eval_pose.pt', map_location='cpu') | |
# self.eval_camera = camera | |
# pre-cache | |
# self.calc_rays_plucker() | |
def __len__(self): | |
return len(self.rgb_list) | |
def __getitem__(self, index) -> Any: | |
# return super().__getitem__(index) | |
rgb_fname = self.rgb_list[index] | |
# ! preprocess, normalize | |
raw_img = imageio.imread(rgb_fname) | |
# interpolation=cv2.INTER_AREA) | |
if raw_img.shape[-1] == 4: | |
alpha_mask = raw_img[..., 3:4] / 255.0 | |
bg_white = np.ones_like(alpha_mask) * 255.0 | |
raw_img = raw_img[..., :3] * alpha_mask + ( | |
1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder] | |
raw_img = raw_img.astype(np.uint8) | |
# raw_img = recenter(raw_img, np.ones_like(raw_img), border_ratio=0.2) | |
# log gt | |
img = cv2.resize(raw_img, (self.reso, self.reso), | |
interpolation=cv2.INTER_LANCZOS4) | |
img = torch.from_numpy(img)[..., :3].permute( | |
2, 0, 1 | |
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
ret_dict = { | |
# 'rgb_fname': rgb_fname, | |
# 'img_to_encoder': | |
# img_to_encoder.unsqueeze(0).repeat_interleave(40, 0), | |
'img': img, | |
# 'c': self.eval_camera, # TODO, get pre-calculated samples | |
# 'ins': 'placeholder', | |
# 'bbox': 'placeholder', | |
# 'caption': 'placeholder', | |
} | |
# ! repeat as a intance | |
return ret_dict | |
class RealMVDataset(Dataset): | |
def __init__( | |
self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1, | |
interval=1, | |
plucker_embedding=False, | |
shuffle_across_cls=False, | |
wds_split=1, # 4 splits to accelerate preprocessing | |
) -> None: | |
super().__init__() | |
self.file_path = file_path | |
self.overfitting = overfitting | |
self.scene_scale = scene_scale | |
self.reso = reso | |
self.reso_encoder = reso_encoder | |
self.classes = False | |
self.load_depth = load_depth | |
self.preprocess = preprocess | |
self.plucker_embedding = plucker_embedding | |
self.rgb_list = [] | |
all_fname = [ | |
t for t in os.listdir(self.file_path) | |
if t.split('.')[1] in ['png', 'jpg'] | |
] | |
# all_fname = [name for name in all_fname if '-input' in name ] | |
# all_fname = [name for name in all_fname if 'sorting_board-input' in name ] | |
all_fname = [name for name in all_fname if 'teasure_chest-input' in name ] | |
# all_fname = [name for name in all_fname if 'bubble_mart_blue-input' in name ] | |
# all_fname = [name for name in all_fname if 'chair_comfort-input' in name ] | |
self.rgb_list += ([ | |
os.path.join(self.file_path, fname) for fname in all_fname | |
]) | |
# if len(self.rgb_list) == 1: | |
# # placeholder | |
# self.rgb_list = self.rgb_list * 40 | |
# ! setup normalizataion | |
transformations = [ | |
transforms.ToTensor(), # [0,1] range | |
] | |
# load zero123pp pose | |
# ''' | |
camera = torch.load('assets/objv_eval_pose.pt', map_location='cpu') # 40, 25 | |
# v=6 render | |
# zero123pp_pose = torch.load('assets/input_cameras.pt', map_location='cpu')[0] | |
# zero123pp_pose = torch.load('assets/input_cameras_1.5.pt', map_location='cpu')[0] | |
# V = zero123pp_pose.shape[0] | |
# zero123pp_pose = torch.cat([zero123pp_pose[:, :12], # change radius to 1.5 here | |
# torch.Tensor([0,0,0,1]).reshape(1,4).repeat_interleave(V, 0).to(camera), | |
# camera[:V, 16:]], | |
# 1) | |
azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) | |
elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) | |
# zero123pp_pose, _ = generate_input_camera(1.6, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) | |
zero123pp_pose, _ = generate_input_camera(1.8, [[elevations[i], azimuths[i]] for i in range(6)], fov=30) | |
K = torch.Tensor([1.3889, 0.0000, 0.5000, 0.0000, 1.3889, 0.5000, 0.0000, 0.0000, 0.0039]).to(zero123pp_pose) # keeps the same | |
# st() | |
zero123pp_pose = torch.cat([zero123pp_pose.reshape(6,-1), K.unsqueeze(0).repeat(6,1)], dim=-1) | |
# ! directly adopt gt input | |
# self.indices = np.array([0,2,4,5]) | |
# eval_camera = zero123pp_pose[self.indices] | |
# self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder | |
# ! adopt mv-diffusion output as input. | |
# self.indices = np.array([1,0,2,4,5]) | |
self.indices = np.array([0,1,2,3,4,5]) | |
eval_camera = zero123pp_pose[self.indices].float().cpu().numpy() # for normalization | |
# eval_camera = zero123pp_pose[self.indices] | |
# self.eval_camera = eval_camera | |
# self.eval_camera = torch.cat([torch.zeros_like(eval_camera[0:1]),eval_camera], 0) # first c not used as condition here, just placeholder | |
# # * normalize here | |
# self.eval_camera = self.normalize_camera(eval_camera, eval_camera[0:1]) # the first img is not used. | |
# ! no normalization required | |
self.eval_camera = eval_camera | |
# st() | |
# self.eval_camera = self.eval_camera + np.random.randn(*self.eval_camera.shape) * 0.04 - 0.02 | |
# ''' | |
# self.eval_camera = torch.load('/nas/shared/V2V/yslan/logs/nips24/LSGM/t23d/FM/9cls/i23d/dit-XL2-MV-ampFast/gpu1-batch10-lr0-bf16-qknorm-fixinpRange-debug/2590000_c.pt')[1:6].float().cpu().numpy() | |
# self.eval_camera = self.normalize_camera(self.eval_camera, self.eval_camera[0:1]) | |
# self.eval_camera = self.eval_camera + np.random.randn(*self.eval_camera.shape) * 0.04 - 0.02 | |
# pass | |
# FIXME check how sensitive it is: requries training augmentation. | |
# self.eval_camera = torch.load('/nas/shared/V2V/yslan/logs/nips24/LSGM/t23d/FM/9cls/i23d/dit-XL2-MV-ampFast/gpu1-batch10-lr0-bf16-qknorm-fixinpRange-debug/2590000_c.pt')[1:] | |
def normalize_camera(self, c, c_frame0): | |
# assert c.shape[0] == self.chunk_size # 8 o r10 | |
B = c.shape[0] | |
camera_poses = c[:, :16].reshape(B, 4, 4) # 3x4 | |
canonical_camera_poses = c_frame0[:, :16].reshape(1, 4, 4) | |
inverse_canonical_pose = np.linalg.inv(canonical_camera_poses) | |
inverse_canonical_pose = np.repeat(inverse_canonical_pose, B, 0) | |
cam_radius = np.linalg.norm( | |
c_frame0[:, :16].reshape(1, 4, 4)[:, :3, 3], | |
axis=-1, | |
keepdims=False) # since g-buffer adopts dynamic radius here. | |
frame1_fixed_pos = np.repeat(np.eye(4)[None], 1, axis=0) | |
frame1_fixed_pos[:, 2, -1] = -cam_radius | |
transform = frame1_fixed_pos @ inverse_canonical_pose | |
new_camera_poses = np.repeat( | |
transform, 1, axis=0 | |
) @ camera_poses # [V, 4, 4]. np.repeat() is th.repeat_interleave() | |
c = np.concatenate([new_camera_poses.reshape(B, 16), c[:, 16:]], | |
axis=-1) | |
return c | |
def __len__(self): | |
return len(self.rgb_list) | |
def __getitem__(self, index) -> Any: | |
# return super().__getitem__(index) | |
rgb_fname = self.rgb_list[index] | |
# ! if loading training imgs | |
# raw_img = imageio.imread(rgb_fname) | |
# if raw_img.shape[-1] == 4: | |
# alpha_mask = raw_img[..., 3:4] / 255.0 | |
# bg_white = np.ones_like(alpha_mask) * 255.0 | |
# raw_img = raw_img[..., :3] * alpha_mask + ( | |
# 1 - alpha_mask) * bg_white #[3, reso_encoder, reso_encoder] | |
# raw_img = raw_img.astype(np.uint8) | |
# img = rearrange(raw_img, 'h (n w) c -> n c h w', n=6)[self.indices] | |
# img = torch.from_numpy(img) / 127.5 - 1 # | |
# ''' | |
# img = cv2.resize(raw_img, (320,320), | |
# interpolation=cv2.INTER_LANCZOS4) # for easy concat | |
# img = torch.from_numpy(img)[..., :3].permute( | |
# 2, 0, 1 | |
# ) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
# img = rearrange(raw_img, 'h (n w) c -> n c h w', n=6)[1:6] # ! if loading gt inp views | |
# ''' | |
# ! if loading mv-diff output views | |
mv_img = imageio.imread(rgb_fname.replace('-input', '')) | |
# st() | |
# mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) | |
mv_img = rearrange(mv_img, '(n h) (m w) c -> (n m) h w c', n=3, m=2)[self.indices] # (6, 3, 320, 320) | |
mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.2) for img in mv_img], axis=0) | |
# mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.3) for img in mv_img], axis=0) | |
# mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.1) for img in mv_img], axis=0) | |
# mv_img = np.stack([recenter(img, np.ones_like(img), border_ratio=0.4) for img in mv_img], axis=0) | |
mv_img = rearrange(mv_img, 'b h w c -> b c h w') # to torch tradition | |
mv_img = torch.from_numpy(mv_img) / 127.5 - 1 | |
# img = torch.cat([img[None], mv_img], 0) | |
img = mv_img # ! directly adopt output views for reconstruction/generation | |
# ''' | |
ret_dict = { | |
'img': img, | |
'c': self.eval_camera | |
} | |
return ret_dict | |
class NovelViewObjverseDataset(MultiViewObjverseDataset): | |
"""novel view prediction version. | |
""" | |
def __init__(self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1, | |
**kwargs): | |
super().__init__(file_path, reso, reso_encoder, preprocess, classes, | |
load_depth, test, scene_scale, overfitting, | |
imgnet_normalize, dataset_size, overfitting_bs, | |
**kwargs) | |
def __getitem__(self, idx): | |
input_view = super().__getitem__( | |
idx) # get previous input view results | |
# get novel view of the same instance | |
novel_view = super().__getitem__( | |
(idx // self.instance_data_length) * self.instance_data_length + | |
random.randint(0, self.instance_data_length - 1)) | |
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
return input_view | |
class MultiViewObjverseDatasetforLMDB(MultiViewObjverseDataset): | |
def __init__( | |
self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1, | |
shuffle_across_cls=False, | |
wds_split=1, | |
four_view_for_latent=False, | |
): | |
super().__init__(file_path, | |
reso, | |
reso_encoder, | |
preprocess, | |
classes, | |
load_depth, | |
test, | |
scene_scale, | |
overfitting, | |
imgnet_normalize, | |
dataset_size, | |
overfitting_bs, | |
shuffle_across_cls=shuffle_across_cls, | |
wds_split=wds_split, | |
four_view_for_latent=four_view_for_latent) | |
assert self.reso == 256 | |
with open( | |
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' | |
) as f: | |
self.caption_data = json.load(f) | |
lmdb_path = '/cpfs01/user/yangpeiqing.p/yslan/data/Furnitures_uncompressed/' | |
# with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: | |
# self.idx_to_ins_mapping = json.load(f) | |
def __len__(self): | |
return super().__len__() | |
# return 100 # for speed debug | |
def __getitem__(self, idx): | |
# ret_dict = super().__getitem__(idx) | |
rgb_fname = self.rgb_list[idx] | |
pose_fname = self.pose_list[idx] | |
raw_img = imageio.imread(rgb_fname) # [..., :3] | |
# assert raw_img.shape[-1] == 4 | |
if raw_img.shape[-1] == 4: # ! set bg to white | |
alpha_mask = raw_img[..., -1:] / 255 # [0,1] | |
raw_img = alpha_mask * raw_img[..., :3] + ( | |
1 - alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 | |
raw_img = raw_img.astype(np.uint8) | |
raw_img = cv2.resize(raw_img, (self.reso, self.reso), | |
interpolation=cv2.INTER_LANCZOS4) | |
c2w = read_camera_matrix_single(pose_fname) #[1, 4, 4] -> [1, 16] | |
c = np.concatenate([c2w.reshape(16), self.intrinsics], | |
axis=0).reshape(25).astype( | |
np.float32) # 25, no '1' dim needed. | |
c = torch.from_numpy(c) | |
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. | |
# if self.load_depth: | |
# depth, depth_mask, depth_mask_sr = read_dnormal(self.depth_list[idx], | |
# try: | |
depth = read_dnormal(self.depth_list[idx], c2w[:3, 3:], self.reso, | |
self.reso) | |
# except: | |
# # print(self.depth_list[idx]) | |
# raise NotImplementedError(self.depth_list[idx]) | |
# if depth | |
# try: | |
bbox = self.load_bbox(depth > 0) | |
ins = str( | |
(Path(self.data_ins_list[idx]).relative_to(self.file_path)).parent) | |
if self.shuffle_across_cls: | |
caption = self.caption_data['/'.join(ins.split('/')[1:])] | |
else: | |
caption = self.caption_data[ins] | |
ret_dict = { | |
'raw_img': raw_img, | |
'c': c, | |
'depth': depth, | |
# 'depth_mask': depth_mask, # 64x64 here? | |
'bbox': bbox, | |
'ins': ins, | |
'caption': caption, | |
# 'fname': rgb_fname, | |
} | |
return ret_dict | |
class Objv_LMDBDataset_MV_Compressed(LMDBDataset_MV_Compressed): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
test=False, | |
**kwargs): | |
super().__init__(lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize, | |
dataset_size=dataset_size, | |
**kwargs) | |
self.instance_data_length = 40 # ! could save some key attributes in LMDB | |
if test: | |
self.length = self.instance_data_length | |
elif dataset_size > 0: | |
self.length = dataset_size * self.instance_data_length | |
# load caption data, and idx-to-ins mapping | |
with open( | |
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' | |
) as f: | |
self.caption_data = json.load(f) | |
with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: | |
self.idx_to_ins_mapping = json.load(f) | |
def _load_data(self, idx): | |
# ''' | |
raw_img, depth, c, bbox = self._load_lmdb_data(idx) | |
# raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
# resize depth and bbox | |
caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] | |
return { | |
**self._post_process_sample(raw_img, depth), | |
'c': c, | |
'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8), | |
# 'bbox': (bbox*(self.reso/256.0)).astype(np.uint8), # TODO, double check 512 in wds? | |
'caption': caption | |
} | |
# ''' | |
# raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
# st() | |
# return {} | |
def __getitem__(self, idx): | |
return self._load_data(idx) | |
class Objv_LMDBDataset_MV_NoCompressed(Objv_LMDBDataset_MV_Compressed): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
test=False, | |
**kwargs): | |
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
dataset_size, test, **kwargs) | |
def _load_data(self, idx): | |
# ''' | |
raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
# resize depth and bbox | |
caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] | |
return { | |
**self._post_process_sample(raw_img, depth), 'c': c, | |
'bbox': (bbox * (self.reso / 512.0)).astype(np.uint8), | |
'caption': caption | |
} | |
return {} | |
class Objv_LMDBDataset_NV_NoCompressed(Objv_LMDBDataset_MV_NoCompressed): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
test=False, | |
**kwargs): | |
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
dataset_size, test, **kwargs) | |
def __getitem__(self, idx): | |
input_view = self._load_data(idx) # get previous input view results | |
# get novel view of the same instance | |
try: | |
novel_view = self._load_data( | |
(idx // self.instance_data_length) * | |
self.instance_data_length + | |
random.randint(0, self.instance_data_length - 1)) | |
except Exception as e: | |
raise NotImplementedError(idx) | |
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
return input_view | |
class Objv_LMDBDataset_MV_Compressed_for_lmdb(LMDBDataset_MV_Compressed): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
test=False, | |
**kwargs): | |
super().__init__(lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize, | |
dataset_size=dataset_size, | |
**kwargs) | |
self.instance_data_length = 40 # ! could save some key attributes in LMDB | |
if test: | |
self.length = self.instance_data_length | |
elif dataset_size > 0: | |
self.length = dataset_size * self.instance_data_length | |
# load caption data, and idx-to-ins mapping | |
with open( | |
'/cpfs01/shared/V2V/V2V_hdd/yslan/aigc3d/text_captions_cap3d.json' | |
) as f: | |
self.caption_data = json.load(f) | |
with open(os.path.join(lmdb_path, 'idx_to_ins_mapping.json')) as f: | |
self.idx_to_ins_mapping = json.load(f) | |
# def _load_data(self, idx): | |
# # ''' | |
# raw_img, depth, c, bbox = self._load_lmdb_data(idx) | |
# # resize depth and bbox | |
# caption = self.caption_data[self.idx_to_ins_mapping[str(idx)]] | |
# # st() | |
# return { | |
# **self._post_process_sample(raw_img, depth), 'c': c, | |
# 'bbox': (bbox*(self.reso/512.0)).astype(np.uint8), | |
# 'caption': caption | |
# } | |
# # ''' | |
# # raw_img, depth, c, bbox = self._load_lmdb_data_no_decompress(idx) | |
# # st() | |
# # return {} | |
def load_bbox(self, mask): | |
# st() | |
nonzero_value = torch.nonzero(mask) | |
height, width = nonzero_value.max(dim=0)[0] | |
top, left = nonzero_value.min(dim=0)[0] | |
bbox = torch.tensor([top, left, height, width], dtype=torch.float32) | |
return bbox | |
def __getitem__(self, idx): | |
raw_img, depth, c, bbox = self._load_lmdb_data(idx) | |
return {'raw_img': raw_img, 'depth': depth, 'c': c, 'bbox': bbox} | |
class Objv_LMDBDataset_NV_Compressed(Objv_LMDBDataset_MV_Compressed): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
**kwargs): | |
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
dataset_size, **kwargs) | |
def __getitem__(self, idx): | |
input_view = self._load_data(idx) # get previous input view results | |
# get novel view of the same instance | |
try: | |
novel_view = self._load_data( | |
(idx // self.instance_data_length) * | |
self.instance_data_length + | |
random.randint(0, self.instance_data_length - 1)) | |
except Exception as e: | |
raise NotImplementedError(idx) | |
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
return input_view | |
# | |
# test tar loading | |
def load_wds_ResampledShard(file_path, | |
batch_size, | |
num_workers, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
plucker_embedding=False, | |
decode_encode_img_only=False, | |
load_instance=False, | |
mv_input=False, | |
split_chunk_input=False, | |
duplicate_sample=True, | |
append_depth=False, | |
gs_cam_format=False, | |
orthog_duplicate=False, | |
**kwargs): | |
# return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] | |
class PostProcess: | |
def __init__( | |
self, | |
reso, | |
reso_encoder, | |
imgnet_normalize, | |
plucker_embedding, | |
decode_encode_img_only, | |
mv_input, | |
split_chunk_input, | |
duplicate_sample, | |
append_depth, | |
gs_cam_format, | |
orthog_duplicate, | |
) -> None: | |
self.gs_cam_format = gs_cam_format | |
self.append_depth = append_depth | |
self.plucker_embedding = plucker_embedding | |
self.decode_encode_img_only = decode_encode_img_only | |
self.duplicate_sample = duplicate_sample | |
self.orthog_duplicate = orthog_duplicate | |
self.zfar = 100.0 | |
self.znear = 0.01 | |
transformations = [] | |
if not split_chunk_input: | |
transformations.append(transforms.ToTensor()) | |
if imgnet_normalize: | |
transformations.append( | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225)) # type: ignore | |
) | |
else: | |
transformations.append( | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))) # type: ignore | |
self.normalize = transforms.Compose(transformations) | |
self.reso_encoder = reso_encoder | |
self.reso = reso | |
self.instance_data_length = 40 | |
# self.pair_per_instance = 1 # compat | |
self.mv_input = mv_input | |
self.split_chunk_input = split_chunk_input # 8 | |
self.chunk_size = 8 if split_chunk_input else 40 | |
# st() | |
if split_chunk_input: | |
self.pair_per_instance = 1 | |
else: | |
self.pair_per_instance = 4 if mv_input else 2 # check whether improves IO | |
def gen_rays(self, c): | |
# Generate rays | |
intrinsics, c2w = c[16:], c[:16].reshape(4, 4) | |
self.h = self.reso_encoder | |
self.w = self.reso_encoder | |
yy, xx = torch.meshgrid( | |
torch.arange(self.h, dtype=torch.float32) + 0.5, | |
torch.arange(self.w, dtype=torch.float32) + 0.5, | |
indexing='ij') | |
# normalize to 0-1 pixel range | |
yy = yy / self.h | |
xx = xx / self.w | |
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[ | |
0], intrinsics[4] | |
# cx *= self.w | |
# cy *= self.h | |
# f_x = f_y = fx * h / res_raw | |
c2w = torch.from_numpy(c2w).float() | |
xx = (xx - cx) / fx | |
yy = (yy - cy) / fy | |
zz = torch.ones_like(xx) | |
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
dirs = dirs.reshape(-1, 3, 1) | |
del xx, yy, zz | |
# st() | |
dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() | |
origins = origins.view(self.h, self.w, 3) | |
dirs = dirs.view(self.h, self.w, 3) | |
return origins, dirs | |
def _post_process_batch_sample( | |
self, sample): # sample is an instance batch here | |
caption, ins = sample[-2:] | |
instance_samples = [] | |
for instance_idx in range(sample[0].shape[0]): | |
instance_samples.append( | |
self._post_process_sample(item[instance_idx] | |
for item in sample[:-2])) | |
return (*instance_samples, caption, ins) | |
def _post_process_sample(self, data_sample): | |
# raw_img, depth, c, bbox, caption, ins = data_sample | |
raw_img, depth, c, bbox = data_sample | |
bbox = (bbox * (self.reso / 256)).astype( | |
np.uint8) # normalize bbox to the reso range | |
if raw_img.shape[-2] != self.reso_encoder: | |
img_to_encoder = cv2.resize( | |
raw_img, (self.reso_encoder, self.reso_encoder), | |
interpolation=cv2.INTER_LANCZOS4) | |
else: | |
img_to_encoder = raw_img | |
img_to_encoder = self.normalize(img_to_encoder) | |
if self.plucker_embedding: | |
rays_o, rays_d = self.gen_rays(c) | |
rays_plucker = torch.cat( | |
[torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w | |
img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0) | |
img = cv2.resize(raw_img, (self.reso, self.reso), | |
interpolation=cv2.INTER_LANCZOS4) | |
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 | |
if self.decode_encode_img_only: | |
depth_reso, fg_mask_reso = depth, depth | |
else: | |
depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso) | |
# return { | |
# # **sample, | |
# 'img_to_encoder': img_to_encoder, | |
# 'img': img, | |
# 'depth_mask': fg_mask_reso, | |
# # 'img_sr': img_sr, | |
# 'depth': depth_reso, | |
# 'c': c, | |
# 'bbox': bbox, | |
# 'caption': caption, | |
# 'ins': ins | |
# # ! no need to load img_sr for now | |
# } | |
# if len(data_sample) == 4: | |
return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox) | |
# else: | |
# return (img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox, data_sample[-2], data_sample[-1]) | |
def _post_process_sample_batch(self, data_sample): | |
# raw_img, depth, c, bbox, caption, ins = data_sample | |
raw_img, depth, c, bbox = data_sample | |
bbox = (bbox * (self.reso / 256)).astype( | |
np.uint8) # normalize bbox to the reso range | |
assert raw_img.shape[-2] == self.reso_encoder | |
# img_to_encoder = cv2.resize( | |
# raw_img, (self.reso_encoder, self.reso_encoder), | |
# interpolation=cv2.INTER_LANCZOS4) | |
# else: | |
# img_to_encoder = raw_img | |
raw_img = torch.from_numpy(raw_img).permute(0, 3, 1, | |
2) / 255.0 # [0,1] | |
img_to_encoder = self.normalize(raw_img) | |
if self.plucker_embedding: | |
rays_plucker = [] | |
for idx in range(c.shape[0]): | |
rays_o, rays_d = self.gen_rays(c[idx]) | |
rays_plucker.append( | |
torch.cat( | |
[torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
dim=-1).permute(2, 0, 1)) # [h, w, 6] -> 6,h,w | |
rays_plucker = torch.stack(rays_plucker, 0) | |
img_to_encoder = torch.cat([img_to_encoder, rays_plucker], | |
1) # concat in C dim | |
if self.append_depth: | |
normalized_depth = torch.from_numpy(depth).clone().unsqueeze( | |
1) # min=0 | |
# normalized_depth -= torch.min(normalized_depth) # always 0 here | |
# normalized_depth /= torch.max(normalized_depth) | |
# normalized_depth = normalized_depth.unsqueeze(1) * 2 - 1 # normalize to [-1,1] | |
img_to_encoder = torch.cat([img_to_encoder, normalized_depth], | |
1) # concat in C dim | |
# img = cv2.resize(raw_img, (self.reso, self.reso), | |
# interpolation=cv2.INTER_LANCZOS4) | |
# img = torch.from_numpy(raw_img).permute(2, 0, 1) / 127.5 - 1 | |
# st() | |
if raw_img.shape[-1] != self.reso: | |
img = torch.nn.functional.interpolate( | |
input=raw_img, | |
size=(self.reso, self.reso), | |
mode='bilinear', | |
align_corners=False, | |
) * 2 - 1 # [-1,1] range | |
else: | |
img = raw_img * 2 - 1 | |
if self.decode_encode_img_only: | |
depth_reso, fg_mask_reso = depth, depth | |
else: | |
depth_reso, fg_mask_reso = resize_depth_mask_Tensor( | |
torch.from_numpy(depth), self.reso) | |
# if not self.gs_cam_format: # otherwise still playing with np format later | |
c = torch.from_numpy(c) | |
return (img_to_encoder, img, fg_mask_reso, depth_reso, c, | |
torch.from_numpy(bbox)) | |
def rand_sample_idx(self): | |
return random.randint(0, self.instance_data_length - 1) | |
def rand_pair(self): | |
return (self.rand_sample_idx() for _ in range(2)) | |
def paired_post_process(self, sample): | |
# repeat n times? | |
all_inp_list = [] | |
all_nv_list = [] | |
caption, ins = sample[-2:] | |
# expanded_return = [] | |
for _ in range(self.pair_per_instance): | |
cano_idx, nv_idx = self.rand_pair() | |
cano_sample = self._post_process_sample( | |
item[cano_idx] for item in sample[:-2]) | |
nv_sample = self._post_process_sample(item[nv_idx] | |
for item in sample[:-2]) | |
all_inp_list.extend(cano_sample) | |
all_nv_list.extend(nv_sample) | |
return (*all_inp_list, *all_nv_list, caption, ins) | |
# return [cano_sample, nv_sample, caption, ins] | |
# return (*cano_sample, *nv_sample, caption, ins) | |
def get_source_cw2wT(self, source_cameras_view_to_world): | |
return matrix_to_quaternion( | |
source_cameras_view_to_world[:3, :3].transpose(0, 1)) | |
def c_to_3dgs_format(self, pose): | |
# TODO, switch to torch version (batched later) | |
c2w = pose[:16].reshape(4, 4) # 3x4 | |
# ! load cam | |
w2c = np.linalg.inv(c2w) | |
R = np.transpose( | |
w2c[:3, :3] | |
) # R is stored transposed due to 'glm' in CUDA code | |
T = w2c[:3, 3] | |
fx = pose[16] | |
FovX = focal2fov(fx, 1) | |
FovY = focal2fov(fx, 1) | |
tanfovx = math.tan(FovX * 0.5) | |
tanfovy = math.tan(FovY * 0.5) | |
assert tanfovx == tanfovy | |
trans = np.array([0.0, 0.0, 0.0]) | |
scale = 1.0 | |
view_world_transform = torch.tensor( | |
getView2World(R, T, trans, scale)).transpose(0, 1) | |
world_view_transform = torch.tensor( | |
getWorld2View2(R, T, trans, scale)).transpose(0, 1) | |
projection_matrix = getProjectionMatrix(znear=self.znear, | |
zfar=self.zfar, | |
fovX=FovX, | |
fovY=FovY).transpose(0, 1) | |
full_proj_transform = (world_view_transform.unsqueeze(0).bmm( | |
projection_matrix.unsqueeze(0))).squeeze(0) | |
camera_center = world_view_transform.inverse()[3, :3] | |
# item.update(viewpoint_cam=[viewpoint_cam]) | |
c = {} | |
# | |
c["source_cv2wT_quat"] = self.get_source_cw2wT( | |
view_world_transform) | |
c.update( | |
# projection_matrix=projection_matrix, # K | |
cam_view=world_view_transform, # world_view_transform | |
cam_view_proj=full_proj_transform, # full_proj_transform | |
cam_pos=camera_center, | |
tanfov=tanfovx, # TODO, fix in the renderer | |
orig_pose=torch.from_numpy(pose), | |
orig_c2w=torch.from_numpy(c2w), | |
orig_w2c=torch.from_numpy(w2c), | |
# tanfovy=tanfovy, | |
) | |
return c # dict for gs rendering | |
def paired_post_process_chunk(self, sample): | |
# repeat n times? | |
all_inp_list = [] | |
all_nv_list = [] | |
caption, ins = sample[-2:] | |
assert sample[0].shape[0] == 8 # random chunks | |
# expanded_return = [] | |
if self.duplicate_sample: | |
processed_sample = self._post_process_sample_batch( | |
item for item in sample[:-2]) | |
if self.orthog_duplicate: | |
indices = torch.cat([torch.randperm(8), | |
torch.randperm(8)]) # for now | |
else: | |
indices = torch.randperm(8) | |
shuffle_processed_sample = [] | |
for _, item in enumerate(processed_sample): | |
shuffle_processed_sample.append( | |
torch.index_select(item, dim=0, index=indices)) | |
processed_sample = shuffle_processed_sample | |
if not self.orthog_duplicate: | |
all_inp_list.extend(item[:4] for item in processed_sample) | |
all_nv_list.extend(item[4:] for item in processed_sample) | |
else: | |
all_inp_list.extend(item[:8] for item in processed_sample) | |
all_nv_list.extend(item[8:] for item in processed_sample) | |
return (*all_inp_list, *all_nv_list, caption, ins) | |
else: | |
processed_sample = self._post_process_sample_batch( # avoid shuffle shorten processing time | |
item[:4] for item in sample[:-2]) | |
all_inp_list.extend(item for item in processed_sample) | |
all_nv_list.extend( | |
item for item in processed_sample) # ! placeholder | |
return (*all_inp_list, *all_nv_list, caption, ins) | |
# randomly shuffle 8 views, avoid overfitting | |
def single_sample_create_dict(self, sample, prefix=''): | |
# if len(sample) == 1: | |
# sample = sample[0] | |
# assert len(sample) == 6 | |
img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
if self.gs_cam_format: | |
# TODO, can optimize later after model converges | |
B, V, _ = c.shape # B 4 25 | |
c = rearrange(c, 'B V C -> (B V) C').cpu().numpy() | |
all_gs_c = [self.c_to_3dgs_format(pose) for pose in c] | |
c = { | |
k: | |
rearrange(torch.stack([gs_c[k] for gs_c in all_gs_c]), | |
'(B V) ... -> B V ...', | |
B=B, | |
V=V) if isinstance(all_gs_c[0][k], torch.Tensor) | |
else all_gs_c[0][k] | |
for k in all_gs_c[0].keys() | |
} | |
# c = collate_gs_c | |
return { | |
# **sample, | |
f'{prefix}img_to_encoder': img_to_encoder, | |
f'{prefix}img': img, | |
f'{prefix}depth_mask': fg_mask_reso, | |
f'{prefix}depth': depth_reso, | |
f'{prefix}c': c, | |
f'{prefix}bbox': bbox, | |
} | |
def single_instance_sample_create_dict(self, sample, prfix=''): | |
assert len(sample) == 42 | |
inp_sample_list = [[] for _ in range(6)] | |
for item in sample[:40]: | |
for item_idx in range(6): | |
inp_sample_list[item_idx].append(item[0][item_idx]) | |
inp_sample = self.single_sample_create_dict( | |
(torch.stack(item_list) for item_list in inp_sample_list), | |
prefix='') | |
return { | |
**inp_sample, # | |
'caption': sample[-2], | |
'ins': sample[-1] | |
} | |
def decode_zip(self, sample_pyd, shape=(256, 256)): | |
if isinstance(sample_pyd, tuple): | |
sample_pyd = sample_pyd[0] | |
assert isinstance(sample_pyd, dict) | |
raw_img = decompress_and_open_image_gzip( | |
sample_pyd['raw_img'], | |
is_img=True, | |
decompress=True, | |
decompress_fn=lz4.frame.decompress) | |
caption = sample_pyd['caption'].decode('utf-8') | |
ins = sample_pyd['ins'].decode('utf-8') | |
c = decompress_array(sample_pyd['c'], ( | |
self.chunk_size, | |
25, | |
), | |
np.float32, | |
decompress=True, | |
decompress_fn=lz4.frame.decompress) | |
bbox = decompress_array( | |
sample_pyd['bbox'], | |
( | |
self.chunk_size, | |
4, | |
), | |
np.float32, | |
# decompress=False) | |
decompress=True, | |
decompress_fn=lz4.frame.decompress) | |
if self.decode_encode_img_only: | |
depth = np.zeros(shape=(self.chunk_size, | |
*shape)) # save loading time | |
else: | |
depth = decompress_array(sample_pyd['depth'], | |
(self.chunk_size, *shape), | |
np.float32, | |
decompress=True, | |
decompress_fn=lz4.frame.decompress) | |
# return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c} | |
# return raw_img, depth, c, bbox, caption, ins | |
# return raw_img, bbox, caption, ins | |
# return bbox, caption, ins | |
return raw_img, depth, c, bbox, caption, ins | |
# ! run single-instance pipeline first | |
# return raw_img[0], depth[0], c[0], bbox[0], caption, ins | |
def create_dict(self, sample): | |
# sample = [item[0] for item in sample] # wds wrap items in [] | |
# st() | |
cano_sample_list = [[] for _ in range(6)] | |
nv_sample_list = [[] for _ in range(6)] | |
# st() | |
# bs = (len(sample)-2) // 6 | |
for idx in range(0, self.pair_per_instance): | |
cano_sample = sample[6 * idx:6 * (idx + 1)] | |
nv_sample = sample[6 * self.pair_per_instance + | |
6 * idx:6 * self.pair_per_instance + 6 * | |
(idx + 1)] | |
for item_idx in range(6): | |
cano_sample_list[item_idx].append(cano_sample[item_idx]) | |
nv_sample_list[item_idx].append(nv_sample[item_idx]) | |
# ! cycle input/output view for more pairs | |
cano_sample_list[item_idx].append(nv_sample[item_idx]) | |
nv_sample_list[item_idx].append(cano_sample[item_idx]) | |
# if self.split_chunk_input: | |
# cano_sample = self.single_sample_create_dict( | |
# (torch.cat(item_list, 0) for item_list in cano_sample_list), | |
# prefix='') | |
# nv_sample = self.single_sample_create_dict( | |
# (torch.cat(item_list, 0) for item_list in nv_sample_list), | |
# prefix='nv_') | |
# else: | |
cano_sample = self.single_sample_create_dict( | |
(torch.cat(item_list, 0) for item_list in cano_sample_list), | |
prefix='') | |
nv_sample = self.single_sample_create_dict( | |
(torch.cat(item_list, 0) for item_list in nv_sample_list), | |
prefix='nv_') | |
return { | |
**cano_sample, | |
**nv_sample, 'caption': sample[-2], | |
'ins': sample[-1] | |
} | |
def prepare_mv_input(self, sample): | |
# sample = [item[0] for item in sample] # wds wrap items in [] | |
bs = len(sample['caption']) # number of instances | |
chunk_size = sample['img'].shape[0] // bs | |
if self.split_chunk_input: | |
for k, v in sample.items(): | |
if isinstance(v, torch.Tensor): | |
sample[k] = rearrange(v, | |
"b f c ... -> (b f) c ...", | |
f=4 if not self.orthog_duplicate | |
else 8).contiguous() | |
# img = rearrange(sample['img'], "(b f) c h w -> b f c h w", f=4).contiguous() | |
# gt = rearrange(sample['nv_img'], "(b f) c h w -> b c (f h) w", f=4).contiguous() | |
# img = rearrange(sample['img'], "b f c h w -> b c (f h) w", f=4).contiguous() | |
# gt = rearrange(sample['nv_img'], "b f c h w -> b c (f h) w", f=4).contiguous() | |
# torchvision.utils.save_image(img, 'inp.jpg', normalize=True) | |
# torchvision.utils.save_image(gt, 'nv.jpg', normalize=True) | |
# ! shift nv | |
else: | |
for k, v in sample.items(): | |
if k not in ['ins', 'caption']: | |
rolled_idx = torch.LongTensor( | |
list( | |
itertools.chain.from_iterable( | |
list(range(i, sample['img'].shape[0], bs)) | |
for i in range(bs)))) | |
v = torch.index_select(v, dim=0, index=rolled_idx) | |
sample[k] = v | |
# img = sample['img'] | |
# gt = sample['nv_img'] | |
# torchvision.utils.save_image(img[0], 'inp.jpg', normalize=True) | |
# torchvision.utils.save_image(gt[0], 'nv.jpg', normalize=True) | |
for k, v in sample.items(): | |
if 'nv' in k: | |
rolled_idx = torch.LongTensor( | |
list( | |
itertools.chain.from_iterable( | |
list( | |
np.roll( | |
np.arange(i * chunk_size, (i + 1) * | |
chunk_size), 4) | |
for i in range(bs))))) | |
v = torch.index_select(v, dim=0, index=rolled_idx) | |
sample[k] = v | |
# torchvision.utils.save_image(sample['nv_img'], 'nv.png', normalize=True) | |
# torchvision.utils.save_image(sample['img'], 'inp.png', normalize=True) | |
return sample | |
post_process_cls = PostProcess( | |
reso, | |
reso_encoder, | |
imgnet_normalize=imgnet_normalize, | |
plucker_embedding=plucker_embedding, | |
decode_encode_img_only=decode_encode_img_only, | |
mv_input=mv_input, | |
split_chunk_input=split_chunk_input, | |
duplicate_sample=duplicate_sample, | |
append_depth=append_depth, | |
gs_cam_format=gs_cam_format, | |
orthog_duplicate=orthog_duplicate, | |
) | |
# ! add shuffling | |
if isinstance(file_path, list): # lst of shard urls | |
all_shards = [] | |
for url_path in file_path: | |
all_shards.extend(wds.shardlists.expand_source(url_path)) | |
logger.log('all_shards', all_shards) | |
else: | |
all_shards = file_path # to be expanded | |
if not load_instance: # during reconstruction training, load pair | |
if not split_chunk_input: | |
dataset = wds.DataPipeline( | |
wds.ResampledShards(all_shards), # url_shard | |
# at this point we have an iterator over all the shards | |
wds.shuffle(50), | |
wds.split_by_worker, # if multi-node | |
wds.tarfile_to_samples(), | |
# add wds.split_by_node here if you are using multiple nodes | |
wds.shuffle( | |
1000 | |
), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.decode(wds.autodecode.basichandlers), # TODO | |
wds.to_tuple( | |
"sample.pyd"), # extract the pyd from top level dict | |
wds.map(post_process_cls.decode_zip), | |
wds.map(post_process_cls.paired_post_process | |
), # create input-novelview paired samples | |
# wds.map(post_process_cls._post_process_sample), | |
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.batched( | |
16, | |
partial=True, | |
# collation_fn=collate | |
) # streaming more data at once, and rebatch later | |
) | |
else: | |
dataset = wds.DataPipeline( | |
wds.ResampledShards(all_shards), # url_shard | |
# at this point we have an iterator over all the shards | |
wds.shuffle(100), | |
wds.split_by_worker, # if multi-node | |
wds.tarfile_to_samples(), | |
# add wds.split_by_node here if you are using multiple nodes | |
wds.shuffle( | |
# 7500 if not duplicate_sample else 2500 | |
# 7500 if not duplicate_sample else 5000 | |
# 1000, | |
250, | |
), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.decode(wds.autodecode.basichandlers), # TODO | |
wds.to_tuple( | |
"sample.pyd"), # extract the pyd from top level dict | |
wds.map(post_process_cls.decode_zip), | |
wds.map(post_process_cls.paired_post_process_chunk | |
), # create input-novelview paired samples | |
# wds.map(post_process_cls._post_process_sample), | |
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.batched( | |
20, | |
partial=True, | |
# collation_fn=collate | |
) # streaming more data at once, and rebatch later | |
) | |
loader_shard = wds.WebLoader( | |
dataset, | |
num_workers=num_workers, | |
drop_last=False, | |
batch_size=None, | |
shuffle=False, | |
persistent_workers=num_workers | |
# > 0).unbatched().shuffle(1000).batched(batch_size).map( | |
> 0).unbatched().shuffle(250).batched(batch_size).map( | |
post_process_cls.create_dict) | |
if mv_input: | |
loader_shard = loader_shard.map(post_process_cls.prepare_mv_input) | |
else: # load single instance during test/eval | |
assert batch_size == 1 | |
dataset = wds.DataPipeline( | |
wds.ResampledShards(all_shards), # url_shard | |
# at this point we have an iterator over all the shards | |
wds.shuffle(50), | |
wds.split_by_worker, # if multi-node | |
wds.tarfile_to_samples(), | |
# add wds.split_by_node here if you are using multiple nodes | |
wds.detshuffle( | |
100 | |
), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.decode(wds.autodecode.basichandlers), # TODO | |
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict | |
wds.map(post_process_cls.decode_zip), | |
# wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples | |
wds.map(post_process_cls._post_process_batch_sample), | |
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.batched( | |
2, | |
partial=True, | |
# collation_fn=collate | |
) # streaming more data at once, and rebatch later | |
) | |
loader_shard = wds.WebLoader( | |
dataset, | |
num_workers=num_workers, | |
drop_last=False, | |
batch_size=None, | |
shuffle=False, | |
persistent_workers=num_workers | |
> 0).unbatched().shuffle(200).batched(batch_size).map( | |
post_process_cls.single_instance_sample_create_dict) | |
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) | |
# 1000).batched(batch_size).map(post_process_cls.create_dict) | |
# .map(collate) | |
# .map(collate) | |
# .batched(batch_size) | |
# | |
# .unbatched().shuffle(1000).batched(batch_size).map(post_process) | |
# # https://github.com/webdataset/webdataset/issues/187 | |
# return next(iter(loader_shard)) | |
#return dataset | |
return loader_shard | |
# test tar loading | |
def load_wds_diff_ResampledShard(file_path, | |
batch_size, | |
num_workers, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
plucker_embedding=False, | |
decode_encode_img_only=False, | |
mv_latent_dir='', | |
**kwargs): | |
# return raw_img, depth, c, bbox, sample_pyd['ins.pyd'], sample_pyd['fname.pyd'] | |
class PostProcess: | |
def __init__( | |
self, | |
reso, | |
reso_encoder, | |
imgnet_normalize, | |
plucker_embedding, | |
decode_encode_img_only, | |
mv_latent_dir, | |
) -> None: | |
self.plucker_embedding = plucker_embedding | |
self.mv_latent_dir = mv_latent_dir | |
self.decode_encode_img_only = decode_encode_img_only | |
transformations = [ | |
transforms.ToTensor(), # [0,1] range | |
] | |
if imgnet_normalize: | |
transformations.append( | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225)) # type: ignore | |
) | |
else: | |
transformations.append( | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))) # type: ignore | |
self.normalize = transforms.Compose(transformations) | |
self.reso_encoder = reso_encoder | |
self.reso = reso | |
self.instance_data_length = 40 | |
# self.pair_per_instance = 1 # compat | |
self.pair_per_instance = 2 # check whether improves IO | |
# self.pair_per_instance = 3 # check whether improves IO | |
# self.pair_per_instance = 4 # check whether improves IO | |
def get_rays_kiui(self, c, opengl=True): | |
h, w = self.reso_encoder, self.reso_encoder | |
intrinsics, pose = c[16:], c[:16].reshape(4, 4) | |
# cx, cy, fx, fy = intrinsics[2], intrinsics[5] | |
fx = fy = 525 # pixel space | |
cx = cy = 256 # rendering default K | |
factor = self.reso / (cx * 2) # 128 / 512 | |
fx = fx * factor | |
fy = fy * factor | |
x, y = torch.meshgrid( | |
torch.arange(w, device=pose.device), | |
torch.arange(h, device=pose.device), | |
indexing="xy", | |
) | |
x = x.flatten() | |
y = y.flatten() | |
cx = w * 0.5 | |
cy = h * 0.5 | |
# focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) | |
camera_dirs = F.pad( | |
torch.stack( | |
[ | |
(x - cx + 0.5) / fx, | |
(y - cy + 0.5) / fy * (-1.0 if opengl else 1.0), | |
], | |
dim=-1, | |
), | |
(0, 1), | |
value=(-1.0 if opengl else 1.0), | |
) # [hw, 3] | |
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] | |
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] | |
rays_o = rays_o.view(h, w, 3) | |
rays_d = safe_normalize(rays_d).view(h, w, 3) | |
return rays_o, rays_d | |
def gen_rays(self, c): | |
# Generate rays | |
intrinsics, c2w = c[16:], c[:16].reshape(4, 4) | |
self.h = self.reso_encoder | |
self.w = self.reso_encoder | |
yy, xx = torch.meshgrid( | |
torch.arange(self.h, dtype=torch.float32) + 0.5, | |
torch.arange(self.w, dtype=torch.float32) + 0.5, | |
indexing='ij') | |
# normalize to 0-1 pixel range | |
yy = yy / self.h | |
xx = xx / self.w | |
# K = np.array([f_x, 0, w / 2, 0, f_y, h / 2, 0, 0, 1]).reshape(3, 3) | |
cx, cy, fx, fy = intrinsics[2], intrinsics[5], intrinsics[ | |
0], intrinsics[4] | |
# cx *= self.w | |
# cy *= self.h | |
# f_x = f_y = fx * h / res_raw | |
c2w = torch.from_numpy(c2w).float() | |
xx = (xx - cx) / fx | |
yy = (yy - cy) / fy | |
zz = torch.ones_like(xx) | |
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
dirs = dirs.reshape(-1, 3, 1) | |
del xx, yy, zz | |
# st() | |
dirs = (c2w[None, :3, :3] @ dirs)[..., 0] | |
origins = c2w[None, :3, 3].expand(self.h * self.w, -1).contiguous() | |
origins = origins.view(self.h, self.w, 3) | |
dirs = dirs.view(self.h, self.w, 3) | |
return origins, dirs | |
def _post_process_sample(self, data_sample): | |
# raw_img, depth, c, bbox, caption, ins = data_sample | |
raw_img, c, caption, ins = data_sample | |
# bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range | |
# if raw_img.shape[-2] != self.reso_encoder: | |
# img_to_encoder = cv2.resize( | |
# raw_img, (self.reso_encoder, self.reso_encoder), | |
# interpolation=cv2.INTER_LANCZOS4) | |
# else: | |
# img_to_encoder = raw_img | |
# img_to_encoder = self.normalize(img_to_encoder) | |
# if self.plucker_embedding: | |
# rays_o, rays_d = self.gen_rays(c) | |
# rays_plucker = torch.cat( | |
# [torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
# dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w | |
# img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0) | |
# img = cv2.resize(raw_img, (self.reso, self.reso), | |
# interpolation=cv2.INTER_LANCZOS4) | |
img = raw_img # 256x256 | |
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 | |
# load latent | |
latent_path = Path(self.mv_latent_dir, ins, 'latent.npy') | |
latent = np.load(latent_path) | |
# return (img_to_encoder, img, c, caption, ins) | |
return (latent, img, c, caption, ins) | |
def rand_sample_idx(self): | |
return random.randint(0, self.instance_data_length - 1) | |
def rand_pair(self): | |
return (self.rand_sample_idx() for _ in range(2)) | |
def paired_post_process(self, sample): | |
# repeat n times? | |
all_inp_list = [] | |
all_nv_list = [] | |
caption, ins = sample[-2:] | |
# expanded_return = [] | |
for _ in range(self.pair_per_instance): | |
cano_idx, nv_idx = self.rand_pair() | |
cano_sample = self._post_process_sample( | |
item[cano_idx] for item in sample[:-2]) | |
nv_sample = self._post_process_sample(item[nv_idx] | |
for item in sample[:-2]) | |
all_inp_list.extend(cano_sample) | |
all_nv_list.extend(nv_sample) | |
return (*all_inp_list, *all_nv_list, caption, ins) | |
# return [cano_sample, nv_sample, caption, ins] | |
# return (*cano_sample, *nv_sample, caption, ins) | |
# def single_sample_create_dict(self, sample, prefix=''): | |
# # if len(sample) == 1: | |
# # sample = sample[0] | |
# # assert len(sample) == 6 | |
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
# return { | |
# # **sample, | |
# f'{prefix}img_to_encoder': img_to_encoder, | |
# f'{prefix}img': img, | |
# f'{prefix}depth_mask': fg_mask_reso, | |
# f'{prefix}depth': depth_reso, | |
# f'{prefix}c': c, | |
# f'{prefix}bbox': bbox, | |
# } | |
def single_sample_create_dict(self, sample, prefix=''): | |
# if len(sample) == 1: | |
# sample = sample[0] | |
# assert len(sample) == 6 | |
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
# img_to_encoder, img, c, caption, ins = sample | |
# img, c, caption, ins = sample | |
latent, img, c, caption, ins = sample | |
# load latent | |
return { | |
# **sample, | |
# 'img_to_encoder': img_to_encoder, | |
'latent': latent, | |
'img': img, | |
'c': c, | |
'caption': caption, | |
'ins': ins | |
} | |
def decode_zip(self, sample_pyd, shape=(256, 256)): | |
if isinstance(sample_pyd, tuple): | |
sample_pyd = sample_pyd[0] | |
assert isinstance(sample_pyd, dict) | |
raw_img = decompress_and_open_image_gzip( | |
sample_pyd['raw_img'], | |
is_img=True, | |
decompress=True, | |
decompress_fn=lz4.frame.decompress) | |
caption = sample_pyd['caption'].decode('utf-8') | |
ins = sample_pyd['ins'].decode('utf-8') | |
c = decompress_array(sample_pyd['c'], (25, ), | |
np.float32, | |
decompress=True, | |
decompress_fn=lz4.frame.decompress) | |
# bbox = decompress_array( | |
# sample_pyd['bbox'], | |
# ( | |
# 40, | |
# 4, | |
# ), | |
# np.float32, | |
# # decompress=False) | |
# decompress=True, | |
# decompress_fn=lz4.frame.decompress) | |
# if self.decode_encode_img_only: | |
# depth = np.zeros(shape=(40, *shape)) # save loading time | |
# else: | |
# depth = decompress_array(sample_pyd['depth'], (40, *shape), | |
# np.float32, | |
# decompress=True, | |
# decompress_fn=lz4.frame.decompress) | |
# return {'raw_img': raw_img, 'depth': depth, 'bbox': bbox, 'caption': caption, 'ins': ins, 'c': c} | |
# return raw_img, depth, c, bbox, caption, ins | |
# return raw_img, bbox, caption, ins | |
# return bbox, caption, ins | |
return raw_img, c, caption, ins | |
# ! run single-instance pipeline first | |
# return raw_img[0], depth[0], c[0], bbox[0], caption, ins | |
def create_dict(self, sample): | |
# sample = [item[0] for item in sample] # wds wrap items in [] | |
# cano_sample_list = [[] for _ in range(6)] | |
# nv_sample_list = [[] for _ in range(6)] | |
# for idx in range(0, self.pair_per_instance): | |
# cano_sample = sample[6*idx:6*(idx+1)] | |
# nv_sample = sample[6*self.pair_per_instance+6*idx:6*self.pair_per_instance+6*(idx+1)] | |
# for item_idx in range(6): | |
# cano_sample_list[item_idx].append(cano_sample[item_idx]) | |
# nv_sample_list[item_idx].append(nv_sample[item_idx]) | |
# # ! cycle input/output view for more pairs | |
# cano_sample_list[item_idx].append(nv_sample[item_idx]) | |
# nv_sample_list[item_idx].append(cano_sample[item_idx]) | |
cano_sample = self.single_sample_create_dict(sample, prefix='') | |
# nv_sample = self.single_sample_create_dict((torch.cat(item_list) for item_list in nv_sample_list) , prefix='nv_') | |
return cano_sample | |
# return { | |
# **cano_sample, | |
# # **nv_sample, | |
# 'caption': sample[-2], | |
# 'ins': sample[-1] | |
# } | |
post_process_cls = PostProcess( | |
reso, | |
reso_encoder, | |
imgnet_normalize=imgnet_normalize, | |
plucker_embedding=plucker_embedding, | |
decode_encode_img_only=decode_encode_img_only, | |
mv_latent_dir=mv_latent_dir, | |
) | |
if isinstance(file_path, list): # lst of shard urls | |
all_shards = [] | |
for url_path in file_path: | |
all_shards.extend(wds.shardlists.expand_source(url_path)) | |
logger.log('all_shards', all_shards) | |
else: | |
all_shards = file_path # to be expanded | |
dataset = wds.DataPipeline( | |
wds.ResampledShards(all_shards), # url_shard | |
# at this point we have an iterator over all the shards | |
wds.shuffle(50), | |
wds.split_by_worker, # if multi-node | |
wds.tarfile_to_samples(), | |
# add wds.split_by_node here if you are using multiple nodes | |
wds.detshuffle( | |
15000 | |
), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.decode(wds.autodecode.basichandlers), # TODO | |
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict | |
wds.map(post_process_cls.decode_zip), | |
# wds.map(post_process_cls.paired_post_process), # create input-novelview paired samples | |
wds.map(post_process_cls._post_process_sample), | |
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.batched( | |
80, | |
partial=True, | |
# collation_fn=collate | |
) # streaming more data at once, and rebatch later | |
) | |
loader_shard = wds.WebLoader( | |
dataset, | |
num_workers=num_workers, | |
drop_last=False, | |
batch_size=None, | |
shuffle=False, | |
persistent_workers=num_workers | |
> 0).unbatched().shuffle(2500).batched(batch_size).map( | |
post_process_cls.create_dict) | |
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) | |
# 1000).batched(batch_size).map(post_process_cls.create_dict) | |
# .map(collate) | |
# .map(collate) | |
# .batched(batch_size) | |
# | |
# .unbatched().shuffle(1000).batched(batch_size).map(post_process) | |
# # https://github.com/webdataset/webdataset/issues/187 | |
# return next(iter(loader_shard)) | |
#return dataset | |
return loader_shard | |
def load_wds_data( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
num_workers=6, | |
plucker_embedding=False, | |
decode_encode_img_only=False, | |
load_wds_diff=False, | |
load_wds_latent=False, | |
load_instance=False, # for evaluation | |
mv_input=False, | |
split_chunk_input=False, | |
duplicate_sample=True, | |
mv_latent_dir='', | |
append_depth=False, | |
gs_cam_format=False, | |
orthog_duplicate=False, | |
**args): | |
if load_wds_diff: | |
assert num_workers == 0 # on aliyun, worker=0 performs much much faster | |
wds_loader = load_wds_diff_ResampledShard( | |
file_path, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
reso=reso, | |
reso_encoder=reso_encoder, | |
plucker_embedding=plucker_embedding, | |
decode_encode_img_only=decode_encode_img_only, | |
mv_input=mv_input, | |
split_chunk_input=split_chunk_input, | |
append_depth=append_depth, | |
mv_latent_dir=mv_latent_dir, | |
gs_cam_format=gs_cam_format, | |
orthog_duplicate=orthog_duplicate, | |
) | |
elif load_wds_latent: | |
# for diffusion training, cache latent | |
wds_loader = load_wds_latent_ResampledShard( | |
file_path, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
reso=reso, | |
reso_encoder=reso_encoder, | |
plucker_embedding=plucker_embedding, | |
decode_encode_img_only=decode_encode_img_only, | |
mv_input=mv_input, | |
split_chunk_input=split_chunk_input, | |
) | |
# elif load_instance: | |
# wds_loader = load_wds_instance_ResampledShard( | |
# file_path, | |
# batch_size=batch_size, | |
# num_workers=num_workers, | |
# reso=reso, | |
# reso_encoder=reso_encoder, | |
# plucker_embedding=plucker_embedding, | |
# decode_encode_img_only=decode_encode_img_only | |
# ) | |
else: | |
wds_loader = load_wds_ResampledShard( | |
file_path, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
reso=reso, | |
reso_encoder=reso_encoder, | |
plucker_embedding=plucker_embedding, | |
decode_encode_img_only=decode_encode_img_only, | |
load_instance=load_instance, | |
mv_input=mv_input, | |
split_chunk_input=split_chunk_input, | |
duplicate_sample=duplicate_sample, | |
append_depth=append_depth, | |
gs_cam_format=gs_cam_format, | |
orthog_duplicate=orthog_duplicate, | |
) | |
while True: | |
yield from wds_loader | |
# yield from wds_loader | |
class PostProcess_forlatent: | |
def __init__( | |
self, | |
reso, | |
reso_encoder, | |
imgnet_normalize, | |
plucker_embedding, | |
decode_encode_img_only, | |
) -> None: | |
self.plucker_embedding = plucker_embedding | |
self.decode_encode_img_only = decode_encode_img_only | |
transformations = [ | |
transforms.ToTensor(), # [0,1] range | |
] | |
if imgnet_normalize: | |
transformations.append( | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225)) # type: ignore | |
) | |
else: | |
transformations.append( | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))) # type: ignore | |
self.normalize = transforms.Compose(transformations) | |
self.reso_encoder = reso_encoder | |
self.reso = reso | |
self.instance_data_length = 40 | |
# self.pair_per_instance = 1 # compat | |
self.pair_per_instance = 2 # check whether improves IO | |
# self.pair_per_instance = 3 # check whether improves IO | |
# self.pair_per_instance = 4 # check whether improves IO | |
def _post_process_sample(self, data_sample): | |
# raw_img, depth, c, bbox, caption, ins = data_sample | |
raw_img, c, caption, ins = data_sample | |
# bbox = (bbox*(self.reso/256)).astype(np.uint8) # normalize bbox to the reso range | |
if raw_img.shape[-2] != self.reso_encoder: | |
img_to_encoder = cv2.resize( | |
raw_img, (self.reso_encoder, self.reso_encoder), | |
interpolation=cv2.INTER_LANCZOS4) | |
else: | |
img_to_encoder = raw_img | |
img_to_encoder = self.normalize(img_to_encoder) | |
if self.plucker_embedding: | |
rays_o, rays_d = self.gen_rays(c) | |
rays_plucker = torch.cat( | |
[torch.cross(rays_o, rays_d, dim=-1), rays_d], | |
dim=-1).permute(2, 0, 1) # [h, w, 6] -> 6,h,w | |
img_to_encoder = torch.cat([img_to_encoder, rays_plucker], 0) | |
img = cv2.resize(raw_img, (self.reso, self.reso), | |
interpolation=cv2.INTER_LANCZOS4) | |
img = torch.from_numpy(img).permute(2, 0, 1) / 127.5 - 1 | |
return (img_to_encoder, img, c, caption, ins) | |
def rand_sample_idx(self): | |
return random.randint(0, self.instance_data_length - 1) | |
def rand_pair(self): | |
return (self.rand_sample_idx() for _ in range(2)) | |
def paired_post_process(self, sample): | |
# repeat n times? | |
all_inp_list = [] | |
all_nv_list = [] | |
caption, ins = sample[-2:] | |
# expanded_return = [] | |
for _ in range(self.pair_per_instance): | |
cano_idx, nv_idx = self.rand_pair() | |
cano_sample = self._post_process_sample( | |
item[cano_idx] for item in sample[:-2]) | |
nv_sample = self._post_process_sample(item[nv_idx] | |
for item in sample[:-2]) | |
all_inp_list.extend(cano_sample) | |
all_nv_list.extend(nv_sample) | |
return (*all_inp_list, *all_nv_list, caption, ins) | |
# return [cano_sample, nv_sample, caption, ins] | |
# return (*cano_sample, *nv_sample, caption, ins) | |
# def single_sample_create_dict(self, sample, prefix=''): | |
# # if len(sample) == 1: | |
# # sample = sample[0] | |
# # assert len(sample) == 6 | |
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
# return { | |
# # **sample, | |
# f'{prefix}img_to_encoder': img_to_encoder, | |
# f'{prefix}img': img, | |
# f'{prefix}depth_mask': fg_mask_reso, | |
# f'{prefix}depth': depth_reso, | |
# f'{prefix}c': c, | |
# f'{prefix}bbox': bbox, | |
# } | |
def single_sample_create_dict(self, sample, prefix=''): | |
# if len(sample) == 1: | |
# sample = sample[0] | |
# assert len(sample) == 6 | |
# img_to_encoder, img, fg_mask_reso, depth_reso, c, bbox = sample | |
img_to_encoder, img, c, caption, ins = sample | |
return { | |
# **sample, | |
'img_to_encoder': img_to_encoder, | |
'img': img, | |
'c': c, | |
'caption': caption, | |
'ins': ins | |
} | |
def decode_zip(self, sample_pyd, shape=(256, 256)): | |
if isinstance(sample_pyd, tuple): | |
sample_pyd = sample_pyd[0] | |
assert isinstance(sample_pyd, dict) | |
latent = sample_pyd['latent'] | |
caption = sample_pyd['caption'].decode('utf-8') | |
c = sample_pyd['c'] | |
# img = sample_pyd['img'] | |
# st() | |
return latent, caption, c | |
def create_dict(self, sample): | |
return { | |
# **sample, | |
'latent': sample[0], | |
'caption': sample[1], | |
'c': sample[2], | |
} | |
# test tar loading | |
def load_wds_latent_ResampledShard(file_path, | |
batch_size, | |
num_workers, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
plucker_embedding=False, | |
decode_encode_img_only=False, | |
**kwargs): | |
post_process_cls = PostProcess_forlatent( | |
reso, | |
reso_encoder, | |
imgnet_normalize=imgnet_normalize, | |
plucker_embedding=plucker_embedding, | |
decode_encode_img_only=decode_encode_img_only, | |
) | |
if isinstance(file_path, list): # lst of shard urls | |
all_shards = [] | |
for url_path in file_path: | |
all_shards.extend(wds.shardlists.expand_source(url_path)) | |
logger.log('all_shards', all_shards) | |
else: | |
all_shards = file_path # to be expanded | |
dataset = wds.DataPipeline( | |
wds.ResampledShards(all_shards), # url_shard | |
# at this point we have an iterator over all the shards | |
wds.shuffle(50), | |
wds.split_by_worker, # if multi-node | |
wds.tarfile_to_samples(), | |
# add wds.split_by_node here if you are using multiple nodes | |
wds.detshuffle( | |
2500 | |
), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.decode(wds.autodecode.basichandlers), # TODO | |
wds.to_tuple("sample.pyd"), # extract the pyd from top level dict | |
wds.map(post_process_cls.decode_zip), | |
# wds.map(post_process_cls._post_process_sample), | |
# wds.detshuffle(1000), # shuffles in the memory, leverage large RAM for more efficient loading | |
wds.batched( | |
150, | |
partial=True, | |
# collation_fn=collate | |
) # streaming more data at once, and rebatch later | |
) | |
loader_shard = wds.WebLoader( | |
dataset, | |
num_workers=num_workers, | |
drop_last=False, | |
batch_size=None, | |
shuffle=False, | |
persistent_workers=num_workers | |
> 0).unbatched().shuffle(1000).batched(batch_size).map( | |
post_process_cls.create_dict) | |
# persistent_workers=num_workers > 0).unbatched().batched(batch_size).map(post_process_cls.create_dict) | |
# 1000).batched(batch_size).map(post_process_cls.create_dict) | |
# .map(collate) | |
# .map(collate) | |
# .batched(batch_size) | |
# | |
# .unbatched().shuffle(1000).batched(batch_size).map(post_process) | |
# # https://github.com/webdataset/webdataset/issues/187 | |
# return next(iter(loader_shard)) | |
#return dataset | |
return loader_shard | |