MASA_GroundingDINO / masa /models /tracker /masa_tao_tracker.py
JohanDL's picture
initial commit
f1dd031
raw
history blame
14 kB
"""
Author: Siyuan Li
Licensed: Apache-2.0 License
"""
from typing import List, Tuple
import torch
import torch.nn.functional as F
from mmdet.models.trackers.base_tracker import BaseTracker
from mmdet.registry import MODELS
from mmdet.structures import TrackDataSample
from mmdet.structures.bbox import bbox_overlaps
from mmengine.structures import InstanceData
from torch import Tensor
@MODELS.register_module()
class MasaTaoTracker(BaseTracker):
"""Tracker for MASA on TAO benchmark.
Args:
init_score_thr (float): The cls_score threshold to
initialize a new tracklet. Defaults to 0.8.
obj_score_thr (float): The cls_score threshold to
update a tracked tracklet. Defaults to 0.5.
match_score_thr (float): The match threshold. Defaults to 0.5.
memo_tracklet_frames (int): The most frames in a tracklet memory.
Defaults to 10.
memo_momentum (float): The momentum value for embeds updating.
Defaults to 0.8.
distractor_score_thr (float): The score threshold to consider an object as a distractor.
Defaults to 0.5.
distractor_nms_thr (float): The NMS threshold for filtering out distractors.
Defaults to 0.3.
with_cats (bool): Whether to track with the same category.
Defaults to True.
match_metric (str): The match metric. Can be 'bisoftmax', 'softmax', or 'cosine'. Defaults to 'bisoftmax'.
max_distance (float): Maximum distance for considering matches. Defaults to -1.
fps (int): Frames per second of the input video. Used for calculating growth factor. Defaults to 1.
"""
def __init__(
self,
init_score_thr: float = 0.8,
obj_score_thr: float = 0.5,
match_score_thr: float = 0.5,
memo_tracklet_frames: int = 10,
memo_momentum: float = 0.8,
distractor_score_thr: float = 0.5,
distractor_nms_thr=0.3,
with_cats: bool = True,
max_distance: float = -1,
fps=1,
**kwargs
):
super().__init__(**kwargs)
assert 0 <= memo_momentum <= 1.0
assert memo_tracklet_frames >= 0
self.init_score_thr = init_score_thr
self.obj_score_thr = obj_score_thr
self.match_score_thr = match_score_thr
self.memo_tracklet_frames = memo_tracklet_frames
self.memo_momentum = memo_momentum
self.distractor_score_thr = distractor_score_thr
self.distractor_nms_thr = distractor_nms_thr
self.with_cats = with_cats
self.num_tracks = 0
self.tracks = dict()
self.backdrops = []
self.max_distance = max_distance # Maximum distance for considering matches
self.fps = fps
self.growth_factor = self.fps / 6 # Growth factor for the distance mask
self.distance_smoothing_factor = 100 / self.fps
def reset(self):
"""Reset the buffer of the tracker."""
self.num_tracks = 0
self.tracks = dict()
self.backdrops = []
def update(
self,
ids: Tensor,
bboxes: Tensor,
embeds: Tensor,
labels: Tensor,
scores: Tensor,
frame_id: int,
) -> None:
"""Tracking forward function.
Args:
ids (Tensor): of shape(N, ).
bboxes (Tensor): of shape (N, 5).
embeds (Tensor): of shape (N, 256).
labels (Tensor): of shape (N, ).
scores (Tensor): of shape (N, ).
frame_id (int): The id of current frame, 0-index.
"""
tracklet_inds = ids > -1
for id, bbox, embed, label, score in zip(
ids[tracklet_inds],
bboxes[tracklet_inds],
embeds[tracklet_inds],
labels[tracklet_inds],
scores[tracklet_inds],
):
id = int(id)
# update the tracked ones and initialize new tracks
if id in self.tracks.keys():
self.tracks[id]["bbox"] = bbox
self.tracks[id]["embed"] = (1 - self.memo_momentum) * self.tracks[id][
"embed"
] + self.memo_momentum * embed
self.tracks[id]["last_frame"] = frame_id
self.tracks[id]["label"] = label
self.tracks[id]["score"] = score
else:
self.tracks[id] = dict(
bbox=bbox,
embed=embed,
label=label,
score=score,
last_frame=frame_id,
)
# pop memo
invalid_ids = []
for k, v in self.tracks.items():
if frame_id - v["last_frame"] >= self.memo_tracklet_frames:
invalid_ids.append(k)
for invalid_id in invalid_ids:
self.tracks.pop(invalid_id)
@property
def memo(self) -> Tuple[Tensor, ...]:
"""Get tracks memory."""
memo_embeds = []
memo_ids = []
memo_bboxes = []
memo_labels = []
memo_frame_ids = []
# get tracks
for k, v in self.tracks.items():
memo_bboxes.append(v["bbox"][None, :])
memo_embeds.append(v["embed"][None, :])
memo_ids.append(k)
memo_labels.append(v["label"].view(1, 1))
memo_frame_ids.append(v["last_frame"])
memo_ids = torch.tensor(memo_ids, dtype=torch.long).view(1, -1)
memo_bboxes = torch.cat(memo_bboxes, dim=0)
memo_embeds = torch.cat(memo_embeds, dim=0)
memo_labels = torch.cat(memo_labels, dim=0).squeeze(1)
memo_frame_ids = torch.tensor(memo_frame_ids, dtype=torch.long).view(1, -1)
return (
memo_bboxes,
memo_labels,
memo_embeds,
memo_ids.squeeze(0),
memo_frame_ids.squeeze(0),
)
def compute_distance_mask(self, bboxes1, bboxes2, frame_ids1, frame_ids2):
"""Compute a mask based on the pairwise center distances and frame IDs with piecewise soft-weighting."""
centers1 = (bboxes1[:, :2] + bboxes1[:, 2:]) / 2.0
centers2 = (bboxes2[:, :2] + bboxes2[:, 2:]) / 2.0
distances = torch.cdist(centers1, centers2)
frame_id_diff = torch.abs(frame_ids1[:, None] - frame_ids2[None, :]).to(
distances.device
)
# Define a scaling factor for the distance based on frame difference (exponential growth)
scaling_factor = torch.exp(frame_id_diff.float() / self.growth_factor)
# Apply the scaling factor to max_distance
adaptive_max_distance = self.max_distance * scaling_factor
# Create a piecewise function for soft gating
soft_distance_mask = torch.where(
distances <= adaptive_max_distance,
torch.ones_like(distances),
torch.exp(
-(distances - adaptive_max_distance) / self.distance_smoothing_factor
),
)
return soft_distance_mask
def track(
self,
model: torch.nn.Module,
img: torch.Tensor,
feats: List[torch.Tensor],
data_sample: TrackDataSample,
rescale=True,
with_segm=False,
**kwargs
) -> InstanceData:
"""Tracking forward function.
Args:
model (nn.Module): MOT model.
img (Tensor): of shape (T, C, H, W) encoding input image.
Typically these should be mean centered and std scaled.
The T denotes the number of key images and usually is 1.
feats (list[Tensor]): Multi level feature maps of `img`.
data_sample (:obj:`TrackDataSample`): The data sample.
It includes information such as `pred_instances`.
rescale (bool, optional): If True, the bounding boxes should be
rescaled to fit the original scale of the image. Defaults to
True.
Returns:
:obj:`InstanceData`: Tracking results of the input images.
Each InstanceData usually contains ``bboxes``, ``labels``,
``scores`` and ``instances_id``.
"""
metainfo = data_sample.metainfo
bboxes = data_sample.pred_instances.bboxes
labels = data_sample.pred_instances.labels
scores = data_sample.pred_instances.scores
frame_id = metainfo.get("frame_id", -1)
# create pred_track_instances
pred_track_instances = InstanceData()
# return zero bboxes if there is no track targets
if bboxes.shape[0] == 0:
ids = torch.zeros_like(labels)
pred_track_instances = data_sample.pred_instances.clone()
pred_track_instances.instances_id = ids
pred_track_instances.mask_inds = torch.zeros_like(labels)
return pred_track_instances
# get track feats
rescaled_bboxes = bboxes.clone()
if rescale:
scale_factor = rescaled_bboxes.new_tensor(metainfo["scale_factor"]).repeat(
(1, 2)
)
rescaled_bboxes = rescaled_bboxes * scale_factor
track_feats = model.track_head.predict(feats, [rescaled_bboxes])
# sort according to the object_score
_, inds = scores.sort(descending=True)
bboxes = bboxes[inds]
scores = scores[inds]
labels = labels[inds]
embeds = track_feats[inds, :]
if with_segm:
mask_inds = torch.arange(bboxes.size(0)).to(embeds.device)
mask_inds = mask_inds[inds]
else:
mask_inds = []
bboxes, labels, scores, embeds, mask_inds = self.remove_distractor(
bboxes,
labels,
scores,
track_feats=embeds,
mask_inds=mask_inds,
nms="inter",
distractor_score_thr=self.distractor_score_thr,
distractor_nms_thr=self.distractor_nms_thr,
)
# init ids container
ids = torch.full((bboxes.size(0),), -1, dtype=torch.long)
# match if buffer is not empty
if bboxes.size(0) > 0 and not self.empty:
(
memo_bboxes,
memo_labels,
memo_embeds,
memo_ids,
memo_frame_ids,
) = self.memo
feats = torch.mm(embeds, memo_embeds.t())
d2t_scores = feats.softmax(dim=1)
t2d_scores = feats.softmax(dim=0)
match_scores_bisoftmax = (d2t_scores + t2d_scores) / 2
match_scores_cosine = torch.mm(
F.normalize(embeds, p=2, dim=1),
F.normalize(memo_embeds, p=2, dim=1).t(),
)
match_scores = (match_scores_bisoftmax + match_scores_cosine) / 2
if self.max_distance != -1:
# Compute the mask based on spatial proximity
current_frame_ids = torch.full(
(bboxes.size(0),), frame_id, dtype=torch.long
)
distance_mask = self.compute_distance_mask(
bboxes, memo_bboxes, current_frame_ids, memo_frame_ids
)
# Apply the mask to the match scores
match_scores = match_scores * distance_mask
# track according to match_scores
for i in range(bboxes.size(0)):
conf, memo_ind = torch.max(match_scores[i, :], dim=0)
id = memo_ids[memo_ind]
if conf > self.match_score_thr:
if id > -1:
# keep bboxes with high object score
# and remove background bboxes
if scores[i] > self.obj_score_thr:
ids[i] = id
match_scores[:i, memo_ind] = 0
match_scores[i + 1 :, memo_ind] = 0
# initialize new tracks
new_inds = (ids == -1) & (scores > self.init_score_thr).cpu()
num_news = new_inds.sum()
ids[new_inds] = torch.arange(
self.num_tracks, self.num_tracks + num_news, dtype=torch.long
)
self.num_tracks += num_news
self.update(ids, bboxes, embeds, labels, scores, frame_id)
tracklet_inds = ids > -1
# update pred_track_instances
pred_track_instances.bboxes = bboxes[tracklet_inds]
pred_track_instances.labels = labels[tracklet_inds]
pred_track_instances.scores = scores[tracklet_inds]
pred_track_instances.instances_id = ids[tracklet_inds]
if with_segm:
pred_track_instances.mask_inds = mask_inds[tracklet_inds]
return pred_track_instances
def remove_distractor(
self,
bboxes,
labels,
scores,
track_feats,
mask_inds=[],
distractor_score_thr=0.5,
distractor_nms_thr=0.3,
nms="inter",
):
# all objects is valid here
valid_inds = labels > -1
# nms
low_inds = torch.nonzero(scores < distractor_score_thr, as_tuple=False).squeeze(
1
)
if nms == "inter":
ious = bbox_overlaps(bboxes[low_inds, :], bboxes[:, :])
elif nms == "intra":
cat_same = labels[low_inds].view(-1, 1) == labels.view(1, -1)
ious = bbox_overlaps(bboxes[low_inds, :], bboxes)
ious *= cat_same.to(ious.device)
else:
raise NotImplementedError
for i, ind in enumerate(low_inds):
if (ious[i, :ind] > distractor_nms_thr).any():
valid_inds[ind] = False
bboxes = bboxes[valid_inds]
labels = labels[valid_inds]
scores = scores[valid_inds]
if track_feats is not None:
track_feats = track_feats[valid_inds]
if len(mask_inds) > 0:
mask_inds = mask_inds[valid_inds]
return bboxes, labels, scores, track_feats, mask_inds