import torch import torch.nn.functional as F import copy def update_sample(bin_edges, target_bin_left, target_bin_right, depth_r, pred_label, depth_num, min_depth, max_depth, uncertainty_range): with torch.no_grad(): b, _, h, w = bin_edges.shape mode = 'direct' if mode == 'direct': depth_range = uncertainty_range depth_start_update = torch.clamp_min(depth_r - 0.5 * depth_range, min_depth) else: depth_range = uncertainty_range + (target_bin_right - target_bin_left).abs() depth_start_update = torch.clamp_min(target_bin_left - 0.5 * uncertainty_range, min_depth) interval = depth_range / depth_num interval = interval.repeat(1, depth_num, 1, 1) interval = torch.cat([torch.ones([b, 1, h, w], device=bin_edges.device) * depth_start_update, interval], 1) bin_edges = torch.cumsum(interval, 1).clamp(min_depth, max_depth) curr_depth = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:]) return bin_edges.detach(), curr_depth.detach() def get_label(gt_depth_img, bin_edges, depth_num): with torch.no_grad(): gt_label = torch.zeros(gt_depth_img.size(), dtype=torch.int64, device=gt_depth_img.device) for i in range(depth_num): bin_mask = torch.ge(gt_depth_img, bin_edges[:, i]) bin_mask = torch.logical_and(bin_mask, torch.lt(gt_depth_img, bin_edges[:, i + 1])) gt_label[bin_mask] = i return gt_label