Spaces:
Runtime error
Runtime error
import logging | |
import torch | |
from einops import rearrange, repeat | |
from lvdm.models.utils_diffusion import timestep_embedding | |
try: | |
import xformers | |
import xformers.ops | |
XFORMERS_IS_AVAILBLE = True | |
except: | |
XFORMERS_IS_AVAILBLE = False | |
mainlogger = logging.getLogger('mainlogger') | |
def TemporalTransformer_forward(self, x, context=None, is_imgbatch=False): | |
b, c, t, h, w = x.shape | |
x_in = x | |
x = self.norm(x) | |
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous() | |
if not self.use_linear: | |
x = self.proj_in(x) | |
x = rearrange(x, 'bhw c t -> bhw t c').contiguous() | |
if self.use_linear: | |
x = self.proj_in(x) | |
temp_mask = None | |
if self.causal_attention: | |
temp_mask = torch.tril(torch.ones([1, t, t])) | |
if is_imgbatch: | |
temp_mask = torch.eye(t).unsqueeze(0) | |
if temp_mask is not None: | |
mask = temp_mask.to(x.device) | |
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w) | |
else: | |
mask = None | |
if self.only_self_att: | |
## note: if no context is given, cross-attention defaults to self-attention | |
for i, block in enumerate(self.transformer_blocks): | |
x = block(x, context=context, mask=mask) | |
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() | |
else: | |
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous() | |
context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous() | |
for i, block in enumerate(self.transformer_blocks): | |
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package) | |
for j in range(b): | |
unit_context = context[j][0:1] | |
context_j = repeat(unit_context, 't l con -> (t r) l con', r=(h * w)).contiguous() | |
## note: causal mask will not applied in cross-attention case | |
x[j] = block(x[j], context=context_j) | |
if self.use_linear: | |
x = self.proj_out(x) | |
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous() | |
if not self.use_linear: | |
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous() | |
x = self.proj_out(x) | |
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous() | |
if self.use_image_dataset: | |
x = 0.0 * x + x_in | |
else: | |
x = x + x_in | |
return x | |
def selfattn_forward_unet(self, x, timesteps, context=None, y=None, features_adapter=None, is_imgbatch=False, T=None, **kwargs): | |
b,_,t,_,_ = x.shape | |
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) | |
emb = self.time_embed(t_emb) | |
if self.micro_condition and y is not None: | |
micro_emb = timestep_embedding(y, self.model_channels, repeat_only=False) | |
emb = emb + self.micro_embed(micro_emb) | |
# pose_emb = pose_emb.reshape(-1, pose_emb.shape[-1]) | |
## repeat t times for context [(b t) 77 768] & time embedding | |
if not is_imgbatch: | |
context = context.repeat_interleave(repeats=t, dim=0) | |
if 'pose_emb' in kwargs: | |
pose_emb = kwargs.pop('pose_emb') | |
context = { 'context': context, 'pose_emb': pose_emb } | |
emb = emb.repeat_interleave(repeats=t, dim=0) | |
## always in shape (b t) c h w, except for temporal layer | |
x = rearrange(x, 'b c t h w -> (b t) c h w') | |
if features_adapter is not None: | |
features_adapter = [rearrange(feature, 'b c t h w -> (b t) c h w') for feature in features_adapter] | |
h = x.type(self.dtype) | |
adapter_idx = 0 | |
hs = [] | |
for id, module in enumerate(self.input_blocks): | |
h = module(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch) | |
if id ==0 and self.addition_attention: | |
h = self.init_attn(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch) | |
## plug-in adapter features | |
if ((id+1)%3 == 0) and features_adapter is not None: | |
# if adapter_idx == 0 or adapter_idx == 1 or adapter_idx == 2: | |
h = h + features_adapter[adapter_idx] | |
adapter_idx += 1 | |
hs.append(h) | |
if features_adapter is not None: | |
assert len(features_adapter)==adapter_idx, 'Wrong features_adapter' | |
h = self.middle_block(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch) | |
for module in self.output_blocks: | |
h = torch.cat([h, hs.pop()], dim=1) | |
h = module(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch) | |
h = h.type(x.dtype) | |
y = self.out(h) | |
# reshape back to (b c t h w) | |
y = rearrange(y, '(b t) c h w -> b c t h w', b=b) | |
return y | |
def spatial_forward_BasicTransformerBlock(self, x, context=None, mask=None): | |
if isinstance(context, dict): | |
context = context['context'] | |
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x | |
x = self.attn2(self.norm2(x), context=context, mask=mask) + x | |
x = self.ff(self.norm3(x)) + x | |
return x | |
def temporal_selfattn_forward_BasicTransformerBlock(self, x, context=None, mask=None): | |
if isinstance(context, dict) and 'pose_emb' in context: | |
pose_emb = context['pose_emb'] # {channel_num: [B, video_length, pose_dim, pose_embedding_dim]} | |
context = None | |
else: | |
pose_emb = None | |
context = None | |
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x | |
# Add camera pose | |
if pose_emb is not None: | |
B, t, _, _ = pose_emb.shape # [B, video_length, pose_dim, pose_embedding_dim] | |
hw = x.shape[0] // B | |
pose_emb = pose_emb.reshape(B, t, -1) | |
pose_emb = pose_emb.repeat_interleave(repeats=hw, dim=0) | |
x = self.cc_projection(torch.cat([x, pose_emb], dim=-1)) | |
x = self.attn2(self.norm2(x), context=context, mask=mask) + x | |
x = self.ff(self.norm3(x)) + x | |
return x | |