Spaces:
Running
on
A10G
Running
on
A10G
from typing import List, Dict | |
import logging | |
from omegaconf import DictConfig | |
import torch | |
import torch.nn as nn | |
from tracker.model.modules import * | |
from tracker.model.big_modules import * | |
from tracker.model.aux_modules import AuxComputer | |
from tracker.model.utils.memory_utils import * | |
from tracker.model.transformer.object_transformer import QueryTransformer | |
from tracker.model.transformer.object_summarizer import ObjectSummarizer | |
from tracker.utils.tensor_utils import aggregate | |
log = logging.getLogger() | |
class CUTIE(nn.Module): | |
def __init__(self, cfg: DictConfig, *, single_object=False): | |
super().__init__() | |
model_cfg = cfg.model | |
self.ms_dims = model_cfg.pixel_encoder.ms_dims | |
self.key_dim = model_cfg.key_dim | |
self.value_dim = model_cfg.value_dim | |
self.sensory_dim = model_cfg.sensory_dim | |
self.pixel_dim = model_cfg.pixel_dim | |
self.embed_dim = model_cfg.embed_dim | |
self.single_object = single_object | |
log.info(f'Single object: {self.single_object}') | |
self.pixel_encoder = PixelEncoder(model_cfg) | |
self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1) | |
self.key_proj = KeyProjection(model_cfg) | |
self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object) | |
self.mask_decoder = MaskDecoder(model_cfg) | |
self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object) | |
self.object_transformer = QueryTransformer(model_cfg) | |
self.object_summarizer = ObjectSummarizer(model_cfg) | |
self.aux_computer = AuxComputer(cfg) | |
self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False) | |
self.register_buffer("pixel_std", torch.Tensor(model_cfg.pixel_std).view(-1, 1, 1), False) | |
def _get_others(self, masks: torch.Tensor) -> torch.Tensor: | |
# for each object, return the sum of masks of all other objects | |
if self.single_object: | |
return None | |
num_objects = masks.shape[1] | |
if num_objects >= 1: | |
others = (masks.sum(dim=1, keepdim=True) - masks).clamp(0, 1) | |
else: | |
others = torch.zeros_like(masks) | |
return others | |
def encode_image(self, image: torch.Tensor) -> (Iterable[torch.Tensor], torch.Tensor): | |
image = (image - self.pixel_mean) / self.pixel_std | |
ms_image_feat = self.pixel_encoder(image) | |
return ms_image_feat, self.pix_feat_proj(ms_image_feat[0]) | |
def encode_mask( | |
self, | |
image: torch.Tensor, | |
ms_features: List[torch.Tensor], | |
sensory: torch.Tensor, | |
masks: torch.Tensor, | |
*, | |
deep_update: bool = True, | |
chunk_size: int = -1, | |
need_weights: bool = False) -> (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor): | |
image = (image - self.pixel_mean) / self.pixel_std | |
others = self._get_others(masks) | |
mask_value, new_sensory = self.mask_encoder(image, | |
ms_features, | |
sensory, | |
masks, | |
others, | |
deep_update=deep_update, | |
chunk_size=chunk_size) | |
object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights) | |
return mask_value, new_sensory, object_summaries, object_logits | |
def transform_key(self, | |
final_pix_feat: torch.Tensor, | |
*, | |
need_sk: bool = True, | |
need_ek: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor): | |
key, shrinkage, selection = self.key_proj(final_pix_feat, need_s=need_sk, need_e=need_ek) | |
return key, shrinkage, selection | |
# Used in training only. | |
# This step is replaced by MemoryManager in test time | |
def read_memory(self, query_key: torch.Tensor, query_selection: torch.Tensor, | |
memory_key: torch.Tensor, memory_shrinkage: torch.Tensor, | |
msk_value: torch.Tensor, obj_memory: torch.Tensor, pix_feat: torch.Tensor, | |
sensory: torch.Tensor, last_mask: torch.Tensor, | |
selector: torch.Tensor) -> (torch.Tensor, Dict[str, torch.Tensor]): | |
""" | |
query_key : B * CK * H * W | |
query_selection : B * CK * H * W | |
memory_key : B * CK * T * H * W | |
memory_shrinkage: B * 1 * T * H * W | |
msk_value : B * num_objects * CV * T * H * W | |
obj_memory : B * num_objects * T * num_summaries * C | |
pixel_feature : B * C * H * W | |
""" | |
batch_size, num_objects = msk_value.shape[:2] | |
# read using visual attention | |
with torch.cuda.amp.autocast(enabled=False): | |
affinity = get_affinity(memory_key.float(), memory_shrinkage.float(), query_key.float(), | |
query_selection.float()) | |
msk_value = msk_value.flatten(start_dim=1, end_dim=2).float() | |
# B * (num_objects*CV) * H * W | |
pixel_readout = readout(affinity, msk_value) | |
pixel_readout = pixel_readout.view(batch_size, num_objects, self.value_dim, | |
*pixel_readout.shape[-2:]) | |
pixel_readout = self.pixel_fusion(pix_feat, pixel_readout, sensory, last_mask) | |
# read from query transformer | |
mem_readout, aux_features = self.readout_query(pixel_readout, obj_memory, selector=selector) | |
aux_output = { | |
'sensory': sensory, | |
'q_logits': aux_features['logits'] if aux_features else None, | |
'attn_mask': aux_features['attn_mask'] if aux_features else None, | |
} | |
return mem_readout, aux_output | |
def pixel_fusion(self, | |
pix_feat: torch.Tensor, | |
pixel: torch.Tensor, | |
sensory: torch.Tensor, | |
last_mask: torch.Tensor, | |
*, | |
chunk_size: int = -1) -> torch.Tensor: | |
last_mask = F.interpolate(last_mask, size=sensory.shape[-2:], mode='area') | |
last_others = self._get_others(last_mask) | |
fused = self.pixel_fuser(pix_feat, | |
pixel, | |
sensory, | |
last_mask, | |
last_others, | |
chunk_size=chunk_size) | |
return fused | |
def readout_query(self, | |
pixel_readout, | |
obj_memory, | |
*, | |
selector=None, | |
need_weights=False) -> (torch.Tensor, Dict[str, torch.Tensor]): | |
return self.object_transformer(pixel_readout, | |
obj_memory, | |
selector=selector, | |
need_weights=need_weights) | |
def segment(self, | |
ms_image_feat: List[torch.Tensor], | |
memory_readout: torch.Tensor, | |
sensory: torch.Tensor, | |
*, | |
selector: bool = None, | |
chunk_size: int = -1, | |
update_sensory: bool = True) -> (torch.Tensor, torch.Tensor, torch.Tensor): | |
""" | |
multi_scale_features is from the key encoder for skip-connection | |
memory_readout is from working/long-term memory | |
sensory is the sensory memory | |
last_mask is the mask from the last frame, supplementing sensory memory | |
selector is 1 if an object exists, and 0 otherwise. We use it to filter padded objects | |
during training. | |
""" | |
sensory, logits = self.mask_decoder(ms_image_feat, | |
memory_readout, | |
sensory, | |
chunk_size=chunk_size, | |
update_sensory=update_sensory) | |
prob = torch.sigmoid(logits) | |
if selector is not None: | |
prob = prob * selector | |
# Softmax over all objects[] | |
logits = aggregate(prob, dim=1) | |
logits = F.interpolate(logits, scale_factor=4, mode='bilinear', align_corners=False) | |
prob = F.softmax(logits, dim=1) | |
return sensory, logits, prob | |
def compute_aux(self, pix_feat: torch.Tensor, aux_inputs: Dict[str, torch.Tensor], | |
selector: torch.Tensor) -> Dict[str, torch.Tensor]: | |
return self.aux_computer(pix_feat, aux_inputs, selector) | |
def forward(self, *args, **kwargs): | |
raise NotImplementedError | |
def load_weights(self, src_dict, init_as_zero_if_needed=False) -> None: | |
if not self.single_object: | |
# Map single-object weight to multi-object weight (4->5 out channels in conv1) | |
for k in list(src_dict.keys()): | |
if k == 'mask_encoder.conv1.weight': | |
if src_dict[k].shape[1] == 4: | |
log.info(f'Converting {k} from single object to multiple objects.') | |
pads = torch.zeros((64, 1, 7, 7), device=src_dict[k].device) | |
if not init_as_zero_if_needed: | |
nn.init.orthogonal_(pads) | |
log.info(f'Randomly initialized padding for {k}.') | |
else: | |
log.info(f'Zero-initialized padding for {k}.') | |
src_dict[k] = torch.cat([src_dict[k], pads], 1) | |
elif k == 'pixel_fuser.sensory_compress.weight': | |
if src_dict[k].shape[1] == self.sensory_dim + 1: | |
log.info(f'Converting {k} from single object to multiple objects.') | |
pads = torch.zeros((self.value_dim, 1, 1, 1), device=src_dict[k].device) | |
if not init_as_zero_if_needed: | |
nn.init.orthogonal_(pads) | |
log.info(f'Randomly initialized padding for {k}.') | |
else: | |
log.info(f'Zero-initialized padding for {k}.') | |
src_dict[k] = torch.cat([src_dict[k], pads], 1) | |
elif self.single_object: | |
""" | |
If the model is multiple-object and we are training in single-object, | |
we strip the last channel of conv1. | |
This is not supposed to happen in standard training except when users are trying to | |
finetune a trained model with single object datasets. | |
""" | |
if src_dict['mask_encoder.conv1.weight'].shape[1] == 5: | |
log.warning(f'Converting {k} from multiple objects to single object.' | |
'This is not supposed to happen in standard training.') | |
src_dict[k] = src_dict[k][:, :-1] | |
for k in src_dict: | |
if k not in self.state_dict(): | |
log.info(f'Key {k} found in src_dict but not in self.state_dict()!!!') | |
for k in self.state_dict(): | |
if k not in src_dict: | |
log.info(f'Key {k} found in self.state_dict() but not in src_dict!!!') | |
self.load_state_dict(src_dict, strict=False) | |
def device(self) -> torch.device: | |
return self.pixel_mean.device |