Spaces:
Runtime error
Runtime error
""" | |
Author: Siyuan Li | |
Licensed: Apache-2.0 License | |
""" | |
from typing import List, Tuple | |
import torch | |
import torch.nn.functional as F | |
from mmdet.models.trackers.base_tracker import BaseTracker | |
from mmdet.registry import MODELS | |
from mmdet.structures import TrackDataSample | |
from mmdet.structures.bbox import bbox_overlaps | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
class MasaTaoTracker(BaseTracker): | |
"""Tracker for MASA on TAO benchmark. | |
Args: | |
init_score_thr (float): The cls_score threshold to | |
initialize a new tracklet. Defaults to 0.8. | |
obj_score_thr (float): The cls_score threshold to | |
update a tracked tracklet. Defaults to 0.5. | |
match_score_thr (float): The match threshold. Defaults to 0.5. | |
memo_tracklet_frames (int): The most frames in a tracklet memory. | |
Defaults to 10. | |
memo_momentum (float): The momentum value for embeds updating. | |
Defaults to 0.8. | |
distractor_score_thr (float): The score threshold to consider an object as a distractor. | |
Defaults to 0.5. | |
distractor_nms_thr (float): The NMS threshold for filtering out distractors. | |
Defaults to 0.3. | |
with_cats (bool): Whether to track with the same category. | |
Defaults to True. | |
match_metric (str): The match metric. Can be 'bisoftmax', 'softmax', or 'cosine'. Defaults to 'bisoftmax'. | |
max_distance (float): Maximum distance for considering matches. Defaults to -1. | |
fps (int): Frames per second of the input video. Used for calculating growth factor. Defaults to 1. | |
""" | |
def __init__( | |
self, | |
init_score_thr: float = 0.8, | |
obj_score_thr: float = 0.5, | |
match_score_thr: float = 0.5, | |
memo_tracklet_frames: int = 10, | |
memo_momentum: float = 0.8, | |
distractor_score_thr: float = 0.5, | |
distractor_nms_thr=0.3, | |
with_cats: bool = True, | |
max_distance: float = -1, | |
fps=1, | |
**kwargs | |
): | |
super().__init__(**kwargs) | |
assert 0 <= memo_momentum <= 1.0 | |
assert memo_tracklet_frames >= 0 | |
self.init_score_thr = init_score_thr | |
self.obj_score_thr = obj_score_thr | |
self.match_score_thr = match_score_thr | |
self.memo_tracklet_frames = memo_tracklet_frames | |
self.memo_momentum = memo_momentum | |
self.distractor_score_thr = distractor_score_thr | |
self.distractor_nms_thr = distractor_nms_thr | |
self.with_cats = with_cats | |
self.num_tracks = 0 | |
self.tracks = dict() | |
self.backdrops = [] | |
self.max_distance = max_distance # Maximum distance for considering matches | |
self.fps = fps | |
self.growth_factor = self.fps / 6 # Growth factor for the distance mask | |
self.distance_smoothing_factor = 100 / self.fps | |
def reset(self): | |
"""Reset the buffer of the tracker.""" | |
self.num_tracks = 0 | |
self.tracks = dict() | |
self.backdrops = [] | |
def update( | |
self, | |
ids: Tensor, | |
bboxes: Tensor, | |
embeds: Tensor, | |
labels: Tensor, | |
scores: Tensor, | |
frame_id: int, | |
) -> None: | |
"""Tracking forward function. | |
Args: | |
ids (Tensor): of shape(N, ). | |
bboxes (Tensor): of shape (N, 5). | |
embeds (Tensor): of shape (N, 256). | |
labels (Tensor): of shape (N, ). | |
scores (Tensor): of shape (N, ). | |
frame_id (int): The id of current frame, 0-index. | |
""" | |
tracklet_inds = ids > -1 | |
for id, bbox, embed, label, score in zip( | |
ids[tracklet_inds], | |
bboxes[tracklet_inds], | |
embeds[tracklet_inds], | |
labels[tracklet_inds], | |
scores[tracklet_inds], | |
): | |
id = int(id) | |
# update the tracked ones and initialize new tracks | |
if id in self.tracks.keys(): | |
self.tracks[id]["bbox"] = bbox | |
self.tracks[id]["embed"] = (1 - self.memo_momentum) * self.tracks[id][ | |
"embed" | |
] + self.memo_momentum * embed | |
self.tracks[id]["last_frame"] = frame_id | |
self.tracks[id]["label"] = label | |
self.tracks[id]["score"] = score | |
else: | |
self.tracks[id] = dict( | |
bbox=bbox, | |
embed=embed, | |
label=label, | |
score=score, | |
last_frame=frame_id, | |
) | |
# pop memo | |
invalid_ids = [] | |
for k, v in self.tracks.items(): | |
if frame_id - v["last_frame"] >= self.memo_tracklet_frames: | |
invalid_ids.append(k) | |
for invalid_id in invalid_ids: | |
self.tracks.pop(invalid_id) | |
def memo(self) -> Tuple[Tensor, ...]: | |
"""Get tracks memory.""" | |
memo_embeds = [] | |
memo_ids = [] | |
memo_bboxes = [] | |
memo_labels = [] | |
memo_frame_ids = [] | |
# get tracks | |
for k, v in self.tracks.items(): | |
memo_bboxes.append(v["bbox"][None, :]) | |
memo_embeds.append(v["embed"][None, :]) | |
memo_ids.append(k) | |
memo_labels.append(v["label"].view(1, 1)) | |
memo_frame_ids.append(v["last_frame"]) | |
memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1) | |
memo_bboxes = torch.cat(memo_bboxes, dim=0) | |
memo_embeds = torch.cat(memo_embeds, dim=0) | |
memo_labels = torch.cat(memo_labels, dim=0).squeeze(1) | |
memo_frame_ids = torch.tensor(memo_frame_ids, dtype=torch.long).view(1, -1) | |
return ( | |
memo_bboxes, | |
memo_labels, | |
memo_embeds, | |
memo_ids.squeeze(0), | |
memo_frame_ids.squeeze(0), | |
) | |
def compute_distance_mask(self, bboxes1, bboxes2, frame_ids1, frame_ids2): | |
"""Compute a mask based on the pairwise center distances and frame IDs with piecewise soft-weighting.""" | |
centers1 = (bboxes1[:, :2] + bboxes1[:, 2:]) / 2.0 | |
centers2 = (bboxes2[:, :2] + bboxes2[:, 2:]) / 2.0 | |
distances = torch.cdist(centers1, centers2) | |
frame_id_diff = torch.abs(frame_ids1[:, None] - frame_ids2[None, :]).to( | |
distances.device | |
) | |
# Define a scaling factor for the distance based on frame difference (exponential growth) | |
scaling_factor = torch.exp(frame_id_diff.float() / self.growth_factor) | |
# Apply the scaling factor to max_distance | |
adaptive_max_distance = self.max_distance * scaling_factor | |
# Create a piecewise function for soft gating | |
soft_distance_mask = torch.where( | |
distances <= adaptive_max_distance, | |
torch.ones_like(distances), | |
torch.exp( | |
-(distances - adaptive_max_distance) / self.distance_smoothing_factor | |
), | |
) | |
return soft_distance_mask | |
def track( | |
self, | |
model: torch.nn.Module, | |
img: torch.Tensor, | |
feats: List[torch.Tensor], | |
data_sample: TrackDataSample, | |
rescale=True, | |
with_segm=False, | |
**kwargs | |
) -> InstanceData: | |
"""Tracking forward function. | |
Args: | |
model (nn.Module): MOT model. | |
img (Tensor): of shape (T, C, H, W) encoding input image. | |
Typically these should be mean centered and std scaled. | |
The T denotes the number of key images and usually is 1. | |
feats (list[Tensor]): Multi level feature maps of `img`. | |
data_sample (:obj:`TrackDataSample`): The data sample. | |
It includes information such as `pred_instances`. | |
rescale (bool, optional): If True, the bounding boxes should be | |
rescaled to fit the original scale of the image. Defaults to | |
True. | |
Returns: | |
:obj:`InstanceData`: Tracking results of the input images. | |
Each InstanceData usually contains ``bboxes``, ``labels``, | |
``scores`` and ``instances_id``. | |
""" | |
metainfo = data_sample.metainfo | |
bboxes = data_sample.pred_instances.bboxes | |
labels = data_sample.pred_instances.labels | |
scores = data_sample.pred_instances.scores | |
frame_id = metainfo.get("frame_id", -1) | |
# create pred_track_instances | |
pred_track_instances = InstanceData() | |
# return zero bboxes if there is no track targets | |
if bboxes.shape[0] == 0: | |
ids = torch.zeros_like(labels) | |
pred_track_instances = data_sample.pred_instances.clone() | |
pred_track_instances.instances_id = ids | |
pred_track_instances.mask_inds = torch.zeros_like(labels) | |
return pred_track_instances | |
# get track feats | |
rescaled_bboxes = bboxes.clone() | |
if rescale: | |
scale_factor = rescaled_bboxes.new_tensor(metainfo["scale_factor"]).repeat( | |
(1, 2) | |
) | |
rescaled_bboxes = rescaled_bboxes * scale_factor | |
track_feats = model.track_head.predict(feats, [rescaled_bboxes]) | |
# sort according to the object_score | |
_, inds = scores.sort(descending=True) | |
bboxes = bboxes[inds] | |
scores = scores[inds] | |
labels = labels[inds] | |
embeds = track_feats[inds, :] | |
if with_segm: | |
mask_inds = torch.arange(bboxes.size(0)).to(embeds.device) | |
mask_inds = mask_inds[inds] | |
else: | |
mask_inds = [] | |
bboxes, labels, scores, embeds, mask_inds = self.remove_distractor( | |
bboxes, | |
labels, | |
scores, | |
track_feats=embeds, | |
mask_inds=mask_inds, | |
nms="inter", | |
distractor_score_thr=self.distractor_score_thr, | |
distractor_nms_thr=self.distractor_nms_thr, | |
) | |
# init ids container | |
ids = torch.full((bboxes.size(0),), -1, dtype=torch.long) | |
# match if buffer is not empty | |
if bboxes.size(0) > 0 and not self.empty: | |
( | |
memo_bboxes, | |
memo_labels, | |
memo_embeds, | |
memo_ids, | |
memo_frame_ids, | |
) = self.memo | |
feats = torch.mm(embeds, memo_embeds.t()) | |
d2t_scores = feats.softmax(dim=1) | |
t2d_scores = feats.softmax(dim=0) | |
match_scores_bisoftmax = (d2t_scores + t2d_scores) / 2 | |
match_scores_cosine = torch.mm( | |
F.normalize(embeds, p=2, dim=1), | |
F.normalize(memo_embeds, p=2, dim=1).t(), | |
) | |
match_scores = (match_scores_bisoftmax + match_scores_cosine) / 2 | |
if self.max_distance != -1: | |
# Compute the mask based on spatial proximity | |
current_frame_ids = torch.full( | |
(bboxes.size(0),), frame_id, dtype=torch.long | |
) | |
distance_mask = self.compute_distance_mask( | |
bboxes, memo_bboxes, current_frame_ids, memo_frame_ids | |
) | |
# Apply the mask to the match scores | |
match_scores = match_scores * distance_mask | |
# track according to match_scores | |
for i in range(bboxes.size(0)): | |
conf, memo_ind = torch.max(match_scores[i, :], dim=0) | |
id = memo_ids[memo_ind] | |
if conf > self.match_score_thr: | |
if id > -1: | |
# keep bboxes with high object score | |
# and remove background bboxes | |
if scores[i] > self.obj_score_thr: | |
ids[i] = id | |
match_scores[:i, memo_ind] = 0 | |
match_scores[i + 1 :, memo_ind] = 0 | |
# initialize new tracks | |
new_inds = (ids == -1) & (scores > self.init_score_thr).cpu() | |
num_news = new_inds.sum() | |
ids[new_inds] = torch.arange( | |
self.num_tracks, self.num_tracks + num_news, dtype=torch.long | |
) | |
self.num_tracks += num_news | |
self.update(ids, bboxes, embeds, labels, scores, frame_id) | |
tracklet_inds = ids > -1 | |
# update pred_track_instances | |
pred_track_instances.bboxes = bboxes[tracklet_inds] | |
pred_track_instances.labels = labels[tracklet_inds] | |
pred_track_instances.scores = scores[tracklet_inds] | |
pred_track_instances.instances_id = ids[tracklet_inds] | |
if with_segm: | |
pred_track_instances.mask_inds = mask_inds[tracklet_inds] | |
return pred_track_instances | |
def remove_distractor( | |
self, | |
bboxes, | |
labels, | |
scores, | |
track_feats, | |
mask_inds=[], | |
distractor_score_thr=0.5, | |
distractor_nms_thr=0.3, | |
nms="inter", | |
): | |
# all objects is valid here | |
valid_inds = labels > -1 | |
# nms | |
low_inds = torch.nonzero(scores < distractor_score_thr, as_tuple=False).squeeze( | |
1 | |
) | |
if nms == "inter": | |
ious = bbox_overlaps(bboxes[low_inds, :], bboxes[:, :]) | |
elif nms == "intra": | |
cat_same = labels[low_inds].view(-1, 1) == labels.view(1, -1) | |
ious = bbox_overlaps(bboxes[low_inds, :], bboxes) | |
ious *= cat_same.to(ious.device) | |
else: | |
raise NotImplementedError | |
for i, ind in enumerate(low_inds): | |
if (ious[i, :ind] > distractor_nms_thr).any(): | |
valid_inds[ind] = False | |
bboxes = bboxes[valid_inds] | |
labels = labels[valid_inds] | |
scores = scores[valid_inds] | |
if track_feats is not None: | |
track_feats = track_feats[valid_inds] | |
if len(mask_inds) > 0: | |
mask_inds = mask_inds[valid_inds] | |
return bboxes, labels, scores, track_feats, mask_inds | |