Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torchvision | |
import pickle | |
from typing import Any | |
import lmdb | |
import cv2 | |
import imageio | |
import numpy as np | |
from PIL import Image | |
import Imath | |
import OpenEXR | |
from pdb import set_trace as st | |
from pathlib import Path | |
from functools import partial | |
import io | |
import gzip | |
import random | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, Dataset | |
from torchvision import transforms | |
from torch.utils.data.distributed import DistributedSampler | |
from pathlib import Path | |
from guided_diffusion import logger | |
def load_dataset( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
trainer_name='input_rec', | |
use_lmdb=False, | |
infi_sampler=True | |
): | |
# st() | |
# dataset_cls = { | |
# 'input_rec': MultiViewDataset, | |
# 'nv': NovelViewDataset, | |
# }[trainer_name] | |
# st() | |
if use_lmdb: | |
logger.log('using LMDB dataset') | |
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
if 'nv' in trainer_name: | |
dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
# dataset = dataset_cls(file_path) | |
else: | |
if 'nv' in trainer_name: | |
dataset_cls = NovelViewDataset # 1.5-2iter/s | |
else: | |
dataset_cls = MultiViewDataset | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
dataset_size=dataset_size) | |
logger.log('dataset_cls: {}, dataset size: {}'.format( | |
trainer_name, len(dataset))) | |
loader = DataLoader(dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
shuffle=False) | |
return loader | |
def load_data( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
trainer_name='input_rec', | |
use_lmdb=False, | |
infi_sampler=True | |
): | |
# st() | |
# dataset_cls = { | |
# 'input_rec': MultiViewDataset, | |
# 'nv': NovelViewDataset, | |
# }[trainer_name] | |
# st() | |
if use_lmdb: | |
logger.log('using LMDB dataset') | |
# dataset_cls = LMDBDataset_MV # 2.5-3iter/s, but unstable, drops to 1 later. | |
if 'nv' in trainer_name: | |
dataset_cls = LMDBDataset_NV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
else: | |
dataset_cls = LMDBDataset_MV_Compressed # 2.5-3iter/s, but unstable, drops to 1 later. | |
# dataset = dataset_cls(file_path) | |
else: | |
if 'nv' in trainer_name: | |
dataset_cls = NovelViewDataset # 1.5-2iter/s | |
else: | |
dataset_cls = MultiViewDataset | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
dataset_size=dataset_size) | |
logger.log('dataset_cls: {}, dataset size: {}'.format( | |
trainer_name, len(dataset))) | |
# st() | |
if infi_sampler: | |
train_sampler = DistributedSampler(dataset=dataset, | |
shuffle=True, | |
drop_last=True) | |
loader = DataLoader(dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=True, | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
sampler=train_sampler) | |
while True: | |
yield from loader | |
else: | |
# loader = DataLoader(dataset, | |
# batch_size=batch_size, | |
# num_workers=num_workers, | |
# drop_last=False, | |
# pin_memory=True, | |
# persistent_workers=num_workers > 0, | |
# shuffle=False) | |
st() | |
return dataset | |
def load_eval_rays(file_path="", | |
reso=64, | |
reso_encoder=224, | |
imgnet_normalize=True): | |
dataset = MultiViewDataset(file_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=imgnet_normalize) | |
pose_list = dataset.single_pose_list | |
ray_list = [] | |
for pose_fname in pose_list: | |
# c2w = dataset.get_c2w(pose_fname).reshape(1,4,4) #[1, 4, 4] | |
# rays_o, rays_d = dataset.gen_rays(c2w) | |
# ray_list.append( | |
# [rays_o.unsqueeze(0), | |
# rays_d.unsqueeze(0), | |
# c2w.reshape(-1, 16)]) | |
c2w = dataset.get_c2w(pose_fname).reshape(16) #[1, 4, 4] | |
c = torch.cat([c2w, dataset.intrinsics], | |
dim=0).reshape(25) # 25, no '1' dim needed. | |
ray_list.append(c) | |
return ray_list | |
def load_eval_data(file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
num_workers=1, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
interval=1, **kwargs): | |
dataset = MultiViewDataset(file_path, | |
reso, | |
reso_encoder, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
test=True, | |
imgnet_normalize=imgnet_normalize, | |
interval=interval) | |
print('eval dataset size: {}'.format(len(dataset))) | |
# train_sampler = DistributedSampler(dataset=dataset) | |
loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
shuffle=False, | |
) | |
# sampler=train_sampler) | |
return loader | |
def load_memory_data(file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
num_workers=1, | |
load_depth=True, | |
preprocess=None, | |
imgnet_normalize=True): | |
# load a single-instance into the memory to speed up training IO | |
dataset = MultiViewDataset(file_path, | |
reso, | |
reso_encoder, | |
preprocess=preprocess, | |
load_depth=True, | |
test=False, | |
overfitting=True, | |
imgnet_normalize=imgnet_normalize, | |
overfitting_bs=batch_size) | |
logger.log('!!!!!!! memory dataset size: {} !!!!!!'.format(len(dataset))) | |
# train_sampler = DistributedSampler(dataset=dataset) | |
loader = DataLoader( | |
dataset, | |
batch_size=len(dataset), | |
num_workers=num_workers, | |
drop_last=False, | |
shuffle=False, | |
) | |
all_data: dict = next(iter(loader)) | |
while True: | |
start_idx = np.random.randint(0, len(dataset) - batch_size + 1) | |
yield { | |
k: v[start_idx:start_idx + batch_size] | |
for k, v in all_data.items() | |
} | |
class MultiViewDataset(Dataset): | |
def __init__(self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1, | |
interval=1): | |
self.file_path = file_path | |
self.overfitting = overfitting | |
self.scene_scale = scene_scale | |
self.reso = reso | |
self.reso_encoder = reso_encoder | |
self.classes = False | |
self.load_depth = load_depth | |
self.preprocess = preprocess | |
assert not self.classes, "Not support class condition now." | |
# self.ins_list = os.listdir(self.file_path) | |
# if test: # TODO | |
dataset_name = Path(self.file_path).stem.split('_')[0] | |
self.dataset_name = dataset_name | |
if test: | |
# ins_list_file = Path(self.file_path).parent / f'{dataset_name}_test_list.txt' # ? in domain | |
if dataset_name == 'chair': | |
self.ins_list = sorted(os.listdir( | |
self.file_path))[1:2] # more diversity | |
else: | |
self.ins_list = sorted(os.listdir(self.file_path))[ | |
0:1] # the first 1 instance for evaluation reference. | |
else: | |
# self.ins_list = sorted(Path(self.file_path).glob('[0-8]*')) | |
# self.ins_list = Path(self.file_path).glob('*') | |
# self.ins_list = list(Path(self.file_path).glob('*'))[:dataset_size] | |
# ins_list_file = Path( | |
# self.file_path).parent / f'{dataset_name}s_train_list.txt' | |
# assert ins_list_file.exists(), 'add training list for ShapeNet' | |
# with open(ins_list_file, 'r') as f: | |
# self.ins_list = [name.strip() for name in f.readlines()] | |
# if dataset_name == 'chair': | |
ins_list_file = Path( | |
self.file_path).parent / f'{dataset_name}_train_list.txt' | |
# st() | |
assert ins_list_file.exists(), 'add training list for ShapeNet' | |
with open(ins_list_file, 'r') as f: | |
self.ins_list = [name.strip() | |
for name in f.readlines()][:dataset_size] | |
# else: | |
# self.ins_list = Path(self.file_path).glob('*') | |
if overfitting: | |
self.ins_list = self.ins_list[:1] | |
self.rgb_list = [] | |
self.pose_list = [] | |
self.depth_list = [] | |
self.data_ins_list = [] | |
self.instance_data_length = -1 | |
for ins in self.ins_list: | |
cur_rgb_path = os.path.join(self.file_path, ins, 'rgb') | |
cur_pose_path = os.path.join(self.file_path, ins, 'pose') | |
cur_all_fname = sorted([ | |
t.split('.')[0] for t in os.listdir(cur_rgb_path) | |
if 'depth' not in t | |
][::interval]) | |
if self.instance_data_length == -1: | |
self.instance_data_length = len(cur_all_fname) | |
else: | |
assert len(cur_all_fname) == self.instance_data_length | |
# ! check filtered data | |
# for idx in range(len(cur_all_fname)): | |
# fname = cur_all_fname[idx] | |
# if not Path(os.path.join(cur_rgb_path, fname + '.png') ).exists(): | |
# cur_all_fname.remove(fname) | |
# del cur_all_fname[idx] | |
if test: | |
mid_index = len(cur_all_fname) // 3 * 2 | |
cur_all_fname.insert(0, cur_all_fname[mid_index]) | |
self.pose_list += ([ | |
os.path.join(cur_pose_path, fname + '.txt') | |
for fname in cur_all_fname | |
]) | |
self.rgb_list += ([ | |
os.path.join(cur_rgb_path, fname + '.png') | |
for fname in cur_all_fname | |
]) | |
self.depth_list += ([ | |
os.path.join(cur_rgb_path, fname + '_depth0001.exr') | |
for fname in cur_all_fname | |
]) | |
self.data_ins_list += ([ins] * len(cur_all_fname)) | |
# validate overfitting on images | |
if overfitting: | |
# bs=9 | |
# self.pose_list = self.pose_list[::50//9+1] | |
# self.rgb_list = self.rgb_list[::50//9+1] | |
# self.depth_list = self.depth_list[::50//9+1] | |
# bs=6 | |
# self.pose_list = self.pose_list[::50//6+1] | |
# self.rgb_list = self.rgb_list[::50//6+1] | |
# self.depth_list = self.depth_list[::50//6+1] | |
# bs=3 | |
assert overfitting_bs != -1 | |
# bs=1 | |
# self.pose_list = self.pose_list[25:26] | |
# self.rgb_list = self.rgb_list[25:26] | |
# self.depth_list = self.depth_list[25:26] | |
# uniform pose sampling | |
self.pose_list = self.pose_list[::50//overfitting_bs+1] | |
self.rgb_list = self.rgb_list[::50//overfitting_bs+1] | |
self.depth_list = self.depth_list[::50//overfitting_bs+1] | |
# sequentially sampling pose | |
# self.pose_list = self.pose_list[25:25+overfitting_bs] | |
# self.rgb_list = self.rgb_list[25:25+overfitting_bs] | |
# self.depth_list = self.depth_list[25:25+overfitting_bs] | |
# duplicate the same pose | |
# self.pose_list = [self.pose_list[25]] * overfitting_bs | |
# self.rgb_list = [self.rgb_list[25]] * overfitting_bs | |
# self.depth_list = [self.depth_list[25]] * overfitting_bs | |
# self.pose_list = [self.pose_list[28]] * overfitting_bs | |
# self.rgb_list = [self.rgb_list[28]] * overfitting_bs | |
# self.depth_list = [self.depth_list[28]] * overfitting_bs | |
self.single_pose_list = [ | |
os.path.join(cur_pose_path, fname + '.txt') | |
for fname in cur_all_fname | |
] | |
# st() | |
# if imgnet_normalize: | |
transformations = [ | |
transforms.ToTensor(), # [0,1] range | |
] | |
if imgnet_normalize: | |
transformations.append( | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225)) # type: ignore | |
) | |
else: | |
transformations.append( | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))) # type: ignore | |
self.normalize = transforms.Compose(transformations) | |
# self.normalize_normalrange = transforms.Compose([ | |
# transforms.ToTensor(),# [0,1] range | |
# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
# ]) | |
fx = fy = 525 | |
cx = cy = 256 # rendering default K | |
factor = self.reso / (cx * 2) # 128 / 512 | |
self.fx = fx * factor | |
self.fy = fy * factor | |
self.cx = cx * factor | |
self.cy = cy * factor | |
# ! fix scale for triplane ray_sampler(), here we adopt [0,1] uv range, not [0, w] img space range. | |
self.cx /= self.reso # 0.5 | |
self.cy /= self.reso # 0.5 | |
self.fx /= self.reso | |
self.fy /= self.reso | |
intrinsics = np.array([[self.fx, 0, self.cx], [0, self.fy, self.cy], | |
[0, 0, 1]]).reshape(9) | |
# self.intrinsics = torch.from_numpy(intrinsics).float() | |
self.intrinsics = intrinsics | |
def __len__(self): | |
return len(self.rgb_list) | |
def get_c2w(self, pose_fname): | |
with open(pose_fname, 'r') as f: | |
cam2world = f.readline().strip() | |
cam2world = [float(t) for t in cam2world.split(' ')] | |
c2w = torch.tensor(cam2world, dtype=torch.float32).reshape(4, 4) | |
return c2w | |
def gen_rays(self, c2w): | |
# Generate rays | |
self.h = self.reso | |
self.w = self.reso | |
yy, xx = torch.meshgrid( | |
torch.arange(self.h, dtype=torch.float32) + 0.5, | |
torch.arange(self.w, dtype=torch.float32) + 0.5, | |
indexing='ij') | |
xx = (xx - self.cx) / self.fx | |
yy = (yy - self.cy) / self.fy | |
zz = torch.ones_like(xx) | |
dirs = torch.stack((xx, yy, zz), dim=-1) # OpenCV convention | |
dirs /= torch.norm(dirs, dim=-1, keepdim=True) | |
dirs = dirs.reshape(1, -1, 3, 1) | |
del xx, yy, zz | |
dirs = (c2w[:, None, :3, :3] @ dirs)[..., 0] | |
origins = c2w[:, None, :3, 3].expand(-1, self.h * self.w, | |
-1).contiguous() | |
origins = origins.view(-1, 3) | |
dirs = dirs.view(-1, 3) | |
return origins, dirs | |
def read_depth(self, idx): | |
depth_path = self.depth_list[idx] | |
# image_path = os.path.join(depth_fname, self.image_names[index]) | |
exr = OpenEXR.InputFile(depth_path) | |
header = exr.header() | |
size = (header['dataWindow'].max.x - header['dataWindow'].min.x + 1, | |
header['dataWindow'].max.y - header['dataWindow'].min.y + 1) | |
FLOAT = Imath.PixelType(Imath.PixelType.FLOAT) | |
depth_str = exr.channel('B', FLOAT) | |
depth = np.frombuffer(depth_str, | |
dtype=np.float32).reshape(size[1], | |
size[0]) # H W | |
depth = np.nan_to_num(depth, posinf=0, neginf=0) | |
depth = depth.reshape(size) | |
def resize_depth_mask(depth_to_resize, resolution): | |
depth_resized = cv2.resize(depth_to_resize, | |
(resolution, resolution), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
return depth_resized > 0 # type: ignore | |
fg_mask_reso = resize_depth_mask(depth, self.reso) | |
fg_mask_sr = resize_depth_mask(depth, 128) | |
# depth = cv2.resize(depth, (self.reso, self.reso), | |
# interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
# depth_mask = depth > 0 | |
# depth = np.expand_dims(depth, axis=0).reshape(size) | |
# return torch.from_numpy(depth) | |
return torch.from_numpy(depth), torch.from_numpy( | |
fg_mask_reso), torch.from_numpy(fg_mask_sr) | |
def load_bbox(self, mask): | |
nonzero_value = torch.nonzero(mask) | |
height, width = nonzero_value.max(dim=0)[0] | |
top, left = nonzero_value.min(dim=0)[0] | |
bbox = torch.tensor([top, left, height, width], dtype=torch.float32) | |
return bbox | |
def __getitem__(self, idx): | |
rgb_fname = self.rgb_list[idx] | |
pose_fname = self.pose_list[idx] | |
raw_img = imageio.imread(rgb_fname) | |
if self.preprocess is None: | |
img_to_encoder = cv2.resize(raw_img, | |
(self.reso_encoder, self.reso_encoder), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
img_to_encoder = img_to_encoder[ | |
..., :3] #[3, reso_encoder, reso_encoder] | |
img_to_encoder = self.normalize(img_to_encoder) | |
else: | |
img_to_encoder = self.preprocess(Image.open(rgb_fname)) # clip | |
img = cv2.resize(raw_img, (self.reso, self.reso), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
# img_sr = cv2.resize(raw_img, (512, 512), interpolation=cv2.INTER_AREA) | |
# img_sr = cv2.resize(raw_img, (256, 256), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution | |
# img_sr = cv2.resize(raw_img, (128, 128), interpolation=cv2.INTER_AREA) # just as refinement, since eg3d uses 64->128 final resolution | |
img_sr = cv2.resize( | |
raw_img, (128, 128), interpolation=cv2.INTER_LANCZOS4 | |
) # just as refinement, since eg3d uses 64->128 final resolution | |
# img = torch.from_numpy(img)[..., :3].permute( | |
# 2, 0, 1) / 255.0 #[3, reso, reso] | |
img = torch.from_numpy(img)[..., :3].permute( | |
2, 0, 1 | |
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
img_sr = torch.from_numpy(img_sr)[..., :3].permute( | |
2, 0, 1 | |
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
# c2w = self.get_c2w(pose_fname).reshape(1, 4, 4) #[1, 4, 4] | |
# rays_o, rays_d = self.gen_rays(c2w) | |
# return img_to_encoder, img, rays_o, rays_d, c2w.reshape(-1) | |
c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16] | |
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. | |
c = torch.cat([c2w, torch.from_numpy(self.intrinsics)], | |
dim=0).reshape(25) # 25, no '1' dim needed. | |
ret_dict = { | |
# 'rgb_fname': rgb_fname, | |
'img_to_encoder': img_to_encoder, | |
'img': img, | |
'c': c, | |
'img_sr': img_sr, | |
# 'ins_name': self.data_ins_list[idx] | |
} | |
if self.load_depth: | |
depth, depth_mask, depth_mask_sr = self.read_depth(idx) | |
bbox = self.load_bbox(depth_mask) | |
ret_dict.update({ | |
'depth': depth, | |
'depth_mask': depth_mask, | |
'depth_mask_sr': depth_mask_sr, | |
'bbox': bbox | |
}) | |
# rays_o, rays_d = self.gen_rays(c2w) | |
# return img_to_encoder, img, c | |
return ret_dict | |
class MultiViewDatasetforLMDB(MultiViewDataset): | |
def __init__(self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1): | |
super().__init__(file_path, reso, reso_encoder, preprocess, classes, | |
load_depth, test, scene_scale, overfitting, | |
imgnet_normalize, dataset_size, overfitting_bs) | |
def __len__(self): | |
return super().__len__() | |
# return 100 # for speed debug | |
def __getitem__(self, idx): | |
# ret_dict = super().__getitem__(idx) | |
rgb_fname = self.rgb_list[idx] | |
pose_fname = self.pose_list[idx] | |
raw_img = imageio.imread(rgb_fname)[..., :3] | |
c2w = self.get_c2w(pose_fname).reshape(16) #[1, 4, 4] -> [1, 16] | |
# c = np.concatenate([c2w, self.intrinsics], axis=0).reshape(25) # 25, no '1' dim needed. | |
c = torch.cat([c2w, torch.from_numpy(self.intrinsics)], | |
dim=0).reshape(25) # 25, no '1' dim needed. | |
depth, depth_mask, depth_mask_sr = self.read_depth(idx) | |
bbox = self.load_bbox(depth_mask) | |
ret_dict = { | |
'raw_img': raw_img, | |
'c': c, | |
'depth': depth, | |
# 'depth_mask': depth_mask, # 64x64 here? | |
'bbox': bbox | |
} | |
return ret_dict | |
def load_data_dryrun( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True): | |
# st() | |
dataset = MultiViewDataset(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize) | |
print('dataset size: {}'.format(len(dataset))) | |
# st() | |
# train_sampler = DistributedSampler(dataset=dataset) | |
loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
# shuffle=shuffle, | |
drop_last=False, | |
) | |
# sampler=train_sampler) | |
return loader | |
class NovelViewDataset(MultiViewDataset): | |
"""novel view prediction version. | |
""" | |
def __init__(self, | |
file_path, | |
reso, | |
reso_encoder, | |
preprocess=None, | |
classes=False, | |
load_depth=False, | |
test=False, | |
scene_scale=1, | |
overfitting=False, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
overfitting_bs=-1): | |
super().__init__(file_path, reso, reso_encoder, preprocess, classes, | |
load_depth, test, scene_scale, overfitting, | |
imgnet_normalize, dataset_size, overfitting_bs) | |
def __getitem__(self, idx): | |
input_view = super().__getitem__( | |
idx) # get previous input view results | |
# get novel view of the same instance | |
novel_view = super().__getitem__( | |
(idx // self.instance_data_length) * self.instance_data_length + | |
random.randint(0, self.instance_data_length - 1)) | |
# assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
return input_view | |
def load_data_for_lmdb( | |
file_path="", | |
reso=64, | |
reso_encoder=224, | |
batch_size=1, | |
# shuffle=True, | |
num_workers=6, | |
load_depth=False, | |
preprocess=None, | |
imgnet_normalize=True, | |
dataset_size=-1, | |
trainer_name='input_rec'): | |
# st() | |
# dataset_cls = { | |
# 'input_rec': MultiViewDataset, | |
# 'nv': NovelViewDataset, | |
# }[trainer_name] | |
# if 'nv' in trainer_name: | |
# dataset_cls = NovelViewDataset | |
# else: | |
# dataset_cls = MultiViewDataset | |
dataset_cls = MultiViewDatasetforLMDB | |
dataset = dataset_cls(file_path, | |
reso, | |
reso_encoder, | |
test=False, | |
preprocess=preprocess, | |
load_depth=load_depth, | |
imgnet_normalize=imgnet_normalize, | |
dataset_size=dataset_size) | |
logger.log('dataset_cls: {}, dataset size: {}'.format( | |
trainer_name, len(dataset))) | |
# train_sampler = DistributedSampler(dataset=dataset, shuffle=True, drop_last=True) | |
loader = DataLoader( | |
dataset, | |
shuffle=False, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
drop_last=False, | |
prefetch_factor=2, | |
# prefetch_factor=3, | |
pin_memory=True, | |
persistent_workers=True, | |
) | |
# sampler=train_sampler) | |
# while True: | |
# yield from loader | |
return loader, dataset.dataset_name, len(dataset) | |
class LMDBDataset(Dataset): | |
def __init__(self, lmdb_path): | |
self.env = lmdb.open( | |
lmdb_path, | |
readonly=True, | |
max_readers=32, | |
lock=False, | |
readahead=False, | |
meminit=False, | |
) | |
self.num_samples = self.env.stat()['entries'] | |
# self.start_idx = self.env.stat()['start_idx'] | |
# self.end_idx = self.env.stat()['end_idx'] | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, idx): | |
with self.env.begin(write=False) as txn: | |
key = str(idx).encode('utf-8') | |
value = txn.get(key) | |
sample = pickle.loads(value) | |
return sample | |
def resize_depth_mask(depth_to_resize, resolution): | |
depth_resized = cv2.resize(depth_to_resize, (resolution, resolution), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
return depth_resized, depth_resized > 0 # type: ignore | |
class LMDBDataset_MV(LMDBDataset): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
**kwargs): | |
super().__init__(lmdb_path) | |
self.reso_encoder = reso_encoder | |
self.reso = reso | |
transformations = [ | |
transforms.ToTensor(), # [0,1] range | |
] | |
if imgnet_normalize: | |
transformations.append( | |
transforms.Normalize((0.485, 0.456, 0.406), | |
(0.229, 0.224, 0.225)) # type: ignore | |
) | |
else: | |
transformations.append( | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5))) # type: ignore | |
self.normalize = transforms.Compose(transformations) | |
def _post_process_sample(self, raw_img, depth): | |
# if raw_img.shape[-1] == 4: # ! set bg to white | |
# alpha_mask = raw_img[..., -1:] > 0 | |
# raw_img = alpha_mask * raw_img[..., :3] + (1-alpha_mask) * np.ones_like(raw_img[..., :3]) * 255 | |
# raw_img = raw_img.astype(np.uint8) | |
# img_to_encoder = cv2.resize(sample.pop('raw_img'), | |
img_to_encoder = cv2.resize(raw_img, | |
(self.reso_encoder, self.reso_encoder), | |
interpolation=cv2.INTER_LANCZOS4) | |
# interpolation=cv2.INTER_AREA) | |
img_to_encoder = img_to_encoder[..., : | |
3] #[3, reso_encoder, reso_encoder] | |
img_to_encoder = self.normalize(img_to_encoder) | |
img = cv2.resize(raw_img, (self.reso, self.reso), | |
interpolation=cv2.INTER_LANCZOS4) | |
if img.shape[-1] == 4: | |
alpha_mask = img[..., -1:] > 0 | |
img = alpha_mask * img[..., :3] + (1-alpha_mask) * np.ones_like(img[..., :3]) * 255 | |
img = torch.from_numpy(img)[..., :3].permute( | |
2, 0, 1 | |
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
img_sr = torch.from_numpy(raw_img)[..., :3].permute( | |
2, 0, 1 | |
) / 127.5 - 1 #[3, reso, reso], normalize to [-1,1], follow triplane range | |
# depth | |
# fg_mask_reso = resize_depth_mask(sample['depth'], self.reso) | |
depth_reso, fg_mask_reso = resize_depth_mask(depth, self.reso) | |
return { | |
# **sample, | |
'img_to_encoder': img_to_encoder, | |
'img': img, | |
'depth_mask': fg_mask_reso, | |
'img_sr': img_sr, | |
'depth': depth_reso, | |
# ! no need to load img_sr for now | |
} | |
def __getitem__(self, idx): | |
sample = super().__getitem__(idx) | |
# do transformations online | |
return self._post_process_sample(sample['raw_img'], sample['depth']) | |
# return sample | |
def load_bytes(inp_bytes, dtype, shape): | |
return np.frombuffer(inp_bytes, dtype=dtype).reshape(shape).copy() | |
# Function to decompress an image using gzip and open with imageio | |
def decompress_and_open_image_gzip(compressed_data, is_img=False): | |
# Decompress the image data using gzip | |
decompressed_data = gzip.decompress(compressed_data) | |
# Read the decompressed image using imageio | |
if is_img: | |
image = imageio.v3.imread(io.BytesIO(decompressed_data)).copy() | |
return image | |
return decompressed_data | |
# Function to decompress an array using gzip | |
def decompress_array(compressed_data, shape, dtype): | |
# Decompress the array data using gzip | |
decompressed_data = gzip.decompress(compressed_data) | |
# Convert the decompressed data to a NumPy array | |
# arr = np.frombuffer(decompressed_data, dtype=dtype).reshape(shape) | |
return load_bytes(decompressed_data, dtype, shape) | |
class LMDBDataset_MV_Compressed(LMDBDataset_MV): | |
def __init__(self, | |
lmdb_path, | |
reso, | |
reso_encoder, | |
imgnet_normalize=True, | |
**kwargs): | |
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, | |
**kwargs) | |
with self.env.begin(write=False) as txn: | |
self.length = int( | |
txn.get('length'.encode('utf-8')).decode('utf-8')) - 40 | |
self.load_image_fn = partial(decompress_and_open_image_gzip, | |
is_img=True) | |
def __len__(self): | |
return self.length | |
def _load_lmdb_data(self, idx): | |
with self.env.begin(write=False) as txn: | |
raw_img_key = f'{idx}-raw_img'.encode('utf-8') | |
raw_img = self.load_image_fn(txn.get(raw_img_key)) | |
depth_key = f'{idx}-depth'.encode('utf-8') | |
depth = decompress_array(txn.get(depth_key), (512,512), np.float32) | |
c_key = f'{idx}-c'.encode('utf-8') | |
c = decompress_array(txn.get(c_key), (25, ), np.float32) | |
bbox_key = f'{idx}-bbox'.encode('utf-8') | |
bbox = decompress_array(txn.get(bbox_key), (4, ), np.float32) | |
return raw_img, depth, c, bbox | |
def __getitem__(self, idx): | |
# sample = super(LMDBDataset).__getitem__(idx) | |
# do gzip uncompress online | |
raw_img, depth, c, bbox = self._load_lmdb_data(idx) | |
return { | |
**self._post_process_sample(raw_img, depth), 'c': c, | |
'bbox': bbox*(self.reso/64.0), | |
# 'depth': depth, | |
} | |
class LMDBDataset_NV_Compressed(LMDBDataset_MV_Compressed): | |
def __init__(self, lmdb_path, reso, reso_encoder, imgnet_normalize=True, **kwargs): | |
super().__init__(lmdb_path, reso, reso_encoder, imgnet_normalize, **kwargs) | |
self.instance_data_length = 50 # | |
def __getitem__(self, idx): | |
input_view = super().__getitem__( | |
idx) # get previous input view results | |
# get novel view of the same instance | |
try: | |
novel_view = super().__getitem__( | |
(idx // self.instance_data_length) * self.instance_data_length + | |
random.randint(0, self.instance_data_length - 1)) | |
except Exception as e: | |
raise NotImplementedError(idx) | |
assert input_view['ins_name'] == novel_view['ins_name'], 'should sample novel view from the same instance' | |
input_view.update({f'nv_{k}': v for k, v in novel_view.items()}) | |
return input_view |