Spaces:
Runtime error
Runtime error
File size: 6,047 Bytes
f1df74a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import logging
import torch
from einops import rearrange, repeat
from lvdm.models.utils_diffusion import timestep_embedding
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
mainlogger = logging.getLogger('mainlogger')
def TemporalTransformer_forward(self, x, context=None, is_imgbatch=False):
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
if self.use_linear:
x = self.proj_in(x)
temp_mask = None
if self.causal_attention:
temp_mask = torch.tril(torch.ones([1, t, t]))
if is_imgbatch:
temp_mask = torch.eye(t).unsqueeze(0)
if temp_mask is not None:
mask = temp_mask.to(x.device)
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b*h*w)
else:
mask = None
if self.only_self_att:
## note: if no context is given, cross-attention defaults to self-attention
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, mask=mask)
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
context = rearrange(context, '(b t) l con -> b t l con', t=t).contiguous()
for i, block in enumerate(self.transformer_blocks):
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
unit_context = context[j][0:1]
context_j = repeat(unit_context, 't l con -> (t r) l con', r=(h * w)).contiguous()
## note: causal mask will not applied in cross-attention case
x[j] = block(x[j], context=context_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h, w=w).contiguous()
if self.use_image_dataset:
x = 0.0 * x + x_in
else:
x = x + x_in
return x
def selfattn_forward_unet(self, x, timesteps, context=None, y=None, features_adapter=None, is_imgbatch=False, T=None, **kwargs):
b,_,t,_,_ = x.shape
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.micro_condition and y is not None:
micro_emb = timestep_embedding(y, self.model_channels, repeat_only=False)
emb = emb + self.micro_embed(micro_emb)
# pose_emb = pose_emb.reshape(-1, pose_emb.shape[-1])
## repeat t times for context [(b t) 77 768] & time embedding
if not is_imgbatch:
context = context.repeat_interleave(repeats=t, dim=0)
if 'pose_emb' in kwargs:
pose_emb = kwargs.pop('pose_emb')
context = { 'context': context, 'pose_emb': pose_emb }
emb = emb.repeat_interleave(repeats=t, dim=0)
## always in shape (b t) c h w, except for temporal layer
x = rearrange(x, 'b c t h w -> (b t) c h w')
if features_adapter is not None:
features_adapter = [rearrange(feature, 'b c t h w -> (b t) c h w') for feature in features_adapter]
h = x.type(self.dtype)
adapter_idx = 0
hs = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch)
if id ==0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b,is_imgbatch=is_imgbatch)
## plug-in adapter features
if ((id+1)%3 == 0) and features_adapter is not None:
# if adapter_idx == 0 or adapter_idx == 1 or adapter_idx == 2:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
hs.append(h)
if features_adapter is not None:
assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch)
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b, is_imgbatch=is_imgbatch)
h = h.type(x.dtype)
y = self.out(h)
# reshape back to (b c t h w)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
return y
def spatial_forward_BasicTransformerBlock(self, x, context=None, mask=None):
if isinstance(context, dict):
context = context['context']
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
def temporal_selfattn_forward_BasicTransformerBlock(self, x, context=None, mask=None):
if isinstance(context, dict) and 'pose_emb' in context:
pose_emb = context['pose_emb'] # {channel_num: [B, video_length, pose_dim, pose_embedding_dim]}
context = None
else:
pose_emb = None
context = None
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None, mask=mask) + x
# Add camera pose
if pose_emb is not None:
B, t, _, _ = pose_emb.shape # [B, video_length, pose_dim, pose_embedding_dim]
hw = x.shape[0] // B
pose_emb = pose_emb.reshape(B, t, -1)
pose_emb = pose_emb.repeat_interleave(repeats=hw, dim=0)
x = self.cc_projection(torch.cat([x, pose_emb], dim=-1))
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
|