Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Alibaba, Inc. and its affiliates. | |
import math | |
import os | |
from abc import abstractmethod | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import xformers | |
import xformers.ops | |
from einops import rearrange | |
from fairscale.nn.checkpoint import checkpoint_wrapper | |
from timm.models.vision_transformer import Mlp | |
USE_TEMPORAL_TRANSFORMER = True | |
class CaptionEmbedder(nn.Module): | |
""" | |
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. | |
""" | |
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate="tanh"), token_num=120): | |
super().__init__() | |
self.y_proj = Mlp( | |
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, drop=0 | |
) | |
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels**0.5)) | |
self.uncond_prob = uncond_prob | |
def token_drop(self, caption, force_drop_ids=None): | |
""" | |
Drops labels to enable classifier-free guidance. | |
""" | |
if force_drop_ids is None: | |
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob | |
else: | |
drop_ids = force_drop_ids == 1 | |
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) | |
return caption | |
def forward(self, caption, train, force_drop_ids=None): | |
if train: | |
assert caption.shape[2:] == self.y_embedding.shape | |
use_dropout = self.uncond_prob > 0 | |
if (train and use_dropout) or (force_drop_ids is not None): | |
caption = self.token_drop(caption, force_drop_ids) | |
caption = self.y_proj(caption) | |
return caption | |
class DropPath(nn.Module): | |
r"""DropPath but without rescaling and supports optional all-zero and/or all-keep. | |
""" | |
def __init__(self, p): | |
super(DropPath, self).__init__() | |
self.p = p | |
def forward(self, *args, zero=None, keep=None): | |
if not self.training: | |
return args[0] if len(args) == 1 else args | |
# params | |
x = args[0] | |
b = x.size(0) | |
n = (torch.rand(b) < self.p).sum() | |
# non-zero and non-keep mask | |
mask = x.new_ones(b, dtype=torch.bool) | |
if keep is not None: | |
mask[keep] = False | |
if zero is not None: | |
mask[zero] = False | |
# drop-path index | |
index = torch.where(mask)[0] | |
index = index[torch.randperm(len(index))[:n]] | |
if zero is not None: | |
index = torch.cat([index, torch.where(zero)[0]], dim=0) | |
# drop-path multiplier | |
multiplier = x.new_ones(b) | |
multiplier[index] = 0.0 | |
output = tuple(u * self.broadcast(multiplier, u) for u in args) | |
return output[0] if len(args) == 1 else output | |
def broadcast(self, src, dst): | |
assert src.size(0) == dst.size(0) | |
shape = (dst.size(0), ) + (1, ) * (dst.ndim - 1) | |
return src.view(shape) | |
def sinusoidal_embedding(timesteps, dim): | |
# check input | |
half = dim // 2 | |
timesteps = timesteps.float() | |
# compute sinusoidal embedding | |
sinusoid = torch.outer( | |
timesteps, torch.pow(10000, | |
-torch.arange(half).to(timesteps).div(half))) | |
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |
if dim % 2 != 0: | |
x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) | |
return x | |
def exists(x): | |
return x is not None | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if callable(d) else d | |
def prob_mask_like(shape, prob, device): | |
if prob == 1: | |
return torch.ones(shape, device=device, dtype=torch.bool) | |
elif prob == 0: | |
return torch.zeros(shape, device=device, dtype=torch.bool) | |
else: | |
mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob | |
# aviod mask all, which will cause find_unused_parameters error | |
if mask.all(): | |
mask[0] = False | |
return mask | |
class MemoryEfficientCrossAttention(nn.Module): | |
def __init__(self, | |
query_dim, | |
context_dim=None, | |
heads=8, | |
dim_head=64, | |
max_bs=16384, | |
dropout=0.0): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = default(context_dim, query_dim) | |
self.max_bs = max_bs | |
self.heads = heads | |
self.dim_head = dim_head | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) | |
self.attention_op: Optional[Any] = None | |
def forward(self, x, context=None, mask=None): | |
q = self.to_q(x) | |
context = default(context, x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
b, _, _ = q.shape | |
q, k, v = map( | |
lambda t: t.unsqueeze(3).reshape(b, t.shape[ | |
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( | |
b * self.heads, t.shape[1], self.dim_head).contiguous(), | |
(q, k, v), | |
) | |
# actually compute the attention, what we cannot get enough of. | |
if q.shape[0] > self.max_bs: | |
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0) | |
k_list = torch.chunk(k, k.shape[0] // self.max_bs, dim=0) | |
v_list = torch.chunk(v, v.shape[0] // self.max_bs, dim=0) | |
out_list = [] | |
for q_1, k_1, v_1 in zip(q_list, k_list, v_list): | |
out = xformers.ops.memory_efficient_attention( | |
q_1, k_1, v_1, attn_bias=None, op=self.attention_op) | |
out_list.append(out) | |
out = torch.cat(out_list, dim=0) | |
else: | |
out = xformers.ops.memory_efficient_attention( | |
q, k, v, attn_bias=None, op=self.attention_op) | |
if exists(mask): | |
raise NotImplementedError | |
out = ( | |
out.unsqueeze(0).reshape( | |
b, self.heads, out.shape[1], | |
self.dim_head).permute(0, 2, 1, | |
3).reshape(b, out.shape[1], | |
self.heads * self.dim_head)) | |
return self.to_out(out) | |
class RelativePositionBias(nn.Module): | |
def __init__(self, heads=8, num_buckets=32, max_distance=128): | |
super().__init__() | |
self.num_buckets = num_buckets | |
self.max_distance = max_distance | |
self.relative_attention_bias = nn.Embedding(num_buckets, heads) | |
def _relative_position_bucket(relative_position, | |
num_buckets=32, | |
max_distance=128): | |
ret = 0 | |
n = -relative_position | |
num_buckets //= 2 | |
ret += (n < 0).long() * num_buckets | |
n = torch.abs(n) | |
max_exact = num_buckets // 2 | |
is_small = n < max_exact | |
val_if_large = max_exact + ( | |
torch.log(n.float() / max_exact) | |
/ math.log(max_distance / max_exact) * # noqa | |
(num_buckets - max_exact)).long() | |
val_if_large = torch.min( | |
val_if_large, torch.full_like(val_if_large, num_buckets - 1)) | |
ret += torch.where(is_small, n, val_if_large) | |
return ret | |
def forward(self, n, device): | |
q_pos = torch.arange(n, dtype=torch.long, device=device) | |
k_pos = torch.arange(n, dtype=torch.long, device=device) | |
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') | |
rp_bucket = self._relative_position_bucket( | |
rel_pos, | |
num_buckets=self.num_buckets, | |
max_distance=self.max_distance) | |
values = self.relative_attention_bias(rp_bucket) | |
return rearrange(values, 'i j h -> h i j') | |
class SpatialTransformer(nn.Module): | |
""" | |
Transformer block for image-like data. | |
First, project the input (aka embedding) | |
and reshape to b, t, d. | |
Then apply standard transformer action. | |
Finally, reshape to image | |
NEW: use_linear for more efficiency instead of the 1x1 convs | |
""" | |
def __init__(self, | |
in_channels, | |
n_heads, | |
d_head, | |
depth=1, | |
dropout=0., | |
context_dim=None, | |
disable_self_attn=False, | |
use_linear=False, | |
use_checkpoint=True, | |
is_ctrl=False): | |
super().__init__() | |
if exists(context_dim) and not isinstance(context_dim, list): | |
context_dim = [context_dim] | |
self.in_channels = in_channels | |
inner_dim = n_heads * d_head | |
self.norm = torch.nn.GroupNorm( | |
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
if not use_linear: | |
self.proj_in = nn.Conv2d( | |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0) | |
else: | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
self.transformer_blocks = nn.ModuleList([ | |
BasicTransformerBlock( | |
inner_dim, | |
n_heads, | |
d_head, | |
dropout=dropout, | |
context_dim=context_dim[d], | |
disable_self_attn=disable_self_attn, | |
checkpoint=use_checkpoint, | |
local_type='space', | |
is_ctrl=is_ctrl) for d in range(depth) | |
]) | |
if not use_linear: | |
self.proj_out = zero_module( | |
nn.Conv2d( | |
inner_dim, in_channels, kernel_size=1, stride=1, | |
padding=0)) | |
else: | |
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) | |
self.use_linear = use_linear | |
def forward(self, x, context=None): | |
# note: if no context is given, cross-attention defaults to self-attention | |
if not isinstance(context, list): | |
context = [context] | |
_, _, h, w = x.shape | |
# print('x shape:', x.shape) # [64, 320, 90, 160] | |
x_in = x | |
x = self.norm(x) | |
if not self.use_linear: | |
x = self.proj_in(x) | |
x = rearrange(x, 'b c h w -> b (h w) c').contiguous() | |
if self.use_linear: | |
x = self.proj_in(x) | |
for i, block in enumerate(self.transformer_blocks): | |
x = block(x, context=context[i], h=h, w=w) | |
if self.use_linear: | |
x = self.proj_out(x) | |
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() | |
if not self.use_linear: | |
x = self.proj_out(x) | |
return x + x_in | |
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32') | |
class CrossAttention(nn.Module): | |
def __init__(self, | |
query_dim, | |
context_dim=None, | |
heads=8, | |
dim_head=64, | |
dropout=0.): | |
super().__init__() | |
inner_dim = dim_head * heads | |
context_dim = default(context_dim, query_dim) | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(context_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) | |
def forward(self, x, context=None, mask=None): | |
h = self.heads | |
q = self.to_q(x) | |
context = default(context, x) | |
k = self.to_k(context) | |
v = self.to_v(context) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), | |
(q, k, v)) | |
# force cast to fp32 to avoid overflowing | |
if _ATTN_PRECISION == 'fp32': | |
with torch.autocast(enabled=False, device_type='cuda'): | |
q, k = q.float(), k.float() | |
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale | |
else: | |
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale | |
del q, k | |
if exists(mask): | |
mask = rearrange(mask, 'b ... -> b (...)') | |
max_neg_value = -torch.finfo(sim.dtype).max | |
mask = repeat(mask, 'b j -> (b h) () j', h=h) | |
sim.masked_fill_(~mask, max_neg_value) | |
# attention, what we cannot get enough of | |
sim = sim.softmax(dim=-1) | |
out = torch.einsum('b i j, b j d -> b i d', sim, v) | |
out = rearrange(out, '(b h) n d -> b n (h d)', h=h) | |
return self.to_out(out) | |
class SpatialAttention(nn.Module): | |
def __init__(self): | |
super(SpatialAttention, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, padding=7 // 2, bias=False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
max_out, _ = torch.max(x, dim=1, keepdim=True) | |
avg_out = torch.mean(x, dim=1, keepdim=True) | |
weight = torch.cat([max_out, avg_out], dim=1) | |
weight = self.conv1(weight) | |
out = self.sigmoid(weight) * x | |
return out | |
class TemporalLocalAttention(nn.Module): # b c t h w | |
def __init__(self, dim, kernel_size=7): | |
super(TemporalLocalAttention, self).__init__() | |
self.conv1 = nn.Linear(in_features=2, out_features=1, bias=False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
max_out, _ = torch.max(x, dim=-1, keepdim=True) | |
avg_out = torch.mean(x, dim=-1, keepdim=True) | |
weight = torch.cat([max_out, avg_out], dim=-1) | |
weight = self.conv1(weight) | |
out = self.sigmoid(weight) * x | |
return out | |
class BasicTransformerBlock(nn.Module): | |
def __init__(self, | |
dim, | |
n_heads, | |
d_head, | |
dropout=0., | |
context_dim=None, | |
gated_ff=True, | |
checkpoint=True, | |
disable_self_attn=False, | |
local_type=None, | |
is_ctrl=False): | |
super().__init__() | |
self.local_type = local_type | |
self.is_ctrl = is_ctrl | |
attn_cls = MemoryEfficientCrossAttention | |
self.disable_self_attn = disable_self_attn | |
self.attn1 = attn_cls( # self-attn | |
query_dim=dim, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout, | |
context_dim=context_dim if self.disable_self_attn else None) | |
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) | |
attn_cls2 = MemoryEfficientCrossAttention | |
self.attn2 = attn_cls2( | |
query_dim=dim, | |
context_dim=context_dim, | |
heads=n_heads, | |
dim_head=d_head, | |
dropout=dropout) | |
self.norm1 = nn.LayerNorm(dim) | |
self.norm2 = nn.LayerNorm(dim) | |
self.norm3 = nn.LayerNorm(dim) | |
self.checkpoint = checkpoint | |
if self.local_type == 'space' and self.is_ctrl: | |
self.local1 = SpatialAttention() | |
if self.local_type == 'temp' and self.is_ctrl: | |
self.local1 = TemporalLocalAttention(dim=dim) | |
self.local2 = TemporalLocalAttention(dim=dim) | |
def forward_(self, x, context=None): | |
return checkpoint(self._forward, (x, context), self.parameters(), | |
self.checkpoint) | |
def forward(self, x, context=None, h=None, w=None): | |
if self.local_type == 'space' and self.is_ctrl: # [b*t,(hw), c] | |
x_local = rearrange(x, 'b (h w) c -> b c h w', h=h) | |
x_local = self.local1(x_local) | |
x_local = rearrange(x_local, 'b c h w -> b (h w) c') | |
x = self.attn1( | |
self.norm1(x_local), | |
context=context if self.disable_self_attn else None) + x | |
x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention | |
x = self.ff(self.norm3(x)) + x | |
if self.local_type == 'temp' and self.is_ctrl: | |
# x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w) | |
x_local = self.local1(x) | |
x = self.attn1( | |
self.norm1(x_local), | |
context=context if self.disable_self_attn else None) + x | |
# x_local = rearrange(x, '(b h w) t c -> b c t h w', h=h, w=w) | |
x_local = self.local2(x) | |
x = self.attn2(self.norm2(x_local), context=context) + x | |
x = self.ff(self.norm3(x)) + x | |
# elif self.local_type == 'space' and self.is_ctrl: | |
# # print('*** use original attention ***') | |
# x = self.attn1( | |
# self.norm1(x), | |
# context=context if self.disable_self_attn else None) + x # self-attention | |
# x = self.attn2(self.norm2(x), context=context) + x # cross attention or self-attention | |
# x = self.ff(self.norm3(x)) + x | |
return x | |
# feedforward | |
class GEGLU(nn.Module): | |
def __init__(self, dim_in, dim_out): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out * 2) | |
def forward(self, x): | |
x, gate = self.proj(x).chunk(2, dim=-1) | |
return x * F.gelu(gate) | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
class FeedForward(nn.Module): | |
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
dim_out = default(dim_out, dim) | |
project_in = nn.Sequential(nn.Linear( | |
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) | |
self.net = nn.Sequential(project_in, nn.Dropout(dropout), | |
nn.Linear(inner_dim, dim_out)) | |
def forward(self, x): | |
return self.net(x) | |
class Upsample(nn.Module): | |
""" | |
An upsampling layer with an optional convolution. | |
:param channels: channels in the inputs and outputs. | |
:param use_conv: a bool determining if a convolution is applied. | |
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then | |
upsampling occurs in the inner-two dimensions. | |
""" | |
def __init__(self, | |
channels, | |
use_conv, | |
dims=2, | |
out_channels=None, | |
padding=1): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.dims = dims | |
if use_conv: | |
self.conv = nn.Conv2d( | |
self.channels, self.out_channels, 3, padding=padding) | |
def forward(self, x): | |
assert x.shape[1] == self.channels | |
if self.dims == 3: | |
x = F.interpolate( | |
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), | |
mode='nearest') | |
else: | |
x = F.interpolate(x, scale_factor=2, mode='nearest') | |
x = x[..., 1:-1, :] | |
if self.use_conv: | |
x = self.conv(x) | |
return x | |
class ResBlock(nn.Module): | |
""" | |
A residual block that can optionally change the number of channels. | |
:param channels: the number of input channels. | |
:param emb_channels: the number of timestep embedding channels. | |
:param dropout: the rate of dropout. | |
:param out_channels: if specified, the number of out channels. | |
:param use_conv: if True and out_channels is specified, use a spatial | |
convolution instead of a smaller 1x1 convolution to change the | |
channels in the skip connection. | |
:param dims: determines if the signal is 1D, 2D, or 3D. | |
:param use_checkpoint: if True, use gradient checkpointing on this module. | |
:param up: if True, use this block for upsampling. | |
:param down: if True, use this block for downsampling. | |
""" | |
def __init__( | |
self, | |
channels, | |
emb_channels, | |
dropout, | |
out_channels=None, | |
use_conv=False, | |
use_scale_shift_norm=False, | |
dims=2, | |
up=False, | |
down=False, | |
use_temporal_conv=True, | |
use_image_dataset=False, | |
): | |
super().__init__() | |
self.channels = channels | |
self.emb_channels = emb_channels | |
self.dropout = dropout | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.use_scale_shift_norm = use_scale_shift_norm | |
self.use_temporal_conv = use_temporal_conv | |
self.in_layers = nn.Sequential( | |
nn.GroupNorm(32, channels), | |
nn.SiLU(), | |
nn.Conv2d(channels, self.out_channels, 3, padding=1), | |
) | |
self.updown = up or down | |
if up: | |
self.h_upd = Upsample(channels, False, dims) | |
self.x_upd = Upsample(channels, False, dims) | |
elif down: | |
self.h_upd = Downsample(channels, False, dims) | |
self.x_upd = Downsample(channels, False, dims) | |
else: | |
self.h_upd = self.x_upd = nn.Identity() | |
self.emb_layers = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear( | |
emb_channels, | |
2 * self.out_channels | |
if use_scale_shift_norm else self.out_channels, | |
), | |
) | |
self.out_layers = nn.Sequential( | |
nn.GroupNorm(32, self.out_channels), | |
nn.SiLU(), | |
nn.Dropout(p=dropout), | |
zero_module( | |
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), | |
) | |
if self.out_channels == channels: | |
self.skip_connection = nn.Identity() | |
elif use_conv: | |
self.skip_connection = conv_nd( | |
dims, channels, self.out_channels, 3, padding=1) | |
else: | |
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) | |
if self.use_temporal_conv: | |
self.temopral_conv = TemporalConvBlock_v2( | |
self.out_channels, | |
self.out_channels, | |
dropout=0.1, | |
use_image_dataset=use_image_dataset) | |
def forward(self, x, emb, batch_size, variant_info=None): | |
""" | |
Apply the block to a Tensor, conditioned on a timestep embedding. | |
:param x: an [N x C x ...] Tensor of features. | |
:param emb: an [N x emb_channels] Tensor of timestep embeddings. | |
:return: an [N x C x ...] Tensor of outputs. | |
""" | |
return self._forward(x, emb, batch_size, variant_info) | |
def _forward(self, x, emb, batch_size, variant_info): | |
if self.updown: | |
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] | |
h = in_rest(x) | |
h = self.h_upd(h) | |
x = self.x_upd(x) | |
h = in_conv(h) | |
else: | |
h = self.in_layers(x) | |
emb_out = self.emb_layers(emb).type(h.dtype) | |
while len(emb_out.shape) < len(h.shape): | |
emb_out = emb_out[..., None] | |
if self.use_scale_shift_norm: | |
out_norm, out_rest = self.out_layers[0], self.out_layers[1:] | |
scale, shift = th.chunk(emb_out, 2, dim=1) | |
h = out_norm(h) * (1 + scale) + shift | |
h = out_rest(h) | |
else: | |
h = h + emb_out | |
h = self.out_layers(h) | |
h = self.skip_connection(x) + h | |
if self.use_temporal_conv: | |
h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) | |
h = self.temopral_conv(h, variant_info=variant_info) | |
h = rearrange(h, 'b c f h w -> (b f) c h w') | |
return h | |
class Downsample(nn.Module): | |
""" | |
A downsampling layer with an optional convolution. | |
:param channels: channels in the inputs and outputs. | |
:param use_conv: a bool determining if a convolution is applied. | |
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then | |
downsampling occurs in the inner-two dimensions. | |
""" | |
def __init__(self, | |
channels, | |
use_conv, | |
dims=2, | |
out_channels=None, | |
padding=(2, 1)): | |
super().__init__() | |
self.channels = channels | |
self.out_channels = out_channels or channels | |
self.use_conv = use_conv | |
self.dims = dims | |
stride = 2 if dims != 3 else (1, 2, 2) | |
if use_conv: | |
self.op = nn.Conv2d( | |
self.channels, | |
self.out_channels, | |
3, | |
stride=stride, | |
padding=padding) | |
else: | |
assert self.channels == self.out_channels | |
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) | |
def forward(self, x): | |
assert x.shape[1] == self.channels | |
return self.op(x) | |
class Resample(nn.Module): | |
def __init__(self, in_dim, out_dim, mode): | |
assert mode in ['none', 'upsample', 'downsample'] | |
super(Resample, self).__init__() | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.mode = mode | |
def forward(self, x, reference=None): | |
if self.mode == 'upsample': | |
assert reference is not None | |
x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') | |
elif self.mode == 'downsample': | |
x = F.adaptive_avg_pool2d( | |
x, output_size=tuple(u // 2 for u in x.shape[-2:])) | |
return x | |
class ResidualBlock(nn.Module): | |
def __init__(self, | |
in_dim, | |
embed_dim, | |
out_dim, | |
use_scale_shift_norm=True, | |
mode='none', | |
dropout=0.0): | |
super(ResidualBlock, self).__init__() | |
self.in_dim = in_dim | |
self.embed_dim = embed_dim | |
self.out_dim = out_dim | |
self.use_scale_shift_norm = use_scale_shift_norm | |
self.mode = mode | |
# layers | |
self.layer1 = nn.Sequential( | |
nn.GroupNorm(32, in_dim), nn.SiLU(), | |
nn.Conv2d(in_dim, out_dim, 3, padding=1)) | |
self.resample = Resample(in_dim, in_dim, mode) | |
self.embedding = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(embed_dim, | |
out_dim * 2 if use_scale_shift_norm else out_dim)) | |
self.layer2 = nn.Sequential( | |
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |
nn.Conv2d(out_dim, out_dim, 3, padding=1)) | |
self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( | |
in_dim, out_dim, 1) | |
# zero out the last layer params | |
nn.init.zeros_(self.layer2[-1].weight) | |
def forward(self, x, e, reference=None): | |
identity = self.resample(x, reference) | |
x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) | |
e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) | |
if self.use_scale_shift_norm: | |
scale, shift = e.chunk(2, dim=1) | |
x = self.layer2[0](x) * (1 + scale) + shift | |
x = self.layer2[1:](x) | |
else: | |
x = x + e | |
x = self.layer2(x) | |
x = x + self.shortcut(identity) | |
return x | |
class AttentionBlock(nn.Module): | |
def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): | |
# consider head_dim first, then num_heads | |
num_heads = dim // head_dim if head_dim else num_heads | |
head_dim = dim // num_heads | |
assert num_heads * head_dim == dim | |
super(AttentionBlock, self).__init__() | |
self.dim = dim | |
self.context_dim = context_dim | |
self.num_heads = num_heads | |
self.head_dim = head_dim | |
self.scale = math.pow(head_dim, -0.25) | |
# layers | |
self.norm = nn.GroupNorm(32, dim) | |
self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |
if context_dim is not None: | |
self.context_kv = nn.Linear(context_dim, dim * 2) | |
self.proj = nn.Conv2d(dim, dim, 1) | |
# zero out the last layer params | |
nn.init.zeros_(self.proj.weight) | |
def forward(self, x, context=None): | |
r"""x: [B, C, H, W]. | |
context: [B, L, C] or None. | |
""" | |
identity = x | |
b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim | |
# compute query, key, value | |
x = self.norm(x) | |
q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) | |
if context is not None: | |
ck, cv = self.context_kv(context).reshape(b, -1, n * 2, | |
d).permute(0, 2, 3, | |
1).chunk( | |
2, dim=1) | |
k = torch.cat([ck, k], dim=-1) | |
v = torch.cat([cv, v], dim=-1) | |
# compute attention | |
attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) | |
attn = F.softmax(attn, dim=-1) | |
# gather context | |
x = torch.matmul(v, attn.transpose(-1, -2)) | |
x = x.reshape(b, c, h, w) | |
# output | |
x = self.proj(x) | |
return x + identity | |
class TemporalAttentionBlock(nn.Module): | |
def __init__(self, | |
dim, | |
heads=4, | |
dim_head=32, | |
rotary_emb=None, | |
use_image_dataset=False, | |
use_sim_mask=False): | |
super().__init__() | |
# consider num_heads first, as pos_bias needs fixed num_heads | |
dim_head = dim // heads | |
assert heads * dim_head == dim | |
self.use_image_dataset = use_image_dataset | |
self.use_sim_mask = use_sim_mask | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.norm = nn.GroupNorm(32, dim) | |
self.rotary_emb = rotary_emb | |
self.to_qkv = nn.Linear(dim, hidden_dim * 3) | |
self.to_out = nn.Linear(hidden_dim, dim) | |
def forward(self, | |
x, | |
pos_bias=None, | |
focus_present_mask=None, | |
video_mask=None): | |
identity = x | |
n, height, device = x.shape[2], x.shape[-2], x.device | |
x = self.norm(x) | |
x = rearrange(x, 'b c f h w -> b (h w) f c') | |
qkv = self.to_qkv(x).chunk(3, dim=-1) | |
if exists(focus_present_mask) and focus_present_mask.all(): | |
# if all batch samples are focusing on present | |
# it would be equivalent to passing that token's values (v=qkv[-1]) through to the output | |
values = qkv[-1] | |
out = self.to_out(values) | |
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) | |
return out + identity | |
# split out heads | |
q = rearrange(qkv[0], '... n (h d) -> ... h n d', h=self.heads) | |
k = rearrange(qkv[1], '... n (h d) -> ... h n d', h=self.heads) | |
v = rearrange(qkv[2], '... n (h d) -> ... h n d', h=self.heads) | |
# scale | |
q = q * self.scale | |
# rotate positions into queries and keys for time attention | |
if exists(self.rotary_emb): | |
q = self.rotary_emb.rotate_queries_or_keys(q) | |
k = self.rotary_emb.rotate_queries_or_keys(k) | |
# similarity | |
# shape [b (hw) h n n], n=f | |
sim = torch.einsum('... h i d, ... h j d -> ... h i j', q, k) | |
# relative positional bias | |
if exists(pos_bias): | |
sim = sim + pos_bias | |
if (focus_present_mask is None and video_mask is not None): | |
# video_mask: [B, n] | |
mask = video_mask[:, None, :] * video_mask[:, :, None] | |
mask = mask.unsqueeze(1).unsqueeze(1) | |
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) | |
elif exists(focus_present_mask) and not (~focus_present_mask).all(): | |
attend_all_mask = torch.ones((n, n), | |
device=device, | |
dtype=torch.bool) | |
attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) | |
mask = torch.where( | |
rearrange(focus_present_mask, 'b -> b 1 1 1 1'), | |
rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), | |
rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), | |
) | |
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) | |
if self.use_sim_mask: | |
sim_mask = torch.tril( | |
torch.ones((n, n), device=device, dtype=torch.bool), | |
diagonal=0) | |
sim = sim.masked_fill(~sim_mask, -torch.finfo(sim.dtype).max) | |
# numerical stability | |
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
attn = sim.softmax(dim=-1) | |
# aggregate values | |
out = torch.einsum('... h i j, ... h j d -> ... h i d', attn, v) | |
out = rearrange(out, '... h n d -> ... n (h d)') | |
out = self.to_out(out) | |
out = rearrange(out, 'b (h w) f c -> b c f h w', h=height) | |
if self.use_image_dataset: | |
out = identity + 0 * out | |
else: | |
out = identity + out | |
return out | |
class TemporalTransformer(nn.Module): | |
""" | |
Transformer block for image-like data. | |
First, project the input (aka embedding) | |
and reshape to b, t, d. | |
Then apply standard transformer action. | |
Finally, reshape to image | |
""" | |
def __init__(self, | |
in_channels, | |
n_heads, | |
d_head, | |
depth=1, | |
dropout=0., | |
context_dim=None, | |
disable_self_attn=False, | |
use_linear=False, | |
use_checkpoint=True, | |
only_self_att=True, | |
multiply_zero=False, | |
is_ctrl=False): | |
super().__init__() | |
self.multiply_zero = multiply_zero | |
self.only_self_att = only_self_att | |
self.use_adaptor = False | |
if self.only_self_att: | |
context_dim = None | |
if not isinstance(context_dim, list): | |
context_dim = [context_dim] | |
self.in_channels = in_channels | |
inner_dim = n_heads * d_head | |
self.norm = torch.nn.GroupNorm( | |
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) | |
if not use_linear: | |
self.proj_in = nn.Conv1d( | |
in_channels, inner_dim, kernel_size=1, stride=1, padding=0) | |
else: | |
self.proj_in = nn.Linear(in_channels, inner_dim) | |
if self.use_adaptor: | |
self.adaptor_in = nn.Linear(frames, frames) | |
self.transformer_blocks = nn.ModuleList([ | |
BasicTransformerBlock( | |
inner_dim, | |
n_heads, | |
d_head, | |
dropout=dropout, | |
context_dim=context_dim[d], | |
checkpoint=use_checkpoint, | |
local_type='temp', | |
is_ctrl=is_ctrl) for d in range(depth) | |
]) | |
if not use_linear: | |
self.proj_out = zero_module( | |
nn.Conv1d( | |
inner_dim, in_channels, kernel_size=1, stride=1, | |
padding=0)) | |
else: | |
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) | |
if self.use_adaptor: | |
self.adaptor_out = nn.Linear(frames, frames) | |
self.use_linear = use_linear | |
def forward(self, x, context=None): | |
# note: if no context is given, cross-attention defaults to self-attention | |
if self.only_self_att: | |
context = None | |
if not isinstance(context, list): | |
context = [context] | |
b, _, _, h, w = x.shape | |
x_in = x | |
x = self.norm(x) | |
if not self.use_linear: | |
x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() | |
x = self.proj_in(x) | |
if self.use_linear: | |
x = rearrange( | |
x, 'b c f h w -> (b h w) f c').contiguous() | |
x = self.proj_in(x) | |
x = rearrange( | |
x, 'bhw f c -> bhw c f').contiguous() | |
# print('x shape:', x.shape) # [28800, 512, 32] | |
if self.only_self_att: # no cross-attention | |
x = rearrange(x, 'bhw c f -> bhw f c').contiguous() | |
for i, block in enumerate(self.transformer_blocks): | |
x = block(x, h=h, w=w) | |
# print('x shape:', x.shape) # [43200, 32, 512] | |
x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() | |
else: | |
x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() | |
for i, block in enumerate(self.transformer_blocks): | |
context[i] = rearrange( | |
context[i], '(b f) l con -> b f l con', | |
f=self.frames).contiguous() | |
# calculate each batch one by one (since number in shape could not greater then 65,535 for some package) | |
for j in range(b): | |
context_i_j = repeat( | |
context[i][j], | |
'f l con -> (f r) l con', | |
r=(h * w) // self.frames, | |
f=self.frames).contiguous() | |
x[j] = block(x[j], context=context_i_j) | |
if self.use_linear: | |
x = rearrange(x, 'b hw f c -> (b hw) f c').contiguous() | |
x = self.proj_out(x) | |
x = rearrange( | |
x, '(b h w) f c -> b c f h w', b=b, h=h, w=w).contiguous() | |
if not self.use_linear: | |
# print('x shape:', x.shape) # [2, 21600, 32, 512] | |
x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() | |
x = self.proj_out(x) | |
x = rearrange( | |
x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() | |
if self.multiply_zero: | |
x = 0.0 * x + x_in | |
else: | |
x = x + x_in | |
return x | |
class TemporalAttentionMultiBlock(nn.Module): | |
def __init__( | |
self, | |
dim, | |
heads=4, | |
dim_head=32, | |
rotary_emb=None, | |
use_image_dataset=False, | |
use_sim_mask=False, | |
temporal_attn_times=1, | |
): | |
super().__init__() | |
self.att_layers = nn.ModuleList([ | |
TemporalAttentionBlock(dim, heads, dim_head, rotary_emb, | |
use_image_dataset, use_sim_mask) | |
for _ in range(temporal_attn_times) | |
]) | |
def forward(self, | |
x, | |
pos_bias=None, | |
focus_present_mask=None, | |
video_mask=None): | |
for layer in self.att_layers: | |
x = layer(x, pos_bias, focus_present_mask, video_mask) | |
return x | |
class InitTemporalConvBlock(nn.Module): | |
def __init__(self, | |
in_dim, | |
out_dim=None, | |
dropout=0.0, | |
use_image_dataset=False): | |
super(InitTemporalConvBlock, self).__init__() | |
if out_dim is None: | |
out_dim = in_dim | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.use_image_dataset = use_image_dataset | |
# conv layers | |
self.conv = nn.Sequential( | |
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) | |
# zero out the last layer params,so the conv block is identity | |
nn.init.zeros_(self.conv[-1].weight) | |
nn.init.zeros_(self.conv[-1].bias) | |
def forward(self, x): | |
identity = x | |
x = self.conv(x) | |
if self.use_image_dataset: | |
x = identity + 0 * x | |
else: | |
x = identity + x | |
return x | |
class TemporalConvBlock(nn.Module): | |
def __init__(self, | |
in_dim, | |
out_dim=None, | |
dropout=0.0, | |
use_image_dataset=False): | |
super(TemporalConvBlock, self).__init__() | |
if out_dim is None: | |
out_dim = in_dim | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.use_image_dataset = use_image_dataset | |
# conv layers | |
self.conv1 = nn.Sequential( | |
nn.GroupNorm(32, in_dim), nn.SiLU(), | |
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) | |
self.conv2 = nn.Sequential( | |
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) | |
# zero out the last layer params,so the conv block is identity | |
nn.init.zeros_(self.conv2[-1].weight) | |
nn.init.zeros_(self.conv2[-1].bias) | |
def forward(self, x): | |
identity = x | |
x = self.conv1(x) | |
x = self.conv2(x) | |
if self.use_image_dataset: | |
x = identity + 0 * x | |
else: | |
x = identity + x | |
return x | |
class TemporalConvBlock_v2(nn.Module): | |
def __init__(self, | |
in_dim, | |
out_dim=None, | |
dropout=0.0, | |
use_image_dataset=False): | |
super(TemporalConvBlock_v2, self).__init__() | |
if out_dim is None: | |
out_dim = in_dim | |
self.in_dim = in_dim | |
self.out_dim = out_dim | |
self.use_image_dataset = use_image_dataset | |
# conv layers | |
self.conv1 = nn.Sequential( | |
nn.GroupNorm(32, in_dim), nn.SiLU(), | |
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) | |
self.conv2 = nn.Sequential( | |
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) | |
self.conv3 = nn.Sequential( | |
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) | |
self.conv4 = nn.Sequential( | |
nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |
nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) | |
# zero out the last layer params,so the conv block is identity | |
nn.init.zeros_(self.conv4[-1].weight) | |
nn.init.zeros_(self.conv4[-1].bias) | |
def forward(self, x, variant_info=None): | |
if variant_info is not None and variant_info.get('type') == 'variant2': | |
# print(x.shape) # torch.Size([1, 320, 32, 90, 160]) | |
_, _, f, _, _ = x.shape | |
assert f % 4 == 0, "f must be divisible by 4" | |
x_short = rearrange(x, "b c (n s) h w -> (n b) c s h w", n=4) | |
x_short = self.conv1(x_short) | |
x_short = self.conv2(x_short) | |
x_short = self.conv3(x_short) | |
x_short = self.conv4(x_short) | |
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4) | |
identity = x | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
x = self.conv4(x) | |
x = x * (1-variant_info['alpha']) + x_short * variant_info['alpha'] | |
elif variant_info is not None and variant_info.get('type') == 'variant1': | |
identity = x | |
x_long, x_short = x.chunk(2, dim=0) | |
x_short = rearrange(x_short, "b c (n s) h w -> (n b) c s h w", n=4) | |
x_short = self.conv1(x_short) | |
x_short = self.conv2(x_short) | |
x_short = self.conv3(x_short) | |
x_short = self.conv4(x_short) | |
x_short = rearrange(x_short, "(n b) c s h w -> b c (n s) h w", n=4) | |
x_long = self.conv1(x_long) | |
x_long = self.conv2(x_long) | |
x_long = self.conv3(x_long) | |
x_long = self.conv4(x_long) | |
x = torch.cat([x_long, x_short], dim=0) | |
elif variant_info is None: | |
identity = x | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
x = self.conv4(x) | |
if self.use_image_dataset: | |
x = identity + 0.0 * x | |
else: | |
x = identity + x | |
return x | |
class Vid2VidSDUNet(nn.Module): | |
def __init__(self, | |
in_dim=4, | |
dim=320, | |
y_dim=1024, | |
context_dim=1024, | |
out_dim=4, | |
dim_mult=[1, 2, 4, 4], | |
num_heads=8, | |
head_dim=64, | |
num_res_blocks=2, | |
attn_scales=[1 / 1, 1 / 2, 1 / 4], | |
use_scale_shift_norm=True, | |
dropout=0.1, | |
temporal_attn_times=1, | |
temporal_attention=True, | |
use_checkpoint=True, | |
use_image_dataset=False, | |
use_fps_condition=False, | |
use_sim_mask=False, | |
training=False, | |
inpainting=True): | |
embed_dim = dim * 4 | |
num_heads = num_heads if num_heads else dim // 32 | |
super(Vid2VidSDUNet, self).__init__() | |
self.in_dim = in_dim | |
self.dim = dim | |
self.y_dim = y_dim | |
self.context_dim = context_dim | |
self.embed_dim = embed_dim | |
self.out_dim = out_dim | |
self.dim_mult = dim_mult | |
# for temporal attention | |
self.num_heads = num_heads | |
# for spatial attention | |
self.head_dim = head_dim | |
self.num_res_blocks = num_res_blocks | |
self.attn_scales = attn_scales | |
self.use_scale_shift_norm = use_scale_shift_norm | |
self.temporal_attn_times = temporal_attn_times | |
self.temporal_attention = temporal_attention | |
self.use_checkpoint = use_checkpoint | |
self.use_image_dataset = use_image_dataset | |
self.use_fps_condition = use_fps_condition | |
self.use_sim_mask = use_sim_mask | |
self.training = training | |
self.inpainting = inpainting | |
use_linear_in_temporal = False | |
transformer_depth = 1 | |
disabled_sa = False | |
# params | |
enc_dims = [dim * u for u in [1] + dim_mult] | |
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |
shortcut_dims = [] | |
scale = 1.0 | |
# embeddings | |
self.time_embed = nn.Sequential( | |
nn.Linear(dim, embed_dim), nn.SiLU(), | |
nn.Linear(embed_dim, embed_dim)) | |
if self.use_fps_condition: | |
self.fps_embedding = nn.Sequential( | |
nn.Linear(dim, embed_dim), nn.SiLU(), | |
nn.Linear(embed_dim, embed_dim)) | |
nn.init.zeros_(self.fps_embedding[-1].weight) | |
nn.init.zeros_(self.fps_embedding[-1].bias) | |
# encoder | |
self.input_blocks = nn.ModuleList() | |
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) | |
# need an initial temporal attention? | |
if temporal_attention: | |
if USE_TEMPORAL_TRANSFORMER: | |
init_block.append( | |
TemporalTransformer( | |
dim, | |
num_heads, | |
head_dim, | |
depth=transformer_depth, | |
context_dim=context_dim, | |
disable_self_attn=disabled_sa, | |
use_linear=use_linear_in_temporal, | |
multiply_zero=use_image_dataset, | |
is_ctrl=True | |
)) | |
else: | |
init_block.append( | |
TemporalAttentionMultiBlock( | |
dim, | |
num_heads, | |
head_dim, | |
rotary_emb=self.rotary_emb, | |
temporal_attn_times=temporal_attn_times, | |
use_image_dataset=use_image_dataset)) | |
self.input_blocks.append(init_block) | |
shortcut_dims.append(dim) | |
for i, (in_dim, | |
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): | |
for j in range(num_res_blocks): | |
block = nn.ModuleList([ | |
ResBlock( | |
in_dim, | |
embed_dim, | |
dropout, | |
out_channels=out_dim, | |
use_scale_shift_norm=False, | |
use_image_dataset=use_image_dataset, | |
) | |
]) | |
if scale in attn_scales: | |
block.append( | |
SpatialTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=1, | |
context_dim=self.context_dim, | |
disable_self_attn=False, | |
use_linear=True, | |
is_ctrl=True | |
)) | |
if self.temporal_attention: | |
if USE_TEMPORAL_TRANSFORMER: | |
block.append( | |
TemporalTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=transformer_depth, | |
context_dim=context_dim, | |
disable_self_attn=disabled_sa, | |
use_linear=use_linear_in_temporal, | |
multiply_zero=use_image_dataset, | |
is_ctrl=True | |
)) | |
else: | |
block.append( | |
TemporalAttentionMultiBlock( | |
out_dim, | |
num_heads, | |
head_dim, | |
rotary_emb=self.rotary_emb, | |
use_image_dataset=use_image_dataset, | |
use_sim_mask=use_sim_mask, | |
temporal_attn_times=temporal_attn_times)) | |
in_dim = out_dim | |
self.input_blocks.append(block) | |
shortcut_dims.append(out_dim) | |
# downsample | |
if i != len(dim_mult) - 1 and j == num_res_blocks - 1: | |
downsample = Downsample( | |
out_dim, True, dims=2, out_channels=out_dim) | |
shortcut_dims.append(out_dim) | |
scale /= 2.0 | |
self.input_blocks.append(downsample) | |
self.middle_block = nn.ModuleList([ | |
ResBlock( | |
out_dim, | |
embed_dim, | |
dropout, | |
use_scale_shift_norm=False, | |
use_image_dataset=use_image_dataset, | |
), | |
SpatialTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=1, | |
context_dim=self.context_dim, | |
disable_self_attn=False, | |
use_linear=True, | |
is_ctrl=True | |
) | |
]) | |
if self.temporal_attention: | |
if USE_TEMPORAL_TRANSFORMER: | |
self.middle_block.append( | |
TemporalTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=transformer_depth, | |
context_dim=context_dim, | |
disable_self_attn=disabled_sa, | |
use_linear=use_linear_in_temporal, | |
multiply_zero=use_image_dataset, | |
is_ctrl=True | |
)) | |
else: | |
self.middle_block.append( | |
TemporalAttentionMultiBlock( | |
out_dim, | |
num_heads, | |
head_dim, | |
rotary_emb=self.rotary_emb, | |
use_image_dataset=use_image_dataset, | |
use_sim_mask=use_sim_mask, | |
temporal_attn_times=temporal_attn_times)) | |
self.middle_block.append( | |
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) | |
# decoder | |
self.output_blocks = nn.ModuleList() | |
for i, (in_dim, | |
out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): | |
for j in range(num_res_blocks + 1): | |
block = nn.ModuleList([ | |
ResBlock( | |
in_dim + shortcut_dims.pop(), | |
embed_dim, | |
dropout, | |
out_dim, | |
use_scale_shift_norm=False, | |
use_image_dataset=use_image_dataset, | |
) | |
]) | |
if scale in attn_scales: | |
block.append( | |
SpatialTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=1, | |
context_dim=1024, | |
disable_self_attn=False, | |
use_linear=True, | |
is_ctrl=True)) | |
if self.temporal_attention: | |
if USE_TEMPORAL_TRANSFORMER: | |
block.append( | |
TemporalTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=transformer_depth, | |
context_dim=context_dim, | |
disable_self_attn=disabled_sa, | |
use_linear=use_linear_in_temporal, | |
multiply_zero=use_image_dataset, | |
is_ctrl=True)) | |
else: | |
block.append( | |
TemporalAttentionMultiBlock( | |
out_dim, | |
num_heads, | |
head_dim, | |
rotary_emb=self.rotary_emb, | |
use_image_dataset=use_image_dataset, | |
use_sim_mask=use_sim_mask, | |
temporal_attn_times=temporal_attn_times)) | |
in_dim = out_dim | |
# upsample | |
if i != len(dim_mult) - 1 and j == num_res_blocks: | |
upsample = Upsample( | |
out_dim, True, dims=2.0, out_channels=out_dim) | |
scale *= 2.0 | |
block.append(upsample) | |
self.output_blocks.append(block) | |
# head | |
self.out = nn.Sequential( | |
nn.GroupNorm(32, out_dim), nn.SiLU(), | |
nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) | |
# zero out the last layer params | |
nn.init.zeros_(self.out[-1].weight) | |
def forward(self, | |
x, | |
t, | |
y, | |
x_lr=None, | |
fps=None, | |
video_mask=None, | |
focus_present_mask=None, | |
prob_focus_present=0., | |
mask_last_frame_num=0): | |
batch, c, f, h, w = x.shape | |
device = x.device | |
self.batch = batch | |
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored | |
if mask_last_frame_num > 0: | |
focus_present_mask = None | |
video_mask[-mask_last_frame_num:] = False | |
else: | |
focus_present_mask = default( | |
focus_present_mask, lambda: prob_mask_like( | |
(batch, ), prob_focus_present, device=device)) | |
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: | |
time_rel_pos_bias = self.time_rel_pos_bias( | |
x.shape[2], device=x.device) | |
else: | |
time_rel_pos_bias = None | |
# embeddings | |
e = self.time_embed(sinusoidal_embedding(t, self.dim)) | |
context = y | |
# repeat f times for spatial e and context | |
e = e.repeat_interleave(repeats=f, dim=0) | |
context = context.repeat_interleave(repeats=f, dim=0) | |
# always in shape (b f) c h w, except for temporal layer | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
# encoder | |
xs = [] | |
for ind, block in enumerate(self.input_blocks): | |
x = self._forward_single(block, x, e, context, time_rel_pos_bias, | |
focus_present_mask, video_mask) | |
xs.append(x) | |
# middle | |
for block in self.middle_block: | |
x = self._forward_single(block, x, e, context, time_rel_pos_bias, | |
focus_present_mask, video_mask) | |
# decoder | |
for block in self.output_blocks: | |
x = torch.cat([x, xs.pop()], dim=1) | |
x = self._forward_single( | |
block, | |
x, | |
e, | |
context, | |
time_rel_pos_bias, | |
focus_present_mask, | |
video_mask, | |
reference=xs[-1] if len(xs) > 0 else None) | |
# head | |
x = self.out(x) | |
# reshape back to (b c f h w) | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) | |
return x | |
def _forward_single(self, | |
module, | |
x, | |
e, | |
context, | |
time_rel_pos_bias, | |
focus_present_mask, | |
video_mask, | |
reference=None): | |
if isinstance(module, ResidualBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = x.contiguous() | |
x = module(x, e, reference) | |
elif isinstance(module, ResBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = x.contiguous() | |
x = module(x, e, self.batch) | |
elif isinstance(module, SpatialTransformer): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, TemporalTransformer): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, context) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, CrossAttention): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, MemoryEfficientCrossAttention): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, BasicTransformerBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, FeedForward): | |
x = module(x, context) | |
elif isinstance(module, Upsample): | |
x = module(x) | |
elif isinstance(module, Downsample): | |
x = module(x) | |
elif isinstance(module, Resample): | |
x = module(x, reference) | |
elif isinstance(module, TemporalAttentionBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, TemporalAttentionMultiBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, InitTemporalConvBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, TemporalConvBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, nn.ModuleList): | |
for block in module: | |
x = self._forward_single(block, x, e, context, | |
time_rel_pos_bias, focus_present_mask, | |
video_mask, reference) | |
else: | |
x = module(x) | |
return x | |
class ControlledV2VUNet(Vid2VidSDUNet): | |
def __init__(self): | |
super(ControlledV2VUNet, self).__init__() | |
self.VideoControlNet = VideoControlNet() | |
def forward(self, | |
x, | |
t, | |
y, | |
hint=None, | |
variant_info=None, | |
hint_chunk=None, | |
t_hint=None, | |
s_cond=None, | |
mask_cond=None, | |
x_lr=None, | |
fps=None, | |
mask=None, | |
video_mask=None, | |
focus_present_mask=None, | |
prob_focus_present=0., | |
mask_last_frame_num=0, | |
): | |
batch, _, f, _, _= x.shape | |
device = x.device | |
self.batch = batch | |
# Process text (new added for t5 encoder) | |
# y = self.VideoControlNet.y_embedder(y, self.training).squeeze(1) # [1, 1, 120, 4096] -> [B, 1, 120, 1024].squeeze(1) -> [B, 120, 1024] | |
if hint_chunk is not None: | |
hint = hint_chunk | |
control = self.VideoControlNet(x, t, y, hint=hint, t_hint=t_hint, \ | |
mask_cond=mask_cond, s_cond=s_cond, \ | |
variant_info=variant_info) | |
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored | |
if mask_last_frame_num > 0: | |
focus_present_mask = None | |
video_mask[-mask_last_frame_num:] = False | |
else: | |
focus_present_mask = default( | |
focus_present_mask, lambda: prob_mask_like( | |
(batch, ), prob_focus_present, device=device)) | |
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: | |
time_rel_pos_bias = self.time_rel_pos_bias( | |
x.shape[2], device=x.device) | |
else: | |
time_rel_pos_bias = None | |
e = self.time_embed(sinusoidal_embedding(t, self.dim)) | |
e = e.repeat_interleave(repeats=f, dim=0) | |
# context = y | |
context = y.repeat_interleave(repeats=f, dim=0) | |
# always in shape (b f) c h w, except for temporal layer | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
# encoder | |
xs = [] | |
for block in self.input_blocks: | |
x = self._forward_single(block, x, e, context, time_rel_pos_bias, | |
focus_present_mask, video_mask, variant_info=variant_info) | |
xs.append(x) | |
# middle | |
for block in self.middle_block: | |
x = self._forward_single(block, x, e, context, time_rel_pos_bias, | |
focus_present_mask, video_mask, variant_info=variant_info) | |
if control is not None: | |
x = control.pop() + x | |
# decoder | |
for block in self.output_blocks: | |
if control is None: | |
x = torch.cat([x, xs.pop()], dim=1) | |
else: | |
x = torch.cat([x, xs.pop() + control.pop()], dim=1) | |
x = self._forward_single( | |
block, | |
x, | |
e, | |
context, | |
time_rel_pos_bias, | |
focus_present_mask, | |
video_mask, | |
reference=xs[-1] if len(xs) > 0 else None, | |
variant_info=variant_info) | |
# head | |
x = self.out(x) | |
# reshape back to (b c f h w) | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) | |
return x | |
def _forward_single(self, | |
module, | |
x, | |
e, | |
context, | |
time_rel_pos_bias, | |
focus_present_mask, | |
video_mask, | |
reference=None, | |
variant_info=None): | |
variant_info = None # For Debug | |
if isinstance(module, ResidualBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = x.contiguous() | |
x = module(x, e, reference) | |
elif isinstance(module, ResBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = x.contiguous() | |
x = module(x, e, self.batch, variant_info) | |
elif isinstance(module, SpatialTransformer): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, TemporalTransformer): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, context) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, CrossAttention): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, MemoryEfficientCrossAttention): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, BasicTransformerBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, FeedForward): | |
x = module(x, context) | |
elif isinstance(module, Upsample): | |
x = module(x) | |
elif isinstance(module, Downsample): | |
x = module(x) | |
elif isinstance(module, Resample): | |
x = module(x, reference) | |
elif isinstance(module, TemporalAttentionBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, TemporalAttentionMultiBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, InitTemporalConvBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, TemporalConvBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, nn.ModuleList): | |
for block in module: | |
x = self._forward_single(block, x, e, context, | |
time_rel_pos_bias, focus_present_mask, | |
video_mask, reference, variant_info) | |
else: | |
x = module(x) | |
return x | |
class VideoControlNet(nn.Module): | |
def __init__(self, | |
in_dim=4, | |
dim=320, | |
y_dim=1024, | |
context_dim=1024, | |
out_dim=4, | |
dim_mult=[1, 2, 4, 4], | |
num_heads=8, | |
head_dim=64, | |
num_res_blocks=2, | |
attn_scales=[1 / 1, 1 / 2, 1 / 4], | |
use_scale_shift_norm=True, | |
dropout=0.1, | |
temporal_attn_times=1, | |
temporal_attention=True, | |
use_checkpoint=True, | |
use_image_dataset=False, | |
use_fps_condition=False, | |
use_sim_mask=False, | |
training=False, | |
inpainting=True): | |
embed_dim = dim * 4 | |
num_heads = num_heads if num_heads else dim // 32 | |
super(VideoControlNet, self).__init__() | |
self.in_dim = in_dim | |
self.dim = dim | |
self.y_dim = y_dim | |
self.context_dim = context_dim | |
self.embed_dim = embed_dim | |
self.out_dim = out_dim | |
self.dim_mult = dim_mult | |
# for temporal attention | |
self.num_heads = num_heads | |
# for spatial attention | |
self.head_dim = head_dim | |
self.num_res_blocks = num_res_blocks | |
self.attn_scales = attn_scales | |
self.use_scale_shift_norm = use_scale_shift_norm | |
self.temporal_attn_times = temporal_attn_times | |
self.temporal_attention = temporal_attention | |
self.use_checkpoint = use_checkpoint | |
self.use_image_dataset = use_image_dataset | |
self.use_fps_condition = use_fps_condition | |
self.use_sim_mask = use_sim_mask | |
self.training = training | |
self.inpainting = inpainting | |
use_linear_in_temporal = False | |
transformer_depth = 1 | |
disabled_sa = False | |
# params | |
enc_dims = [dim * u for u in [1] + dim_mult] | |
dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |
shortcut_dims = [] | |
scale = 1.0 | |
# CaptionEmbedder (new add) | |
# approx_gelu = lambda: nn.GELU(approximate="tanh") | |
# self.y_embedder = CaptionEmbedder( | |
# in_channels=4096, | |
# hidden_size=1024, | |
# uncond_prob=0.1, | |
# act_layer=approx_gelu, | |
# token_num=120, | |
# ) | |
# embeddings | |
self.time_embed = nn.Sequential( | |
nn.Linear(dim, embed_dim), nn.SiLU(), | |
nn.Linear(embed_dim, embed_dim)) | |
# self.hint_time_zero_linear = zero_module(nn.Linear(embed_dim, embed_dim)) | |
# scale prompt | |
# self.scale_cond = nn.Sequential( | |
# nn.Linear(dim, embed_dim), nn.SiLU(), | |
# zero_module(nn.Linear(embed_dim, embed_dim))) | |
if self.use_fps_condition: | |
self.fps_embedding = nn.Sequential( | |
nn.Linear(dim, embed_dim), nn.SiLU(), | |
nn.Linear(embed_dim, embed_dim)) | |
nn.init.zeros_(self.fps_embedding[-1].weight) | |
nn.init.zeros_(self.fps_embedding[-1].bias) | |
# encoder | |
self.input_blocks = nn.ModuleList() | |
init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) | |
# need an initial temporal attention? | |
if temporal_attention: | |
if USE_TEMPORAL_TRANSFORMER: | |
init_block.append( | |
TemporalTransformer( | |
dim, | |
num_heads, | |
head_dim, | |
depth=transformer_depth, | |
context_dim=context_dim, | |
disable_self_attn=disabled_sa, | |
use_linear=use_linear_in_temporal, | |
multiply_zero=use_image_dataset, | |
is_ctrl=True,)) | |
else: | |
init_block.append( | |
TemporalAttentionMultiBlock( | |
dim, | |
num_heads, | |
head_dim, | |
rotary_emb=self.rotary_emb, | |
temporal_attn_times=temporal_attn_times, | |
use_image_dataset=use_image_dataset)) | |
self.input_blocks.append(init_block) | |
self.zero_convs = nn.ModuleList([self.make_zero_conv(dim)]) | |
shortcut_dims.append(dim) | |
for i, (in_dim, | |
out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): | |
for j in range(num_res_blocks): | |
block = nn.ModuleList([ | |
ResBlock( | |
in_dim, | |
embed_dim, | |
dropout, | |
out_channels=out_dim, | |
use_scale_shift_norm=False, | |
use_image_dataset=use_image_dataset, | |
) | |
]) | |
if scale in attn_scales: | |
block.append( | |
SpatialTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=1, | |
context_dim=self.context_dim, | |
disable_self_attn=False, | |
use_linear=True, | |
is_ctrl=True)) | |
if self.temporal_attention: | |
if USE_TEMPORAL_TRANSFORMER: | |
block.append( | |
TemporalTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=transformer_depth, | |
context_dim=context_dim, | |
disable_self_attn=disabled_sa, | |
use_linear=use_linear_in_temporal, | |
multiply_zero=use_image_dataset, | |
is_ctrl=True,)) | |
else: | |
block.append( | |
TemporalAttentionMultiBlock( | |
out_dim, | |
num_heads, | |
head_dim, | |
rotary_emb=self.rotary_emb, | |
use_image_dataset=use_image_dataset, | |
use_sim_mask=use_sim_mask, | |
temporal_attn_times=temporal_attn_times)) | |
in_dim = out_dim | |
self.input_blocks.append(block) | |
self.zero_convs.append(self.make_zero_conv(out_dim)) | |
shortcut_dims.append(out_dim) | |
# downsample | |
if i != len(dim_mult) - 1 and j == num_res_blocks - 1: | |
downsample = Downsample( | |
out_dim, True, dims=2, out_channels=out_dim) | |
shortcut_dims.append(out_dim) | |
scale /= 2.0 | |
self.input_blocks.append(downsample) | |
self.zero_convs.append(self.make_zero_conv(out_dim)) | |
self.middle_block = nn.ModuleList([ | |
ResBlock( | |
out_dim, | |
embed_dim, | |
dropout, | |
use_scale_shift_norm=False, | |
use_image_dataset=use_image_dataset, | |
), | |
SpatialTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=1, | |
context_dim=self.context_dim, | |
disable_self_attn=False, | |
use_linear=True, | |
is_ctrl=True) | |
]) | |
if self.temporal_attention: | |
if USE_TEMPORAL_TRANSFORMER: | |
self.middle_block.append( | |
TemporalTransformer( | |
out_dim, | |
out_dim // head_dim, | |
head_dim, | |
depth=transformer_depth, | |
context_dim=context_dim, | |
disable_self_attn=disabled_sa, | |
use_linear=use_linear_in_temporal, | |
multiply_zero=use_image_dataset, | |
is_ctrl=True, | |
)) | |
else: | |
self.middle_block.append( | |
TemporalAttentionMultiBlock( | |
out_dim, | |
num_heads, | |
head_dim, | |
rotary_emb=self.rotary_emb, | |
use_image_dataset=use_image_dataset, | |
use_sim_mask=use_sim_mask, | |
temporal_attn_times=temporal_attn_times)) | |
self.middle_block.append( | |
ResBlock(out_dim, embed_dim, dropout, use_scale_shift_norm=False)) | |
self.middle_block_out = self.make_zero_conv(embed_dim) | |
''' | |
add prompt | |
''' | |
add_dim = 320 | |
self.add_dim = add_dim | |
self.input_hint_block = zero_module(nn.Conv2d(4, add_dim, 3, padding=1)) | |
def make_zero_conv(self, in_channels, out_channels=None): | |
out_channels = in_channels if out_channels is None else out_channels | |
return TimestepEmbedSequential(zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0))) | |
def forward(self, | |
x, | |
t, | |
y, | |
s_cond=None, | |
hint=None, | |
variant_info=None, | |
t_hint=None, | |
mask_cond=None, | |
fps=None, | |
video_mask=None, | |
focus_present_mask=None, | |
prob_focus_present=0., | |
mask_last_frame_num=0): | |
batch, _, f, _, _ = x.shape | |
device = x.device | |
self.batch = batch | |
# image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored | |
if mask_last_frame_num > 0: | |
focus_present_mask = None | |
video_mask[-mask_last_frame_num:] = False | |
else: | |
focus_present_mask = default( | |
focus_present_mask, lambda: prob_mask_like( | |
(batch, ), prob_focus_present, device=device)) | |
if self.temporal_attention and not USE_TEMPORAL_TRANSFORMER: | |
time_rel_pos_bias = self.time_rel_pos_bias( | |
x.shape[2], device=x.device) | |
else: | |
time_rel_pos_bias = None | |
if hint is not None: | |
# add = x.new_zeros(batch, self.add_dim, f, h, w) | |
hint = rearrange(hint, 'b c f h w -> (b f) c h w') | |
hint = self.input_hint_block(hint) | |
# hint = rearrange(hint, '(b f) c h w -> b c f h w', b = batch) | |
e = self.time_embed(sinusoidal_embedding(t, self.dim)) | |
e = e.repeat_interleave(repeats=f, dim=0) | |
context = y.repeat_interleave(repeats=f, dim=0) | |
# always in shape (b f) c h w, except for temporal layer | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
# print('before x shape:', x.shape) [64, 320, 90, 160] | |
# print('hint shape:', hint.shape) [32, 320, 90, 160] | |
# encoder | |
xs = [] | |
for module, zero_conv in zip(self.input_blocks, self.zero_convs): | |
if hint is not None: | |
for block in module: | |
x = self._forward_single(block, x, e, context, time_rel_pos_bias, | |
focus_present_mask, video_mask, variant_info=variant_info) | |
if not isinstance(block, TemporalTransformer): | |
if hint is not None: | |
x += hint | |
hint = None | |
else: | |
x = self._forward_single(module, x, e, context, time_rel_pos_bias, | |
focus_present_mask, video_mask, variant_info=variant_info) | |
xs.append(zero_conv(x, e, context)) | |
# middle | |
for block in self.middle_block: | |
x = self._forward_single(block, x, e, context, time_rel_pos_bias, | |
focus_present_mask, video_mask, variant_info=variant_info) | |
xs.append(self.middle_block_out(x, e, context)) | |
return xs | |
def _forward_single(self, | |
module, | |
x, | |
e, | |
context, | |
time_rel_pos_bias, | |
focus_present_mask, | |
video_mask, | |
reference=None, | |
variant_info=None,): | |
# variant_info = None # For Debug | |
if isinstance(module, ResidualBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = x.contiguous() | |
x = module(x, e, reference) | |
elif isinstance(module, ResBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = x.contiguous() | |
x = module(x, e, self.batch, variant_info) | |
elif isinstance(module, SpatialTransformer): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, TemporalTransformer): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
# print("x shape:", x.shape) # [2, 320, 32, 90, 160] | |
x = module(x, context) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, CrossAttention): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, MemoryEfficientCrossAttention): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, BasicTransformerBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = module(x, context) | |
elif isinstance(module, FeedForward): | |
x = module(x, context) | |
elif isinstance(module, Upsample): | |
x = module(x) | |
elif isinstance(module, Downsample): | |
x = module(x) | |
elif isinstance(module, Resample): | |
x = module(x, reference) | |
elif isinstance(module, TemporalAttentionBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, TemporalAttentionMultiBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x, time_rel_pos_bias, focus_present_mask, video_mask) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, InitTemporalConvBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, TemporalConvBlock): | |
module = checkpoint_wrapper( | |
module) if self.use_checkpoint else module | |
x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) | |
x = module(x) | |
x = rearrange(x, 'b c f h w -> (b f) c h w') | |
elif isinstance(module, nn.ModuleList): | |
for block in module: | |
x = self._forward_single(block, x, e, context, | |
time_rel_pos_bias, focus_present_mask, | |
video_mask, reference, variant_info) | |
else: | |
x = module(x) | |
return x | |
class TimestepBlock(nn.Module): | |
""" | |
Any module where forward() takes timestep embeddings as a second argument. | |
""" | |
def forward(self, x, emb): | |
""" | |
Apply the module to `x` given `emb` timestep embeddings. | |
""" | |
class TimestepEmbedSequential(nn.Sequential, TimestepBlock): | |
""" | |
A sequential module that passes timestep embeddings to the children that | |
support it as an extra input. | |
""" | |
def forward(self, x, emb, context=None): | |
for layer in self: | |
if isinstance(layer, TimestepBlock): | |
x = layer(x, emb) | |
elif isinstance(layer, SpatialTransformer): | |
x = layer(x, context) | |
else: | |
x = layer(x) | |
return x |