lemonaddie
commited on
Delete utils
Browse files- utils/batch_size.py +0 -63
- utils/colormap.py +0 -45
- utils/common.py +0 -42
- utils/dataset_configuration.py +0 -81
- utils/de_normalized.py +0 -33
- utils/depth2normal.py +0 -186
- utils/depth_ensemble.py +0 -115
- utils/image_util.py +0 -83
- utils/normal_ensemble.py +0 -22
- utils/seed_all.py +0 -33
- utils/surface_normal.py +0 -213
utils/batch_size.py
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import math
|
5 |
-
|
6 |
-
|
7 |
-
# Search table for suggested max. inference batch size
|
8 |
-
bs_search_table = [
|
9 |
-
# tested on A100-PCIE-80GB
|
10 |
-
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
11 |
-
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
12 |
-
# tested on A100-PCIE-40GB
|
13 |
-
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
14 |
-
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
15 |
-
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
16 |
-
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
17 |
-
# tested on RTX3090, RTX4090
|
18 |
-
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
19 |
-
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
20 |
-
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
21 |
-
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
22 |
-
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
23 |
-
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
24 |
-
# tested on GTX1080Ti
|
25 |
-
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
26 |
-
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
27 |
-
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
28 |
-
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
29 |
-
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
30 |
-
]
|
31 |
-
|
32 |
-
|
33 |
-
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
34 |
-
"""
|
35 |
-
Automatically search for suitable operating batch size.
|
36 |
-
|
37 |
-
Args:
|
38 |
-
ensemble_size (`int`):
|
39 |
-
Number of predictions to be ensembled.
|
40 |
-
input_res (`int`):
|
41 |
-
Operating resolution of the input image.
|
42 |
-
|
43 |
-
Returns:
|
44 |
-
`int`: Operating batch size.
|
45 |
-
"""
|
46 |
-
if not torch.cuda.is_available():
|
47 |
-
return 1
|
48 |
-
|
49 |
-
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
50 |
-
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
51 |
-
for settings in sorted(
|
52 |
-
filtered_bs_search_table,
|
53 |
-
key=lambda k: (k["res"], -k["total_vram"]),
|
54 |
-
):
|
55 |
-
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
56 |
-
bs = settings["bs"]
|
57 |
-
if bs > ensemble_size:
|
58 |
-
bs = ensemble_size
|
59 |
-
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
60 |
-
bs = math.ceil(ensemble_size / 2)
|
61 |
-
return bs
|
62 |
-
|
63 |
-
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/colormap.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import cv2
|
5 |
-
|
6 |
-
def kitti_colormap(disparity, maxval=-1):
|
7 |
-
"""
|
8 |
-
A utility function to reproduce KITTI fake colormap
|
9 |
-
Arguments:
|
10 |
-
- disparity: numpy float32 array of dimension HxW
|
11 |
-
- maxval: maximum disparity value for normalization (if equal to -1, the maximum value in disparity will be used)
|
12 |
-
|
13 |
-
Returns a numpy uint8 array of shape HxWx3.
|
14 |
-
"""
|
15 |
-
if maxval < 0:
|
16 |
-
maxval = np.max(disparity)
|
17 |
-
|
18 |
-
colormap = np.asarray([[0,0,0,114],[0,0,1,185],[1,0,0,114],[1,0,1,174],[0,1,0,114],[0,1,1,185],[1,1,0,114],[1,1,1,0]])
|
19 |
-
weights = np.asarray([8.771929824561404,5.405405405405405,8.771929824561404,5.747126436781609,8.771929824561404,5.405405405405405,8.771929824561404,0])
|
20 |
-
cumsum = np.asarray([0,0.114,0.299,0.413,0.587,0.701,0.8859999999999999,0.9999999999999999])
|
21 |
-
|
22 |
-
colored_disp = np.zeros([disparity.shape[0], disparity.shape[1], 3])
|
23 |
-
values = np.expand_dims(np.minimum(np.maximum(disparity/maxval, 0.), 1.), -1)
|
24 |
-
bins = np.repeat(np.repeat(np.expand_dims(np.expand_dims(cumsum,axis=0),axis=0), disparity.shape[1], axis=1), disparity.shape[0], axis=0)
|
25 |
-
diffs = np.where((np.repeat(values, 8, axis=-1) - bins) > 0, -1000, (np.repeat(values, 8, axis=-1) - bins))
|
26 |
-
index = np.argmax(diffs, axis=-1)-1
|
27 |
-
|
28 |
-
w = 1-(values[:,:,0]-cumsum[index])*np.asarray(weights)[index]
|
29 |
-
|
30 |
-
|
31 |
-
colored_disp[:,:,2] = (w*colormap[index][:,:,0] + (1.-w)*colormap[index+1][:,:,0])
|
32 |
-
colored_disp[:,:,1] = (w*colormap[index][:,:,1] + (1.-w)*colormap[index+1][:,:,1])
|
33 |
-
colored_disp[:,:,0] = (w*colormap[index][:,:,2] + (1.-w)*colormap[index+1][:,:,2])
|
34 |
-
|
35 |
-
return (colored_disp*np.expand_dims((disparity>0),-1)*255).astype(np.uint8)
|
36 |
-
|
37 |
-
def read_16bit_gt(path):
|
38 |
-
"""
|
39 |
-
A utility function to read KITTI 16bit gt
|
40 |
-
Arguments:
|
41 |
-
- path: filepath
|
42 |
-
Returns a numpy float32 array of shape HxW.
|
43 |
-
"""
|
44 |
-
gt = cv2.imread(path,-1).astype(np.float32)/256.
|
45 |
-
return gt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/common.py
DELETED
@@ -1,42 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import json
|
4 |
-
import yaml
|
5 |
-
import logging
|
6 |
-
import os
|
7 |
-
import numpy as np
|
8 |
-
import sys
|
9 |
-
|
10 |
-
def load_loss_scheme(loss_config):
|
11 |
-
with open(loss_config, 'r') as f:
|
12 |
-
loss_json = yaml.safe_load(f)
|
13 |
-
return loss_json
|
14 |
-
|
15 |
-
|
16 |
-
DEBUG =0
|
17 |
-
logger = logging.getLogger()
|
18 |
-
|
19 |
-
|
20 |
-
if DEBUG:
|
21 |
-
#coloredlogs.install(level='DEBUG')
|
22 |
-
logger.setLevel(logging.DEBUG)
|
23 |
-
else:
|
24 |
-
#coloredlogs.install(level='INFO')
|
25 |
-
logger.setLevel(logging.INFO)
|
26 |
-
|
27 |
-
|
28 |
-
strhdlr = logging.StreamHandler()
|
29 |
-
logger.addHandler(strhdlr)
|
30 |
-
formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s')
|
31 |
-
strhdlr.setFormatter(formatter)
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
def count_parameters(model):
|
36 |
-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
37 |
-
|
38 |
-
def check_path(path):
|
39 |
-
if not os.path.exists(path):
|
40 |
-
os.makedirs(path, exist_ok=True)
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/dataset_configuration.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import numpy as np
|
7 |
-
import sys
|
8 |
-
sys.path.append("..")
|
9 |
-
|
10 |
-
from dataloader.mix_loader import MixDataset
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
-
from dataloader import transforms
|
13 |
-
import os
|
14 |
-
|
15 |
-
|
16 |
-
# Get Dataset Here
|
17 |
-
def prepare_dataset(data_dir=None,
|
18 |
-
batch_size=1,
|
19 |
-
test_batch=1,
|
20 |
-
datathread=4,
|
21 |
-
logger=None):
|
22 |
-
|
23 |
-
# set the config parameters
|
24 |
-
dataset_config_dict = dict()
|
25 |
-
|
26 |
-
train_dataset = MixDataset(data_dir=data_dir)
|
27 |
-
|
28 |
-
img_height, img_width = train_dataset.get_img_size()
|
29 |
-
|
30 |
-
datathread = datathread
|
31 |
-
if os.environ.get('datathread') is not None:
|
32 |
-
datathread = int(os.environ.get('datathread'))
|
33 |
-
|
34 |
-
if logger is not None:
|
35 |
-
logger.info("Use %d processes to load data..." % datathread)
|
36 |
-
|
37 |
-
train_loader = DataLoader(train_dataset, batch_size = batch_size, \
|
38 |
-
shuffle = True, num_workers = datathread, \
|
39 |
-
pin_memory = True)
|
40 |
-
|
41 |
-
num_batches_per_epoch = len(train_loader)
|
42 |
-
|
43 |
-
dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch
|
44 |
-
dataset_config_dict['img_size'] = (img_height,img_width)
|
45 |
-
|
46 |
-
return train_loader, dataset_config_dict
|
47 |
-
|
48 |
-
def depth_scale_shift_normalization(depth):
|
49 |
-
|
50 |
-
bsz = depth.shape[0]
|
51 |
-
|
52 |
-
depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy()
|
53 |
-
min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None]
|
54 |
-
max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None]
|
55 |
-
|
56 |
-
normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2
|
57 |
-
normalized_depth = torch.clip(normalized_depth, -1., 1.)
|
58 |
-
|
59 |
-
return normalized_depth
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
def resize_max_res_tensor(input_tensor, mode, recom_resolution=768):
|
64 |
-
assert input_tensor.shape[1]==3
|
65 |
-
original_H, original_W = input_tensor.shape[2:]
|
66 |
-
downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W)
|
67 |
-
|
68 |
-
if mode == 'normal':
|
69 |
-
resized_input_tensor = F.interpolate(input_tensor,
|
70 |
-
scale_factor=downscale_factor,
|
71 |
-
mode='nearest')
|
72 |
-
else:
|
73 |
-
resized_input_tensor = F.interpolate(input_tensor,
|
74 |
-
scale_factor=downscale_factor,
|
75 |
-
mode='bilinear',
|
76 |
-
align_corners=False)
|
77 |
-
|
78 |
-
if mode == 'depth':
|
79 |
-
return resized_input_tensor / downscale_factor
|
80 |
-
else:
|
81 |
-
return resized_input_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/de_normalized.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
from scipy.optimize import least_squares
|
5 |
-
import torch
|
6 |
-
|
7 |
-
def align_scale_shift(pred, target, clip_max):
|
8 |
-
mask = (target > 0) & (target < clip_max)
|
9 |
-
if mask.sum() > 10:
|
10 |
-
target_mask = target[mask]
|
11 |
-
pred_mask = pred[mask]
|
12 |
-
scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
|
13 |
-
return scale, shift
|
14 |
-
else:
|
15 |
-
return 1, 0
|
16 |
-
|
17 |
-
def align_scale(pred: torch.tensor, target: torch.tensor):
|
18 |
-
mask = target > 0
|
19 |
-
if torch.sum(mask) > 10:
|
20 |
-
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
|
21 |
-
else:
|
22 |
-
scale = 1
|
23 |
-
pred_scale = pred * scale
|
24 |
-
return pred_scale, scale
|
25 |
-
|
26 |
-
def align_shift(pred: torch.tensor, target: torch.tensor):
|
27 |
-
mask = target > 0
|
28 |
-
if torch.sum(mask) > 10:
|
29 |
-
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8)
|
30 |
-
else:
|
31 |
-
shift = 0
|
32 |
-
pred_shift = pred + shift
|
33 |
-
return pred_shift, shift
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/depth2normal.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import pickle
|
4 |
-
import os
|
5 |
-
import h5py
|
6 |
-
import numpy as np
|
7 |
-
import cv2
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
import glob
|
11 |
-
|
12 |
-
|
13 |
-
def init_image_coor(height, width):
|
14 |
-
x_row = np.arange(0, width)
|
15 |
-
x = np.tile(x_row, (height, 1))
|
16 |
-
x = x[np.newaxis, :, :]
|
17 |
-
x = x.astype(np.float32)
|
18 |
-
x = torch.from_numpy(x.copy()).cuda()
|
19 |
-
u_u0 = x - width/2.0
|
20 |
-
|
21 |
-
y_col = np.arange(0, height) # y_col = np.arange(0, height)
|
22 |
-
y = np.tile(y_col, (width, 1)).T
|
23 |
-
y = y[np.newaxis, :, :]
|
24 |
-
y = y.astype(np.float32)
|
25 |
-
y = torch.from_numpy(y.copy()).cuda()
|
26 |
-
v_v0 = y - height/2.0
|
27 |
-
return u_u0, v_v0
|
28 |
-
|
29 |
-
|
30 |
-
def depth_to_xyz(depth, focal_length):
|
31 |
-
b, c, h, w = depth.shape
|
32 |
-
u_u0, v_v0 = init_image_coor(h, w)
|
33 |
-
x = u_u0 * depth / focal_length[0]
|
34 |
-
y = v_v0 * depth / focal_length[1]
|
35 |
-
z = depth
|
36 |
-
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
|
37 |
-
return pw
|
38 |
-
|
39 |
-
|
40 |
-
def get_surface_normal(xyz, patch_size=5):
|
41 |
-
# xyz: [1, h, w, 3]
|
42 |
-
x, y, z = torch.unbind(xyz, dim=3)
|
43 |
-
x = torch.unsqueeze(x, 0)
|
44 |
-
y = torch.unsqueeze(y, 0)
|
45 |
-
z = torch.unsqueeze(z, 0)
|
46 |
-
|
47 |
-
xx = x * x
|
48 |
-
yy = y * y
|
49 |
-
zz = z * z
|
50 |
-
xy = x * y
|
51 |
-
xz = x * z
|
52 |
-
yz = y * z
|
53 |
-
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
|
54 |
-
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
|
55 |
-
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
|
56 |
-
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
|
57 |
-
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
|
58 |
-
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
|
59 |
-
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
|
60 |
-
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
|
61 |
-
dim=4)
|
62 |
-
ATA = torch.squeeze(ATA)
|
63 |
-
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
|
64 |
-
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
|
65 |
-
ATA = ATA + eps_identity
|
66 |
-
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
|
67 |
-
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
|
68 |
-
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
|
69 |
-
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
|
70 |
-
AT1 = torch.squeeze(AT1)
|
71 |
-
AT1 = torch.unsqueeze(AT1, 3)
|
72 |
-
|
73 |
-
patch_num = 4
|
74 |
-
patch_x = int(AT1.size(1) / patch_num)
|
75 |
-
patch_y = int(AT1.size(0) / patch_num)
|
76 |
-
n_img = torch.randn(AT1.shape).cuda()
|
77 |
-
overlap = patch_size // 2 + 1
|
78 |
-
for x in range(int(patch_num)):
|
79 |
-
for y in range(int(patch_num)):
|
80 |
-
left_flg = 0 if x == 0 else 1
|
81 |
-
right_flg = 0 if x == patch_num -1 else 1
|
82 |
-
top_flg = 0 if y == 0 else 1
|
83 |
-
btm_flg = 0 if y == patch_num - 1 else 1
|
84 |
-
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
85 |
-
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
86 |
-
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
87 |
-
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
88 |
-
# n_img_tmp, _ = torch.solve(at1, ata)
|
89 |
-
n_img_tmp = torch.linalg.solve(ata, at1)
|
90 |
-
|
91 |
-
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
|
92 |
-
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
|
93 |
-
|
94 |
-
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
|
95 |
-
n_img_norm = n_img / n_img_L2
|
96 |
-
|
97 |
-
# re-orient normals consistently
|
98 |
-
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
|
99 |
-
n_img_norm[orient_mask] *= -1
|
100 |
-
return n_img_norm
|
101 |
-
|
102 |
-
def get_surface_normalv2(xyz, patch_size=5):
|
103 |
-
"""
|
104 |
-
xyz: xyz coordinates
|
105 |
-
patch: [p1, p2, p3,
|
106 |
-
p4, p5, p6,
|
107 |
-
p7, p8, p9]
|
108 |
-
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
|
109 |
-
return: normal [h, w, 3, b]
|
110 |
-
"""
|
111 |
-
b, h, w, c = xyz.shape
|
112 |
-
half_patch = patch_size // 2
|
113 |
-
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
|
114 |
-
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
|
115 |
-
|
116 |
-
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
|
117 |
-
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
|
118 |
-
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
|
119 |
-
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
|
120 |
-
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
|
121 |
-
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
|
122 |
-
|
123 |
-
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
|
124 |
-
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
|
125 |
-
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
|
126 |
-
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
|
127 |
-
xyz_horizon = xyz_left - xyz_right # p4p6
|
128 |
-
xyz_vertical = xyz_top - xyz_bottom # p2p8
|
129 |
-
|
130 |
-
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
|
131 |
-
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
|
132 |
-
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
|
133 |
-
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
|
134 |
-
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
|
135 |
-
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
|
136 |
-
|
137 |
-
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
|
138 |
-
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
|
139 |
-
|
140 |
-
# re-orient normals consistently
|
141 |
-
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
|
142 |
-
n_img_1[orient_mask] *= -1
|
143 |
-
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
|
144 |
-
n_img_2[orient_mask] *= -1
|
145 |
-
|
146 |
-
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
|
147 |
-
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
|
148 |
-
|
149 |
-
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
|
150 |
-
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
|
151 |
-
|
152 |
-
# average 2 norms
|
153 |
-
n_img_aver = n_img1_norm + n_img2_norm
|
154 |
-
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
|
155 |
-
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
|
156 |
-
# re-orient normals consistently
|
157 |
-
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
|
158 |
-
n_img_aver_norm[orient_mask] *= -1
|
159 |
-
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
|
160 |
-
|
161 |
-
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
|
162 |
-
# plt.imshow(np.abs(a), cmap='rainbow')
|
163 |
-
# plt.show()
|
164 |
-
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
|
165 |
-
|
166 |
-
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
|
167 |
-
# para depth: depth map, [b, c, h, w]
|
168 |
-
b, c, h, w = depth.shape
|
169 |
-
focal_length = focal_length[:, None, None, None]
|
170 |
-
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
|
171 |
-
#depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
|
172 |
-
xyz = depth_to_xyz(depth_filter, focal_length)
|
173 |
-
sn_batch = []
|
174 |
-
for i in range(b):
|
175 |
-
xyz_i = xyz[i, :][None, :, :, :]
|
176 |
-
#normal = get_surface_normalv2(xyz_i)
|
177 |
-
normal = get_surface_normal(xyz_i)
|
178 |
-
sn_batch.append(normal)
|
179 |
-
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
|
180 |
-
|
181 |
-
if valid_mask != None:
|
182 |
-
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
|
183 |
-
sn_batch[mask_invalid] = 0.0
|
184 |
-
|
185 |
-
return sn_batch
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/depth_ensemble.py
DELETED
@@ -1,115 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from scipy.optimize import minimize
|
7 |
-
|
8 |
-
def inter_distances(tensors: torch.Tensor):
|
9 |
-
"""
|
10 |
-
To calculate the distance between each two depth maps.
|
11 |
-
"""
|
12 |
-
distances = []
|
13 |
-
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
14 |
-
arr1 = tensors[i : i + 1]
|
15 |
-
arr2 = tensors[j : j + 1]
|
16 |
-
distances.append(arr1 - arr2)
|
17 |
-
dist = torch.concat(distances, dim=0)
|
18 |
-
return dist
|
19 |
-
|
20 |
-
|
21 |
-
def ensemble_depths(input_images:torch.Tensor,
|
22 |
-
regularizer_strength: float =0.02,
|
23 |
-
max_iter: int =2,
|
24 |
-
tol:float =1e-3,
|
25 |
-
reduction: str='median',
|
26 |
-
max_res: int=None):
|
27 |
-
"""
|
28 |
-
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
29 |
-
by aligning estimating the scale and shift
|
30 |
-
"""
|
31 |
-
|
32 |
-
device = input_images.device
|
33 |
-
dtype = input_images.dtype
|
34 |
-
np_dtype = np.float32
|
35 |
-
|
36 |
-
|
37 |
-
original_input = input_images.clone()
|
38 |
-
n_img = input_images.shape[0]
|
39 |
-
ori_shape = input_images.shape
|
40 |
-
|
41 |
-
if max_res is not None:
|
42 |
-
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
43 |
-
if scale_factor < 1:
|
44 |
-
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
45 |
-
input_images = downscaler(torch.from_numpy(input_images)).numpy()
|
46 |
-
|
47 |
-
# init guess
|
48 |
-
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the min value of each possible depth
|
49 |
-
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) # get the max value of each possible depth
|
50 |
-
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) #(10,1,1) : re-scale'f scale
|
51 |
-
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) #(10,1,1)
|
52 |
-
|
53 |
-
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) #(20,)
|
54 |
-
|
55 |
-
input_images = input_images.to(device)
|
56 |
-
|
57 |
-
# objective function
|
58 |
-
def closure(x):
|
59 |
-
l = len(x)
|
60 |
-
s = x[: int(l / 2)]
|
61 |
-
t = x[int(l / 2) :]
|
62 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
63 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
64 |
-
|
65 |
-
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
66 |
-
dists = inter_distances(transformed_arrays)
|
67 |
-
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
68 |
-
|
69 |
-
if "mean" == reduction:
|
70 |
-
pred = torch.mean(transformed_arrays, dim=0)
|
71 |
-
elif "median" == reduction:
|
72 |
-
pred = torch.median(transformed_arrays, dim=0).values
|
73 |
-
else:
|
74 |
-
raise ValueError
|
75 |
-
|
76 |
-
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
77 |
-
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
78 |
-
|
79 |
-
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
80 |
-
err = err.detach().cpu().numpy().astype(np_dtype)
|
81 |
-
return err
|
82 |
-
|
83 |
-
res = minimize(
|
84 |
-
closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
|
85 |
-
)
|
86 |
-
x = res.x
|
87 |
-
l = len(x)
|
88 |
-
s = x[: int(l / 2)]
|
89 |
-
t = x[int(l / 2) :]
|
90 |
-
|
91 |
-
# Prediction
|
92 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
93 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
94 |
-
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) #[10,H,W]
|
95 |
-
|
96 |
-
|
97 |
-
if "mean" == reduction:
|
98 |
-
aligned_images = torch.mean(transformed_arrays, dim=0)
|
99 |
-
std = torch.std(transformed_arrays, dim=0)
|
100 |
-
uncertainty = std
|
101 |
-
|
102 |
-
elif "median" == reduction:
|
103 |
-
aligned_images = torch.median(transformed_arrays, dim=0).values
|
104 |
-
# MAD (median absolute deviation) as uncertainty indicator
|
105 |
-
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
106 |
-
mad = torch.median(abs_dev, dim=0).values
|
107 |
-
uncertainty = mad
|
108 |
-
|
109 |
-
# Scale and shift to [0, 1]
|
110 |
-
_min = torch.min(aligned_images)
|
111 |
-
_max = torch.max(aligned_images)
|
112 |
-
aligned_images = (aligned_images - _min) / (_max - _min)
|
113 |
-
uncertainty /= _max - _min
|
114 |
-
|
115 |
-
return aligned_images, uncertainty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/image_util.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import matplotlib
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from PIL import Image
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
12 |
-
"""
|
13 |
-
Resize image to limit maximum edge length while keeping aspect ratio.
|
14 |
-
Args:
|
15 |
-
img (`Image.Image`):
|
16 |
-
Image to be resized.
|
17 |
-
max_edge_resolution (`int`):
|
18 |
-
Maximum edge length (pixel).
|
19 |
-
Returns:
|
20 |
-
`Image.Image`: Resized image.
|
21 |
-
"""
|
22 |
-
|
23 |
-
original_width, original_height = img.size
|
24 |
-
|
25 |
-
downscale_factor = min(
|
26 |
-
max_edge_resolution / original_width, max_edge_resolution / original_height
|
27 |
-
)
|
28 |
-
|
29 |
-
new_width = int(original_width * downscale_factor)
|
30 |
-
new_height = int(original_height * downscale_factor)
|
31 |
-
|
32 |
-
resized_img = img.resize((new_width, new_height))
|
33 |
-
return resized_img
|
34 |
-
|
35 |
-
|
36 |
-
def colorize_depth_maps(
|
37 |
-
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
38 |
-
):
|
39 |
-
"""
|
40 |
-
Colorize depth maps.
|
41 |
-
"""
|
42 |
-
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
43 |
-
|
44 |
-
if isinstance(depth_map, torch.Tensor):
|
45 |
-
depth = depth_map.detach().clone().squeeze().numpy()
|
46 |
-
elif isinstance(depth_map, np.ndarray):
|
47 |
-
depth = depth_map.copy().squeeze()
|
48 |
-
# reshape to [ (B,) H, W ]
|
49 |
-
if depth.ndim < 3:
|
50 |
-
depth = depth[np.newaxis, :, :]
|
51 |
-
|
52 |
-
# colorize
|
53 |
-
cm = matplotlib.colormaps[cmap]
|
54 |
-
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
55 |
-
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
56 |
-
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
57 |
-
|
58 |
-
if valid_mask is not None:
|
59 |
-
if isinstance(depth_map, torch.Tensor):
|
60 |
-
valid_mask = valid_mask.detach().numpy()
|
61 |
-
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
62 |
-
if valid_mask.ndim < 3:
|
63 |
-
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
64 |
-
else:
|
65 |
-
valid_mask = valid_mask[:, np.newaxis, :, :]
|
66 |
-
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
67 |
-
img_colored_np[~valid_mask] = 0
|
68 |
-
|
69 |
-
if isinstance(depth_map, torch.Tensor):
|
70 |
-
img_colored = torch.from_numpy(img_colored_np).float()
|
71 |
-
elif isinstance(depth_map, np.ndarray):
|
72 |
-
img_colored = img_colored_np
|
73 |
-
|
74 |
-
return img_colored
|
75 |
-
|
76 |
-
|
77 |
-
def chw2hwc(chw):
|
78 |
-
assert 3 == len(chw.shape)
|
79 |
-
if isinstance(chw, torch.Tensor):
|
80 |
-
hwc = torch.permute(chw, (1, 2, 0))
|
81 |
-
elif isinstance(chw, np.ndarray):
|
82 |
-
hwc = np.moveaxis(chw, 0, -1)
|
83 |
-
return hwc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/normal_ensemble.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
|
6 |
-
def ensemble_normals(input_images:torch.Tensor):
|
7 |
-
normal_preds = input_images
|
8 |
-
|
9 |
-
bsz, d, h, w = normal_preds.shape
|
10 |
-
normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
|
11 |
-
|
12 |
-
phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
|
13 |
-
theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
|
14 |
-
normal_pred = torch.zeros((d,h,w)).to(normal_preds)
|
15 |
-
normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
|
16 |
-
normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
|
17 |
-
normal_pred[2,:,:] = torch.cos(theta)
|
18 |
-
|
19 |
-
angle_error = torch.acos(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1))
|
20 |
-
normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
|
21 |
-
|
22 |
-
return normal_preds[normal_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/seed_all.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
# --------------------------------------------------------------------------
|
15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
-
# --------------------------------------------------------------------------
|
19 |
-
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
import random
|
23 |
-
import torch
|
24 |
-
|
25 |
-
|
26 |
-
def seed_all(seed: int = 0):
|
27 |
-
"""
|
28 |
-
Set random seeds of all components.
|
29 |
-
"""
|
30 |
-
random.seed(seed)
|
31 |
-
np.random.seed(seed)
|
32 |
-
torch.manual_seed(seed)
|
33 |
-
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/surface_normal.py
DELETED
@@ -1,213 +0,0 @@
|
|
1 |
-
# A reimplemented version in public environments by Xiao Fu and Mu Hu
|
2 |
-
|
3 |
-
import torch
|
4 |
-
import numpy as np
|
5 |
-
import torch.nn as nn
|
6 |
-
|
7 |
-
|
8 |
-
def init_image_coor(height, width):
|
9 |
-
x_row = np.arange(0, width)
|
10 |
-
x = np.tile(x_row, (height, 1))
|
11 |
-
x = x[np.newaxis, :, :]
|
12 |
-
x = x.astype(np.float32)
|
13 |
-
x = torch.from_numpy(x.copy()).cuda()
|
14 |
-
u_u0 = x - width/2.0
|
15 |
-
|
16 |
-
y_col = np.arange(0, height) # y_col = np.arange(0, height)
|
17 |
-
y = np.tile(y_col, (width, 1)).T
|
18 |
-
y = y[np.newaxis, :, :]
|
19 |
-
y = y.astype(np.float32)
|
20 |
-
y = torch.from_numpy(y.copy()).cuda()
|
21 |
-
v_v0 = y - height/2.0
|
22 |
-
return u_u0, v_v0
|
23 |
-
|
24 |
-
|
25 |
-
def depth_to_xyz(depth, focal_length):
|
26 |
-
b, c, h, w = depth.shape
|
27 |
-
u_u0, v_v0 = init_image_coor(h, w)
|
28 |
-
x = u_u0 * depth / focal_length
|
29 |
-
y = v_v0 * depth / focal_length
|
30 |
-
z = depth
|
31 |
-
pw = torch.cat([x, y, z], 1).permute(0, 2, 3, 1) # [b, h, w, c]
|
32 |
-
return pw
|
33 |
-
|
34 |
-
|
35 |
-
def get_surface_normal(xyz, patch_size=3):
|
36 |
-
# xyz: [1, h, w, 3]
|
37 |
-
x, y, z = torch.unbind(xyz, dim=3)
|
38 |
-
x = torch.unsqueeze(x, 0)
|
39 |
-
y = torch.unsqueeze(y, 0)
|
40 |
-
z = torch.unsqueeze(z, 0)
|
41 |
-
|
42 |
-
xx = x * x
|
43 |
-
yy = y * y
|
44 |
-
zz = z * z
|
45 |
-
xy = x * y
|
46 |
-
xz = x * z
|
47 |
-
yz = y * z
|
48 |
-
patch_weight = torch.ones((1, 1, patch_size, patch_size), requires_grad=False).cuda()
|
49 |
-
xx_patch = nn.functional.conv2d(xx, weight=patch_weight, padding=int(patch_size / 2))
|
50 |
-
yy_patch = nn.functional.conv2d(yy, weight=patch_weight, padding=int(patch_size / 2))
|
51 |
-
zz_patch = nn.functional.conv2d(zz, weight=patch_weight, padding=int(patch_size / 2))
|
52 |
-
xy_patch = nn.functional.conv2d(xy, weight=patch_weight, padding=int(patch_size / 2))
|
53 |
-
xz_patch = nn.functional.conv2d(xz, weight=patch_weight, padding=int(patch_size / 2))
|
54 |
-
yz_patch = nn.functional.conv2d(yz, weight=patch_weight, padding=int(patch_size / 2))
|
55 |
-
ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch],
|
56 |
-
dim=4)
|
57 |
-
ATA = torch.squeeze(ATA)
|
58 |
-
ATA = torch.reshape(ATA, (ATA.size(0), ATA.size(1), 3, 3))
|
59 |
-
eps_identity = 1e-6 * torch.eye(3, device=ATA.device, dtype=ATA.dtype)[None, None, :, :].repeat([ATA.size(0), ATA.size(1), 1, 1])
|
60 |
-
ATA = ATA + eps_identity
|
61 |
-
x_patch = nn.functional.conv2d(x, weight=patch_weight, padding=int(patch_size / 2))
|
62 |
-
y_patch = nn.functional.conv2d(y, weight=patch_weight, padding=int(patch_size / 2))
|
63 |
-
z_patch = nn.functional.conv2d(z, weight=patch_weight, padding=int(patch_size / 2))
|
64 |
-
AT1 = torch.stack([x_patch, y_patch, z_patch], dim=4)
|
65 |
-
AT1 = torch.squeeze(AT1)
|
66 |
-
AT1 = torch.unsqueeze(AT1, 3)
|
67 |
-
|
68 |
-
patch_num = 4
|
69 |
-
patch_x = int(AT1.size(1) / patch_num)
|
70 |
-
patch_y = int(AT1.size(0) / patch_num)
|
71 |
-
n_img = torch.randn(AT1.shape).cuda()
|
72 |
-
overlap = patch_size // 2 + 1
|
73 |
-
for x in range(int(patch_num)):
|
74 |
-
for y in range(int(patch_num)):
|
75 |
-
left_flg = 0 if x == 0 else 1
|
76 |
-
right_flg = 0 if x == patch_num -1 else 1
|
77 |
-
top_flg = 0 if y == 0 else 1
|
78 |
-
btm_flg = 0 if y == patch_num - 1 else 1
|
79 |
-
at1 = AT1[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
80 |
-
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
81 |
-
ata = ATA[y * patch_y - top_flg * overlap:(y + 1) * patch_y + btm_flg * overlap,
|
82 |
-
x * patch_x - left_flg * overlap:(x + 1) * patch_x + right_flg * overlap]
|
83 |
-
n_img_tmp, _ = torch.solve(at1, ata)
|
84 |
-
|
85 |
-
n_img_tmp_select = n_img_tmp[top_flg * overlap:patch_y + top_flg * overlap, left_flg * overlap:patch_x + left_flg * overlap, :, :]
|
86 |
-
n_img[y * patch_y:y * patch_y + patch_y, x * patch_x:x * patch_x + patch_x, :, :] = n_img_tmp_select
|
87 |
-
|
88 |
-
n_img_L2 = torch.sqrt(torch.sum(n_img ** 2, dim=2, keepdim=True))
|
89 |
-
n_img_norm = n_img / n_img_L2
|
90 |
-
|
91 |
-
# re-orient normals consistently
|
92 |
-
orient_mask = torch.sum(torch.squeeze(n_img_norm) * torch.squeeze(xyz), dim=2) > 0
|
93 |
-
n_img_norm[orient_mask] *= -1
|
94 |
-
return n_img_norm
|
95 |
-
|
96 |
-
def get_surface_normalv2(xyz, patch_size=3):
|
97 |
-
"""
|
98 |
-
xyz: xyz coordinates
|
99 |
-
patch: [p1, p2, p3,
|
100 |
-
p4, p5, p6,
|
101 |
-
p7, p8, p9]
|
102 |
-
surface_normal = [(p9-p1) x (p3-p7)] + [(p6-p4) - (p8-p2)]
|
103 |
-
return: normal [h, w, 3, b]
|
104 |
-
"""
|
105 |
-
b, h, w, c = xyz.shape
|
106 |
-
half_patch = patch_size // 2
|
107 |
-
xyz_pad = torch.zeros((b, h + patch_size - 1, w + patch_size - 1, c), dtype=xyz.dtype, device=xyz.device)
|
108 |
-
xyz_pad[:, half_patch:-half_patch, half_patch:-half_patch, :] = xyz
|
109 |
-
|
110 |
-
# xyz_left_top = xyz_pad[:, :h, :w, :] # p1
|
111 |
-
# xyz_right_bottom = xyz_pad[:, -h:, -w:, :]# p9
|
112 |
-
# xyz_left_bottom = xyz_pad[:, -h:, :w, :] # p7
|
113 |
-
# xyz_right_top = xyz_pad[:, :h, -w:, :] # p3
|
114 |
-
# xyz_cross1 = xyz_left_top - xyz_right_bottom # p1p9
|
115 |
-
# xyz_cross2 = xyz_left_bottom - xyz_right_top # p7p3
|
116 |
-
|
117 |
-
xyz_left = xyz_pad[:, half_patch:half_patch + h, :w, :] # p4
|
118 |
-
xyz_right = xyz_pad[:, half_patch:half_patch + h, -w:, :] # p6
|
119 |
-
xyz_top = xyz_pad[:, :h, half_patch:half_patch + w, :] # p2
|
120 |
-
xyz_bottom = xyz_pad[:, -h:, half_patch:half_patch + w, :] # p8
|
121 |
-
xyz_horizon = xyz_left - xyz_right # p4p6
|
122 |
-
xyz_vertical = xyz_top - xyz_bottom # p2p8
|
123 |
-
|
124 |
-
xyz_left_in = xyz_pad[:, half_patch:half_patch + h, 1:w+1, :] # p4
|
125 |
-
xyz_right_in = xyz_pad[:, half_patch:half_patch + h, patch_size-1:patch_size-1+w, :] # p6
|
126 |
-
xyz_top_in = xyz_pad[:, 1:h+1, half_patch:half_patch + w, :] # p2
|
127 |
-
xyz_bottom_in = xyz_pad[:, patch_size-1:patch_size-1+h, half_patch:half_patch + w, :] # p8
|
128 |
-
xyz_horizon_in = xyz_left_in - xyz_right_in # p4p6
|
129 |
-
xyz_vertical_in = xyz_top_in - xyz_bottom_in # p2p8
|
130 |
-
|
131 |
-
n_img_1 = torch.cross(xyz_horizon_in, xyz_vertical_in, dim=3)
|
132 |
-
n_img_2 = torch.cross(xyz_horizon, xyz_vertical, dim=3)
|
133 |
-
|
134 |
-
# re-orient normals consistently
|
135 |
-
orient_mask = torch.sum(n_img_1 * xyz, dim=3) > 0
|
136 |
-
n_img_1[orient_mask] *= -1
|
137 |
-
orient_mask = torch.sum(n_img_2 * xyz, dim=3) > 0
|
138 |
-
n_img_2[orient_mask] *= -1
|
139 |
-
|
140 |
-
n_img1_L2 = torch.sqrt(torch.sum(n_img_1 ** 2, dim=3, keepdim=True))
|
141 |
-
n_img1_norm = n_img_1 / (n_img1_L2 + 1e-8)
|
142 |
-
|
143 |
-
n_img2_L2 = torch.sqrt(torch.sum(n_img_2 ** 2, dim=3, keepdim=True))
|
144 |
-
n_img2_norm = n_img_2 / (n_img2_L2 + 1e-8)
|
145 |
-
|
146 |
-
# average 2 norms
|
147 |
-
n_img_aver = n_img1_norm + n_img2_norm
|
148 |
-
n_img_aver_L2 = torch.sqrt(torch.sum(n_img_aver ** 2, dim=3, keepdim=True))
|
149 |
-
n_img_aver_norm = n_img_aver / (n_img_aver_L2 + 1e-8)
|
150 |
-
# re-orient normals consistently
|
151 |
-
orient_mask = torch.sum(n_img_aver_norm * xyz, dim=3) > 0
|
152 |
-
n_img_aver_norm[orient_mask] *= -1
|
153 |
-
n_img_aver_norm_out = n_img_aver_norm.permute((1, 2, 3, 0)) # [h, w, c, b]
|
154 |
-
|
155 |
-
# a = torch.sum(n_img1_norm_out*n_img2_norm_out, dim=2).cpu().numpy().squeeze()
|
156 |
-
# plt.imshow(np.abs(a), cmap='rainbow')
|
157 |
-
# plt.show()
|
158 |
-
return n_img_aver_norm_out#n_img1_norm.permute((1, 2, 3, 0))
|
159 |
-
|
160 |
-
def surface_normal_from_depth(depth, focal_length, valid_mask=None):
|
161 |
-
# para depth: depth map, [b, c, h, w]
|
162 |
-
b, c, h, w = depth.shape
|
163 |
-
focal_length = focal_length[:, None, None, None]
|
164 |
-
depth_filter = nn.functional.avg_pool2d(depth, kernel_size=3, stride=1, padding=1)
|
165 |
-
depth_filter = nn.functional.avg_pool2d(depth_filter, kernel_size=3, stride=1, padding=1)
|
166 |
-
xyz = depth_to_xyz(depth_filter, focal_length)
|
167 |
-
sn_batch = []
|
168 |
-
for i in range(b):
|
169 |
-
xyz_i = xyz[i, :][None, :, :, :]
|
170 |
-
normal = get_surface_normalv2(xyz_i)
|
171 |
-
sn_batch.append(normal)
|
172 |
-
sn_batch = torch.cat(sn_batch, dim=3).permute((3, 2, 0, 1)) # [b, c, h, w]
|
173 |
-
mask_invalid = (~valid_mask).repeat(1, 3, 1, 1)
|
174 |
-
sn_batch[mask_invalid] = 0.0
|
175 |
-
|
176 |
-
return sn_batch
|
177 |
-
|
178 |
-
|
179 |
-
def vis_normal(normal):
|
180 |
-
"""
|
181 |
-
Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255]
|
182 |
-
@para normal: surface normal, [h, w, 3], numpy.array
|
183 |
-
"""
|
184 |
-
n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True))
|
185 |
-
n_img_norm = normal / (n_img_L2 + 1e-8)
|
186 |
-
normal_vis = n_img_norm * 127
|
187 |
-
normal_vis += 128
|
188 |
-
normal_vis = normal_vis.astype(np.uint8)
|
189 |
-
return normal_vis
|
190 |
-
|
191 |
-
def vis_normal2(normals):
|
192 |
-
'''
|
193 |
-
Montage of normal maps. Vectors are unit length and backfaces thresholded.
|
194 |
-
'''
|
195 |
-
x = normals[:, :, 0] # horizontal; pos right
|
196 |
-
y = normals[:, :, 1] # depth; pos far
|
197 |
-
z = normals[:, :, 2] # vertical; pos up
|
198 |
-
backfacing = (z > 0)
|
199 |
-
norm = np.sqrt(np.sum(normals**2, axis=2))
|
200 |
-
zero = (norm < 1e-5)
|
201 |
-
x += 1.0; x *= 0.5
|
202 |
-
y += 1.0; y *= 0.5
|
203 |
-
z = np.abs(z)
|
204 |
-
x[zero] = 0.0
|
205 |
-
y[zero] = 0.0
|
206 |
-
z[zero] = 0.0
|
207 |
-
normals[:, :, 0] = x # horizontal; pos right
|
208 |
-
normals[:, :, 1] = y # depth; pos far
|
209 |
-
normals[:, :, 2] = z # vertical; pos up
|
210 |
-
return normals
|
211 |
-
|
212 |
-
if __name__ == '__main__':
|
213 |
-
import cv2, os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|