File size: 19,670 Bytes
dbac20f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 |
import logging
from pathlib import Path
import einops
import torch
from omegaconf import OmegaConf
from timm.layers import trunc_normal_
from torch import nn
from mmaudio.ext.synchformer.utils import check_if_file_exists_else_download
from mmaudio.ext.synchformer.video_model_builder import VisionTransformer
FILE2URL = {
# cfg
'motionformer_224_16x4.yaml':
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
'joint_224_16x4.yaml':
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
'divided_224_16x4.yaml':
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
# ckpt
'ssv2_motionformer_224_16x4.pyth':
'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
'ssv2_joint_224_16x4.pyth':
'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
'ssv2_divided_224_16x4.pyth':
'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
}
class MotionFormer(VisionTransformer):
''' This class serves three puposes:
1. Renames the class to MotionFormer.
2. Downloads the cfg from the original repo and patches it if needed.
3. Takes care of feature extraction by redefining .forward()
- if `extract_features=True` and `factorize_space_time=False`,
the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
- if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
and spatial and temporal transformer encoder layers are used.
- if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
the output is of shape (B, D) and spatial and temporal transformer encoder layers
are used as well as the global representation is extracted from segments (extra pos emb
is added).
'''
def __init__(
self,
extract_features: bool = False,
ckpt_path: str = None,
factorize_space_time: bool = None,
agg_space_module: str = None,
agg_time_module: str = None,
add_global_repr: bool = True,
agg_segments_module: str = None,
max_segments: int = None,
):
self.extract_features = extract_features
self.ckpt_path = ckpt_path
self.factorize_space_time = factorize_space_time
if self.ckpt_path is not None:
check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
ckpt = torch.load(self.ckpt_path, map_location='cpu')
mformer_ckpt2cfg = {
'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
}
# init from motionformer ckpt or from our Stage I ckpt
# depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
# load the state dict differently
was_pt_on_avclip = self.ckpt_path.endswith(
'.pt') # checks if it is a stage I ckpt (FIXME: a bit generic)
if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
elif was_pt_on_avclip:
# TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
s1_cfg = ckpt.get('args', None) # Stage I cfg
if s1_cfg is not None:
s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
# if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
if s1_vfeat_extractor_ckpt_path is not None:
cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
else:
cfg_fname = 'divided_224_16x4.yaml'
else:
cfg_fname = 'divided_224_16x4.yaml'
else:
raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
else:
was_pt_on_avclip = False
cfg_fname = 'divided_224_16x4.yaml'
# logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
pos_emb_type = 'separate'
elif cfg_fname == 'joint_224_16x4.yaml':
pos_emb_type = 'joint'
self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
# patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
mformer_cfg.VIT.ATTN_DROPOUT = 0.0
mformer_cfg.VIT.POS_EMBED = pos_emb_type
mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing
mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
# finally init VisionTransformer with the cfg
super().__init__(mformer_cfg)
# load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
if (self.ckpt_path is not None) and (not was_pt_on_avclip):
_ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
if len(_ckpt_load_status.missing_keys) > 0 or len(
_ckpt_load_status.unexpected_keys) > 0:
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
else:
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
if self.extract_features:
assert isinstance(self.norm,
nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
# pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
self.pre_logits = nn.Identity()
# we don't need the classification head (saving memory)
self.head = nn.Identity()
self.head_drop = nn.Identity()
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
transf_enc_layer_kwargs = dict(
d_model=self.embed_dim,
nhead=self.num_heads,
activation=nn.GELU(),
batch_first=True,
dim_feedforward=self.mlp_ratio * self.embed_dim,
dropout=self.drop_rate,
layer_norm_eps=1e-6,
norm_first=True,
)
# define adapters if needed
if self.factorize_space_time:
if agg_space_module == 'TransformerEncoderLayer':
self.spatial_attn_agg = SpatialTransformerEncoderLayer(
**transf_enc_layer_kwargs)
elif agg_space_module == 'AveragePooling':
self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
then_permute_pattern='BS D t -> BS t D')
if agg_time_module == 'TransformerEncoderLayer':
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
elif agg_time_module == 'AveragePooling':
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
elif 'Identity' in agg_time_module:
self.temp_attn_agg = nn.Identity()
# define a global aggregation layer (aggregarate over segments)
self.add_global_repr = add_global_repr
if add_global_repr:
if agg_segments_module == 'TransformerEncoderLayer':
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
# we need to add pos emb (PE) because previously we added the same PE for each segment
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
self.global_attn_agg = TemporalTransformerEncoderLayer(
add_pos_emb=True,
pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
pos_max_len=pos_max_len,
**transf_enc_layer_kwargs)
elif agg_segments_module == 'AveragePooling':
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
if was_pt_on_avclip:
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
# and keep only the state_dict of the feat extractor
ckpt_weights = dict()
for k, v in ckpt['state_dict'].items():
if k.startswith(('module.v_encoder.', 'v_encoder.')):
k = k.replace('module.', '').replace('v_encoder.', '')
ckpt_weights[k] = v
_load_status = self.load_state_dict(ckpt_weights, strict=False)
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
f'Missing keys ({len(_load_status.missing_keys)}): ' \
f'{_load_status.missing_keys}, \n' \
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
f'{_load_status.unexpected_keys} \n' \
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
else:
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
# patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
# but it used to calculate the number of patches, so we need to set keep it
self.patch_embed.requires_grad_(False)
def forward(self, x):
'''
x is of shape (B, S, C, T, H, W) where S is the number of segments.
'''
# Batch, Segments, Channels, T=frames, Height, Width
B, S, C, T, H, W = x.shape
# Motionformer expects a tensor of shape (1, B, C, T, H, W).
# The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
# see `video_model_builder.video_input`.
# x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
orig_shape = (B, S, C, T, H, W)
x = x.view(B * S, C, T, H, W) # flatten batch and segments
x = self.forward_segments(x, orig_shape=orig_shape)
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
x = x.view(B, S, *x.shape[1:])
# x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
return x # x is (B, S, ...)
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
'''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
x, x_mask = self.forward_features(x)
assert self.extract_features
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
x = x[:,
1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
x = self.norm(x)
x = self.pre_logits(x)
if self.factorize_space_time:
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
x = self.temp_attn_agg(
x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
return x
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
'''
feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
From `self.patch_embed_3d`, it follows that we could reshape feats with:
`feats.transpose(1, 2).view(B*S, D, t, h, w)`
'''
B, S, C, T, H, W = orig_shape
D = self.embed_dim
# num patches in each dimension
t = T // self.patch_embed_3d.z_block_size
h = self.patch_embed_3d.height
w = self.patch_embed_3d.width
feats = feats.permute(0, 2, 1) # (B*S, D, T)
feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
return feats
class BaseEncoderLayer(nn.TransformerEncoderLayer):
'''
This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
to the sequence and outputs the CLS token's representation.
This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
We also, optionally, add a positional embedding to the input sequence which
allows to reuse it for global aggregation (of segments) for both streams.
'''
def __init__(self,
add_pos_emb: bool = False,
pos_emb_drop: float = None,
pos_max_len: int = None,
*args_transformer_enc,
**kwargs_transformer_enc):
super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
trunc_normal_(self.cls_token, std=.02)
# add positional embedding
self.add_pos_emb = add_pos_emb
if add_pos_emb:
self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
self.pos_drop = nn.Dropout(pos_emb_drop)
trunc_normal_(self.pos_emb, std=.02)
self.apply(self._init_weights)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
batch_dim = x.shape[0]
# add CLS token
cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
if x_mask is not None:
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
device=x_mask.device) # 1=keep; 0=mask
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
B, N = x_mask_w_cls.shape
# torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
.expand(-1, self.self_attn.num_heads, N, -1)\
.reshape(B * self.self_attn.num_heads, N, N)
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
else:
x_mask_w_cls = None
# add positional embedding
if self.add_pos_emb:
seq_len = x.shape[
1] # (don't even think about moving it before the CLS token concatenation)
assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
x = x + self.pos_emb[:, :seq_len, :]
x = self.pos_drop(x)
# apply encoder layer (calls nn.TransformerEncoderLayer.forward);
x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
# CLS token is expected to hold spatial information for each frame
x = x[:, 0, :] # (batch_dim, D)
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'cls_token', 'pos_emb'}
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
''' Aggregates spatial dimensions by applying attention individually to each frame. '''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
BS, D, t, h, w = x.shape
# time as a batch dimension and flatten spatial dimensions as sequence
x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
# similar to mask
if x_mask is not None:
x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
# reshape back to (B*S, t, D)
x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
# (B*S, t, D)
return x
class TemporalTransformerEncoderLayer(BaseEncoderLayer):
''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
in both streams. '''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x):
''' x is of shape (B*S, t, D) where S is the number of segments.
Returns a tensor of shape (B*S, D) pooling temporal information. '''
BS, t, D = x.shape
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
x = super().forward(x) # (B*S, D)
return x # (B*S, D)
class AveragePooling(nn.Module):
def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
''' patterns are e.g. "bs t d -> bs d" '''
super().__init__()
# TODO: need to register them as buffers (but fails because these are strings)
self.reduce_fn = 'mean'
self.avg_pattern = avg_pattern
self.then_permute_pattern = then_permute_pattern
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
if self.then_permute_pattern is not None:
x = einops.rearrange(x, self.then_permute_pattern)
return x
|