Spaces:
Running
on
A10G
Running
on
A10G
File size: 1,753 Bytes
320e465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import warnings
from typing import Iterable
import torch
from tracker.model.cutie import CUTIE
class ImageFeatureStore:
"""
A cache for image features.
These features might be reused at different parts of the inference pipeline.
This class provide an interface for reusing these features.
It is the user's responsibility to delete redundant features.
Feature of a frame should be associated with a unique index -- typically the frame id.
"""
def __init__(self, network: CUTIE, no_warning: bool = False):
self.network = network
self._store = {}
self.no_warning = no_warning
def _encode_feature(self, index: int, image: torch.Tensor) -> None:
ms_features, pix_feat = self.network.encode_image(image)
key, shrinkage, selection = self.network.transform_key(ms_features[0])
self._store[index] = (ms_features, pix_feat, key, shrinkage, selection)
def get_features(self, index: int,
image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor):
if index not in self._store:
self._encode_feature(index, image)
return self._store[index][:2]
def get_key(self, index: int,
image: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor):
if index not in self._store:
self._encode_feature(index, image)
return self._store[index][2:]
def delete(self, index: int) -> None:
if index in self._store:
del self._store[index]
def __len__(self):
return len(self._store)
def __del__(self):
if len(self._store) > 0 and not self.no_warning:
warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
|