Spaces:
Running
on
A10G
Running
on
A10G
File size: 5,544 Bytes
320e465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
# This file is copied from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
# such that users do not need to install detectron2 just for these two functions
# Copyright (c) Facebook, Inc. and its affiliates.
from typing import List
import torch
from torch.nn import functional as F
def cat(tensors: List[torch.Tensor], dim: int = 0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def calculate_uncertainty(sem_seg_logits):
"""
For each location of the prediction `sem_seg_logits` we estimate uncerainty as the
difference between top first and top second predicted logits.
Args:
mask_logits (Tensor): A tensor of shape (N, C, ...), where N is the minibatch size and
C is the number of foreground classes. The values are logits.
Returns:
scores (Tensor): A tensor of shape (N, 1, ...) that contains uncertainty scores with
the most uncertain locations having the highest uncertainty score.
"""
if sem_seg_logits.shape[1] == 2:
# binary segmentation
return -(torch.abs(sem_seg_logits[:, 1:2]))
top2_scores = torch.topk(sem_seg_logits, k=2, dim=1)[0]
return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)
def point_sample(input, point_coords, **kwargs):
"""
A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors.
Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside
[0, 1] x [0, 1] square.
Args:
input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid.
point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains
[0, 1] x [0, 1] normalized point coordinates.
Returns:
output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains
features for points in `point_coords`. The features are obtained via bilinear
interpolation from `input` the same way as :function:`torch.nn.functional.grid_sample`.
"""
add_dim = False
if point_coords.dim() == 3:
add_dim = True
point_coords = point_coords.unsqueeze(2)
output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs)
if add_dim:
output = output.squeeze(3)
return output
def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points,
oversample_ratio, importance_sample_ratio):
"""
Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The uncertainties
are calculated for each point using 'uncertainty_func' function that takes point's logit
prediction as input.
See PointRend paper for details.
Args:
coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for
class-specific or class-agnostic prediction.
uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that
contains logit predictions for P points and returns their uncertainties as a Tensor of
shape (N, 1, P).
num_points (int): The number of points P to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P
sampled points.
"""
assert oversample_ratio >= 1
assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0
num_boxes = coarse_logits.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(num_boxes, num_sampled, 2, device=coarse_logits.device)
point_logits = point_sample(coarse_logits, point_coords, align_corners=False)
# It is crucial to calculate uncertainty based on the sampled prediction value for the points.
# Calculating uncertainties of the coarse predictions first and sampling them for points leads
# to incorrect results.
# To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between
# two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value.
# However, if we calculate uncertainties for the coarse predictions first,
# both will have -1 uncertainty, and the sampled point will get -1 uncertainty.
point_uncertainties = uncertainty_func(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(num_boxes, dtype=torch.long, device=coarse_logits.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points,
2)
if num_random_points > 0:
point_coords = cat(
[
point_coords,
torch.rand(num_boxes, num_random_points, 2, device=coarse_logits.device),
],
dim=1,
)
return point_coords |