umuthopeyildirim's picture
here we go
bd86ed9
import torch
import torch.nn.functional as F
import copy
def update_sample(bin_edges, target_bin_left, target_bin_right, depth_r, pred_label, depth_num, min_depth, max_depth, uncertainty_range):
with torch.no_grad():
b, _, h, w = bin_edges.shape
mode = 'direct'
if mode == 'direct':
depth_range = uncertainty_range
depth_start_update = torch.clamp_min(depth_r - 0.5 * depth_range, min_depth)
else:
depth_range = uncertainty_range + (target_bin_right - target_bin_left).abs()
depth_start_update = torch.clamp_min(target_bin_left - 0.5 * uncertainty_range, min_depth)
interval = depth_range / depth_num
interval = interval.repeat(1, depth_num, 1, 1)
interval = torch.cat([torch.ones([b, 1, h, w], device=bin_edges.device) * depth_start_update, interval], 1)
bin_edges = torch.cumsum(interval, 1).clamp(min_depth, max_depth)
curr_depth = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
return bin_edges.detach(), curr_depth.detach()
def get_label(gt_depth_img, bin_edges, depth_num):
with torch.no_grad():
gt_label = torch.zeros(gt_depth_img.size(), dtype=torch.int64, device=gt_depth_img.device)
for i in range(depth_num):
bin_mask = torch.ge(gt_depth_img, bin_edges[:, i])
bin_mask = torch.logical_and(bin_mask,
torch.lt(gt_depth_img, bin_edges[:, i + 1]))
gt_label[bin_mask] = i
return gt_label