File size: 7,486 Bytes
b213d84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.
import copy
import numpy as np
from typing import Dict
import torch
from scipy.optimize import linear_sum_assignment
from detectron2.config import configurable
from detectron2.structures import Boxes, Instances
from ..config.config import CfgNode as CfgNode_
from .base_tracker import BaseTracker
class BaseHungarianTracker(BaseTracker):
"""
A base class for all Hungarian trackers
"""
@configurable
def __init__(
self,
video_height: int,
video_width: int,
max_num_instances: int = 200,
max_lost_frame_count: int = 0,
min_box_rel_dim: float = 0.02,
min_instance_period: int = 1,
**kwargs
):
"""
Args:
video_height: height the video frame
video_width: width of the video frame
max_num_instances: maximum number of id allowed to be tracked
max_lost_frame_count: maximum number of frame an id can lost tracking
exceed this number, an id is considered as lost
forever
min_box_rel_dim: a percentage, smaller than this dimension, a bbox is
removed from tracking
min_instance_period: an instance will be shown after this number of period
since its first showing up in the video
"""
super().__init__(**kwargs)
self._video_height = video_height
self._video_width = video_width
self._max_num_instances = max_num_instances
self._max_lost_frame_count = max_lost_frame_count
self._min_box_rel_dim = min_box_rel_dim
self._min_instance_period = min_instance_period
@classmethod
def from_config(cls, cfg: CfgNode_) -> Dict:
raise NotImplementedError("Calling HungarianTracker::from_config")
def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray:
raise NotImplementedError("Calling HungarianTracker::build_matrix")
def update(self, instances: Instances) -> Instances:
if instances.has("pred_keypoints"):
raise NotImplementedError("Need to add support for keypoints")
instances = self._initialize_extra_fields(instances)
if self._prev_instances is not None:
self._untracked_prev_idx = set(range(len(self._prev_instances)))
cost_matrix = self.build_cost_matrix(instances, self._prev_instances)
matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix)
instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx)
instances = self._process_unmatched_idx(instances, matched_idx)
instances = self._process_unmatched_prev_idx(instances, matched_prev_idx)
self._prev_instances = copy.deepcopy(instances)
return instances
def _initialize_extra_fields(self, instances: Instances) -> Instances:
"""
If input instances don't have ID, ID_period, lost_frame_count fields,
this method is used to initialize these fields.
Args:
instances: D2 Instances, for predictions of the current frame
Return:
D2 Instances with extra fields added
"""
if not instances.has("ID"):
instances.set("ID", [None] * len(instances))
if not instances.has("ID_period"):
instances.set("ID_period", [None] * len(instances))
if not instances.has("lost_frame_count"):
instances.set("lost_frame_count", [None] * len(instances))
if self._prev_instances is None:
instances.ID = list(range(len(instances)))
self._id_count += len(instances)
instances.ID_period = [1] * len(instances)
instances.lost_frame_count = [0] * len(instances)
return instances
def _process_matched_idx(
self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray
) -> Instances:
assert matched_idx.size == matched_prev_idx.size
for i in range(matched_idx.size):
instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]]
instances.ID_period[matched_idx[i]] = (
self._prev_instances.ID_period[matched_prev_idx[i]] + 1
)
instances.lost_frame_count[matched_idx[i]] = 0
return instances
def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances:
untracked_idx = set(range(len(instances))).difference(set(matched_idx))
for idx in untracked_idx:
instances.ID[idx] = self._id_count
self._id_count += 1
instances.ID_period[idx] = 1
instances.lost_frame_count[idx] = 0
return instances
def _process_unmatched_prev_idx(
self, instances: Instances, matched_prev_idx: np.ndarray
) -> Instances:
untracked_instances = Instances(
image_size=instances.image_size,
pred_boxes=[],
pred_masks=[],
pred_classes=[],
scores=[],
ID=[],
ID_period=[],
lost_frame_count=[],
)
prev_bboxes = list(self._prev_instances.pred_boxes)
prev_classes = list(self._prev_instances.pred_classes)
prev_scores = list(self._prev_instances.scores)
prev_ID_period = self._prev_instances.ID_period
if instances.has("pred_masks"):
prev_masks = list(self._prev_instances.pred_masks)
untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx))
for idx in untracked_prev_idx:
x_left, y_top, x_right, y_bot = prev_bboxes[idx]
if (
(1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim)
or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim)
or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count
or prev_ID_period[idx] <= self._min_instance_period
):
continue
untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy()))
untracked_instances.pred_classes.append(int(prev_classes[idx]))
untracked_instances.scores.append(float(prev_scores[idx]))
untracked_instances.ID.append(self._prev_instances.ID[idx])
untracked_instances.ID_period.append(self._prev_instances.ID_period[idx])
untracked_instances.lost_frame_count.append(
self._prev_instances.lost_frame_count[idx] + 1
)
if instances.has("pred_masks"):
untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8))
untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes))
untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes)
untracked_instances.scores = torch.FloatTensor(untracked_instances.scores)
if instances.has("pred_masks"):
untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks)
else:
untracked_instances.remove("pred_masks")
return Instances.cat(
[
instances,
untracked_instances,
]
)
|