import torch.nn as nn from einops import rearrange from lvdm.models.ddpm3d import LatentDiffusion from motionctrl.lvdm_modified_modules import ( TemporalTransformer_forward, selfattn_forward_unet, spatial_forward_BasicTransformerBlock, temporal_selfattn_forward_BasicTransformerBlock) from utils.utils import instantiate_from_config class MotionCtrl(LatentDiffusion): def __init__(self, omcm_config=None, pose_dim=12, context_dim=1024, *args, **kwargs): super(MotionCtrl, self).__init__(*args, **kwargs) # object motion control module if omcm_config is not None: self.omcm = instantiate_from_config(omcm_config) else: self.omcm = None # camera motion control module bound_method = selfattn_forward_unet.__get__( self.model.diffusion_model, self.model.diffusion_model.__class__) setattr(self.model.diffusion_model, 'forward', bound_method) for _name, _module in self.model.diffusion_model.named_modules(): if _module.__class__.__name__ == 'TemporalTransformer': bound_method = TemporalTransformer_forward.__get__( _module, _module.__class__) setattr(_module, 'forward', bound_method) if _module.__class__.__name__ == 'BasicTransformerBlock': # SpatialTransformer only if _module.attn2.to_k.in_features != context_dim: # TemporalTransformer without crossattn bound_method = temporal_selfattn_forward_BasicTransformerBlock.__get__( _module, _module.__class__) setattr(_module, '_forward', bound_method) cc_projection = nn.Linear(_module.attn2.to_k.in_features + pose_dim, _module.attn2.to_k.in_features) nn.init.eye_(list(cc_projection.parameters())[0][:_module.attn2.to_k.in_features, :_module.attn2.to_k.in_features]) nn.init.zeros_(list(cc_projection.parameters())[1]) cc_projection.requires_grad_(True) _module.add_module('cc_projection', cc_projection) else: bound_method = spatial_forward_BasicTransformerBlock.__get__( _module, _module.__class__) setattr(_module, '_forward', bound_method) def get_traj_features(self, extra_cond): b, c, t, h, w = extra_cond.shape ## process in 2D manner extra_cond = rearrange(extra_cond, 'b c t h w -> (b t) c h w') traj_features = self.omcm(extra_cond) traj_features = [rearrange(feature, '(b t) c h w -> b c t h w', b=b, t=t) for feature in traj_features] return traj_features