|
|
|
|
|
from functools import partial |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from ultralytics.utils import IterableSimpleNamespace, yaml_load |
|
from ultralytics.utils.checks import check_yaml |
|
from .bot_sort import BOTSORT |
|
from .byte_tracker import BYTETracker |
|
|
|
|
|
TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} |
|
|
|
|
|
def on_predict_start(predictor: object, persist: bool = False) -> None: |
|
""" |
|
Initialize trackers for object tracking during prediction. |
|
|
|
Args: |
|
predictor (object): The predictor object to initialize trackers for. |
|
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. |
|
|
|
Raises: |
|
AssertionError: If the tracker_type is not 'bytetrack' or 'botsort'. |
|
""" |
|
if hasattr(predictor, "trackers") and persist: |
|
return |
|
|
|
tracker = check_yaml(predictor.args.tracker) |
|
cfg = IterableSimpleNamespace(**yaml_load(tracker)) |
|
|
|
if cfg.tracker_type not in ["bytetrack", "botsort"]: |
|
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'") |
|
|
|
trackers = [] |
|
for _ in range(predictor.dataset.bs): |
|
tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) |
|
trackers.append(tracker) |
|
if predictor.dataset.mode != "stream": |
|
break |
|
predictor.trackers = trackers |
|
predictor.vid_path = [None] * predictor.dataset.bs |
|
|
|
|
|
def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None: |
|
""" |
|
Postprocess detected boxes and update with object tracking. |
|
|
|
Args: |
|
predictor (object): The predictor object containing the predictions. |
|
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. |
|
""" |
|
path, im0s = predictor.batch[:2] |
|
|
|
is_obb = predictor.args.task == "obb" |
|
is_stream = predictor.dataset.mode == "stream" |
|
for i in range(len(im0s)): |
|
tracker = predictor.trackers[i if is_stream else 0] |
|
vid_path = predictor.save_dir / Path(path[i]).name |
|
if not persist and predictor.vid_path[i if is_stream else 0] != vid_path: |
|
tracker.reset() |
|
predictor.vid_path[i if is_stream else 0] = vid_path |
|
|
|
det = (predictor.results[i].obb if is_obb else predictor.results[i].boxes).cpu().numpy() |
|
if len(det) == 0: |
|
continue |
|
tracks = tracker.update(det, im0s[i]) |
|
if len(tracks) == 0: |
|
continue |
|
idx = tracks[:, -1].astype(int) |
|
predictor.results[i] = predictor.results[i][idx] |
|
|
|
update_args = dict() |
|
update_args["obb" if is_obb else "boxes"] = torch.as_tensor(tracks[:, :-1]) |
|
predictor.results[i].update(**update_args) |
|
|
|
|
|
def register_tracker(model: object, persist: bool) -> None: |
|
""" |
|
Register tracking callbacks to the model for object tracking during prediction. |
|
|
|
Args: |
|
model (object): The model object to register tracking callbacks for. |
|
persist (bool): Whether to persist the trackers if they already exist. |
|
""" |
|
model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) |
|
model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) |
|
|