Spaces:
Running
on
A10G
Running
on
A10G
from typing import Union, List, Dict | |
import torch | |
from tracker.inference.object_info import ObjectInfo | |
class ObjectManager: | |
""" | |
Object IDs are immutable. The same ID always represent the same object. | |
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. | |
Temporary IDs start from 1. | |
""" | |
def __init__(self): | |
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} | |
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} | |
self.obj_id_to_obj: Dict[int, ObjectInfo] = {} | |
self.all_historical_object_ids: List[int] = [] | |
def _recompute_obj_id_to_obj_mapping(self) -> None: | |
self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} | |
def add_new_objects( | |
self, objects: Union[List[ObjectInfo], ObjectInfo, | |
List[int]]) -> (List[int], List[int]): | |
if not isinstance(objects, list): | |
objects = [objects] | |
corresponding_tmp_ids = [] | |
corresponding_obj_ids = [] | |
for obj in objects: | |
if isinstance(obj, int): | |
obj = ObjectInfo(id=obj) | |
if obj in self.obj_to_tmp_id: | |
# old object | |
corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) | |
corresponding_obj_ids.append(obj.id) | |
else: | |
# new object | |
new_obj = ObjectInfo(id=obj) | |
# new object | |
new_tmp_id = len(self.obj_to_tmp_id) + 1 | |
self.obj_to_tmp_id[new_obj] = new_tmp_id | |
self.tmp_id_to_obj[new_tmp_id] = new_obj | |
self.all_historical_object_ids.append(new_obj.id) | |
corresponding_tmp_ids.append(new_tmp_id) | |
corresponding_obj_ids.append(new_obj.id) | |
self._recompute_obj_id_to_obj_mapping() | |
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) | |
return corresponding_tmp_ids, corresponding_obj_ids | |
def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None: | |
# delete an object or a list of objects | |
# re-sort the tmp ids | |
if isinstance(obj_ids_to_remove, int): | |
obj_ids_to_remove = [obj_ids_to_remove] | |
new_tmp_id = 1 | |
total_num_id = len(self.obj_to_tmp_id) | |
local_obj_to_tmp_id = {} | |
local_tmp_to_obj_id = {} | |
for tmp_iter in range(1, total_num_id + 1): | |
obj = self.tmp_id_to_obj[tmp_iter] | |
if obj.id not in obj_ids_to_remove: | |
local_obj_to_tmp_id[obj] = new_tmp_id | |
local_tmp_to_obj_id[new_tmp_id] = obj | |
new_tmp_id += 1 | |
self.obj_to_tmp_id = local_obj_to_tmp_id | |
self.tmp_id_to_obj = local_tmp_to_obj_id | |
self._recompute_obj_id_to_obj_mapping() | |
def purge_inactive_objects(self, | |
max_missed_detection_count: int) -> (bool, List[int], List[int]): | |
# remove tmp ids of objects that are removed | |
obj_id_to_be_deleted = [] | |
tmp_id_to_be_deleted = [] | |
tmp_id_to_keep = [] | |
obj_id_to_keep = [] | |
for obj in self.obj_to_tmp_id: | |
if obj.poke_count > max_missed_detection_count: | |
obj_id_to_be_deleted.append(obj.id) | |
tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) | |
else: | |
tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) | |
obj_id_to_keep.append(obj.id) | |
purge_activated = len(obj_id_to_be_deleted) > 0 | |
if purge_activated: | |
self.delete_object(obj_id_to_be_deleted) | |
return purge_activated, tmp_id_to_keep, obj_id_to_keep | |
def tmp_to_obj_cls(self, mask) -> torch.Tensor: | |
# remap tmp id cls representation to the true object id representation | |
new_mask = torch.zeros_like(mask) | |
for tmp_id, obj in self.tmp_id_to_obj.items(): | |
new_mask[mask == tmp_id] = obj.id | |
return new_mask | |
def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: | |
# returns the mapping in a dict format for saving it with pickle | |
return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} | |
def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: | |
# turns a dict indexed by obj id into a tensor, ordered by tmp IDs | |
output = [] | |
for _, obj in self.tmp_id_to_obj.items(): | |
if obj.id not in obj_dict: | |
raise NotImplementedError | |
output.append(obj_dict[obj.id]) | |
output = torch.stack(output, dim=dim) | |
return output | |
def make_one_hot(self, cls_mask) -> torch.Tensor: | |
output = [] | |
for _, obj in self.tmp_id_to_obj.items(): | |
output.append(cls_mask == obj.id) | |
if len(output) == 0: | |
output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) | |
else: | |
output = torch.stack(output, dim=0) | |
return output | |
def all_obj_ids(self) -> List[int]: | |
return [k.id for k in self.obj_to_tmp_id] | |
def num_obj(self) -> int: | |
return len(self.obj_to_tmp_id) | |
def has_all(self, objects: List[int]) -> bool: | |
for obj in objects: | |
if obj not in self.obj_to_tmp_id: | |
return False | |
return True | |
def find_object_by_id(self, obj_id) -> ObjectInfo: | |
return self.obj_id_to_obj[obj_id] | |
def find_tmp_by_id(self, obj_id) -> int: | |
return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] | |