|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
def postprocess(out, pred_mask, depth_mode, conf_mode): |
|
""" |
|
extract 3D points/confidence from prediction head output |
|
""" |
|
fmap = out.permute(0, 2, 3, 1) |
|
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode), pred_mask=pred_mask) |
|
|
|
if conf_mode is not None: |
|
res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode) |
|
return res |
|
|
|
|
|
def reg_dense_depth(xyz, mode): |
|
""" |
|
extract 3D points from prediction head output |
|
""" |
|
mode, vmin, vmax = mode |
|
|
|
no_bounds = (vmin == -float('inf')) and (vmax == float('inf')) |
|
assert no_bounds |
|
|
|
if mode == 'linear': |
|
if no_bounds: |
|
return xyz |
|
return xyz.clip(min=vmin, max=vmax) |
|
|
|
|
|
d = xyz.norm(dim=-1, keepdim=True) |
|
xyz = xyz / d.clip(min=1e-8) |
|
|
|
if mode == 'square': |
|
return xyz * d.square() |
|
|
|
if mode == 'exp': |
|
return xyz * torch.expm1(d) |
|
|
|
raise ValueError(f'bad {mode=}') |
|
|
|
|
|
def reg_dense_conf(x, mode): |
|
""" |
|
extract confidence from prediction head output |
|
""" |
|
mode, vmin, vmax = mode |
|
if mode == 'exp': |
|
return vmin + x.exp().clip(max=vmax-vmin) |
|
if mode == 'sigmoid': |
|
return (vmax - vmin) * torch.sigmoid(x) + vmin |
|
raise ValueError(f'bad {mode=}') |
|
|