|
import logging |
|
from typing import Any, Mapping |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from mmaudio.ext.synchformer.motionformer import MotionFormer |
|
|
|
|
|
class Synchformer(nn.Module): |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self.vfeat_extractor = MotionFormer(extract_features=True, |
|
factorize_space_time=True, |
|
agg_space_module='TransformerEncoderLayer', |
|
agg_time_module='torch.nn.Identity', |
|
add_global_repr=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, vis): |
|
B, S, Tv, C, H, W = vis.shape |
|
vis = vis.permute(0, 1, 3, 2, 4, 5) |
|
|
|
|
|
vis = self.vfeat_extractor(vis) |
|
return vis |
|
|
|
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): |
|
|
|
sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} |
|
|
|
return super().load_state_dict(sd, strict) |
|
|
|
|
|
if __name__ == "__main__": |
|
model = Synchformer().cuda().eval() |
|
sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True) |
|
model.load_state_dict(sd) |
|
|
|
vid = torch.randn(2, 7, 16, 3, 224, 224).cuda() |
|
features = model.extract_vfeats(vid, for_loop=False).detach().cpu() |
|
print(features.shape) |
|
|
|
|
|
|
|
|
|
|