|
import io |
|
import cv2 |
|
import numpy as np |
|
import h5py |
|
import torch |
|
from numpy.linalg import inv |
|
import re |
|
|
|
|
|
try: |
|
|
|
from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT |
|
except Exception: |
|
MEGADEPTH_CLIENT = SCANNET_CLIENT = None |
|
|
|
|
|
|
|
|
|
def load_array_from_s3( |
|
path, |
|
client, |
|
cv_type, |
|
use_h5py=False, |
|
): |
|
byte_str = client.Get(path) |
|
try: |
|
if not use_h5py: |
|
raw_array = np.fromstring(byte_str, np.uint8) |
|
data = cv2.imdecode(raw_array, cv_type) |
|
else: |
|
f = io.BytesIO(byte_str) |
|
data = np.array(h5py.File(f, "r")["/depth"]) |
|
except Exception as ex: |
|
print(f"==> Data loading failure: {path}") |
|
raise ex |
|
|
|
assert data is not None |
|
return data |
|
|
|
|
|
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT): |
|
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR |
|
if str(path).startswith("s3://"): |
|
image = load_array_from_s3(str(path), client, cv_type) |
|
else: |
|
image = cv2.imread(str(path), cv_type) |
|
|
|
if augment_fn is not None: |
|
image = cv2.imread(str(path), cv2.IMREAD_COLOR) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
image = augment_fn(image) |
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
|
return image |
|
|
|
|
|
def get_resized_wh(w, h, resize=None): |
|
if resize is not None: |
|
scale = resize / max(h, w) |
|
w_new, h_new = int(round(w * scale)), int(round(h * scale)) |
|
else: |
|
w_new, h_new = w, h |
|
return w_new, h_new |
|
|
|
|
|
def get_divisible_wh(w, h, df=None): |
|
if df is not None: |
|
w_new, h_new = map(lambda x: int(x // df * df), [w, h]) |
|
else: |
|
w_new, h_new = w, h |
|
return w_new, h_new |
|
|
|
|
|
def pad_bottom_right(inp, pad_size, ret_mask=False): |
|
assert isinstance(pad_size, int) and pad_size >= max( |
|
inp.shape[-2:] |
|
), f"{pad_size} < {max(inp.shape[-2:])}" |
|
mask = None |
|
if inp.ndim == 2: |
|
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) |
|
padded[: inp.shape[0], : inp.shape[1]] = inp |
|
if ret_mask: |
|
mask = np.zeros((pad_size, pad_size), dtype=bool) |
|
mask[: inp.shape[0], : inp.shape[1]] = True |
|
elif inp.ndim == 3: |
|
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) |
|
padded[:, : inp.shape[1], : inp.shape[2]] = inp |
|
if ret_mask: |
|
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) |
|
mask[:, : inp.shape[1], : inp.shape[2]] = True |
|
else: |
|
raise NotImplementedError() |
|
return padded, mask |
|
|
|
|
|
|
|
|
|
|
|
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None): |
|
""" |
|
Args: |
|
resize (int, optional): the longer edge of resized images. None for no resize. |
|
padding (bool): If set to 'True', zero-pad resized images to squared size. |
|
augment_fn (callable, optional): augments images with pre-defined visual effects |
|
Returns: |
|
image (torch.tensor): (1, h, w) |
|
mask (torch.tensor): (h, w) |
|
scale (torch.tensor): [w/w_new, h/h_new] |
|
""" |
|
|
|
image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT) |
|
|
|
|
|
w, h = image.shape[1], image.shape[0] |
|
w_new, h_new = get_resized_wh(w, h, resize) |
|
w_new, h_new = get_divisible_wh(w_new, h_new, df) |
|
|
|
image = cv2.resize(image, (w_new, h_new)) |
|
scale = torch.tensor([w / w_new, h / h_new], dtype=torch.float) |
|
|
|
if padding: |
|
pad_to = max(h_new, w_new) |
|
image, mask = pad_bottom_right(image, pad_to, ret_mask=True) |
|
else: |
|
mask = None |
|
|
|
image = ( |
|
torch.from_numpy(image).float()[None] / 255 |
|
) |
|
if mask is not None: |
|
mask = torch.from_numpy(mask) |
|
|
|
return image, mask, scale |
|
|
|
|
|
def read_megadepth_depth(path, pad_to=None): |
|
if str(path).startswith("s3://"): |
|
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True) |
|
else: |
|
depth = np.array(h5py.File(path, "r")["depth"]) |
|
if pad_to is not None: |
|
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False) |
|
depth = torch.from_numpy(depth).float() |
|
return depth |
|
|
|
|
|
|
|
|
|
|
|
def read_scannet_gray(path, resize=(640, 480), augment_fn=None): |
|
""" |
|
Args: |
|
resize (tuple): align image to depthmap, in (w, h). |
|
augment_fn (callable, optional): augments images with pre-defined visual effects |
|
Returns: |
|
image (torch.tensor): (1, h, w) |
|
mask (torch.tensor): (h, w) |
|
scale (torch.tensor): [w/w_new, h/h_new] |
|
""" |
|
|
|
image = imread_gray(path, augment_fn) |
|
image = cv2.resize(image, resize) |
|
|
|
|
|
image = torch.from_numpy(image).float()[None] / 255 |
|
return image |
|
|
|
|
|
def read_scannet_depth(path): |
|
if str(path).startswith("s3://"): |
|
depth = load_array_from_s3(str(path), SCANNET_CLIENT, cv2.IMREAD_UNCHANGED) |
|
else: |
|
depth = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) |
|
depth = depth / 1000 |
|
depth = torch.from_numpy(depth).float() |
|
return depth |
|
|
|
|
|
def read_scannet_pose(path): |
|
"""Read ScanNet's Camera2World pose and transform it to World2Camera. |
|
|
|
Returns: |
|
pose_w2c (np.ndarray): (4, 4) |
|
""" |
|
cam2world = np.loadtxt(path, delimiter=" ") |
|
world2cam = inv(cam2world) |
|
return world2cam |
|
|
|
|
|
def read_scannet_intrinsic(path): |
|
"""Read ScanNet's intrinsic matrix and return the 3x3 matrix.""" |
|
intrinsic = np.loadtxt(path, delimiter=" ") |
|
return intrinsic[:-1, :-1] |
|
|
|
|
|
def read_gl3d_gray(path, resize): |
|
img = cv2.resize(cv2.imread(path, cv2.IMREAD_GRAYSCALE), (int(resize), int(resize))) |
|
img = ( |
|
torch.from_numpy(img).float()[None] / 255 |
|
) |
|
return img |
|
|
|
|
|
def read_gl3d_depth(file_path): |
|
with open(file_path, "rb") as fin: |
|
color = None |
|
width = None |
|
height = None |
|
scale = None |
|
data_type = None |
|
header = str(fin.readline().decode("UTF-8")).rstrip() |
|
if header == "PF": |
|
color = True |
|
elif header == "Pf": |
|
color = False |
|
else: |
|
raise Exception("Not a PFM file.") |
|
dim_match = re.match(r"^(\d+)\s(\d+)\s$", fin.readline().decode("UTF-8")) |
|
if dim_match: |
|
width, height = map(int, dim_match.groups()) |
|
else: |
|
raise Exception("Malformed PFM header.") |
|
scale = float((fin.readline().decode("UTF-8")).rstrip()) |
|
if scale < 0: |
|
data_type = "<f" |
|
else: |
|
data_type = ">f" |
|
data_string = fin.read() |
|
data = np.fromstring(data_string, data_type) |
|
shape = (height, width, 3) if color else (height, width) |
|
data = np.reshape(data, shape) |
|
data = np.flip(data, 0) |
|
return torch.from_numpy(data.copy()).float() |
|
|