|
import torch |
|
import torch.nn.functional as F |
|
import copy |
|
|
|
def update_sample(bin_edges, target_bin_left, target_bin_right, depth_r, pred_label, depth_num, min_depth, max_depth, uncertainty_range): |
|
|
|
with torch.no_grad(): |
|
b, _, h, w = bin_edges.shape |
|
|
|
mode = 'direct' |
|
if mode == 'direct': |
|
depth_range = uncertainty_range |
|
depth_start_update = torch.clamp_min(depth_r - 0.5 * depth_range, min_depth) |
|
else: |
|
depth_range = uncertainty_range + (target_bin_right - target_bin_left).abs() |
|
depth_start_update = torch.clamp_min(target_bin_left - 0.5 * uncertainty_range, min_depth) |
|
|
|
interval = depth_range / depth_num |
|
interval = interval.repeat(1, depth_num, 1, 1) |
|
interval = torch.cat([torch.ones([b, 1, h, w], device=bin_edges.device) * depth_start_update, interval], 1) |
|
|
|
bin_edges = torch.cumsum(interval, 1).clamp(min_depth, max_depth) |
|
curr_depth = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:]) |
|
|
|
return bin_edges.detach(), curr_depth.detach() |
|
|
|
def get_label(gt_depth_img, bin_edges, depth_num): |
|
|
|
with torch.no_grad(): |
|
gt_label = torch.zeros(gt_depth_img.size(), dtype=torch.int64, device=gt_depth_img.device) |
|
for i in range(depth_num): |
|
bin_mask = torch.ge(gt_depth_img, bin_edges[:, i]) |
|
bin_mask = torch.logical_and(bin_mask, |
|
torch.lt(gt_depth_img, bin_edges[:, i + 1])) |
|
gt_label[bin_mask] = i |
|
|
|
return gt_label |
|
|
|
|
|
|