depthanyvideo
update
e9f3e75
raw
history blame
19.4 kB
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.utils import deprecate
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import SpatialNorm
from diffusers.models.downsampling import ( # noqa
Downsample2D,
downsample_2d,
)
from diffusers.models.normalization import AdaGroupNorm
from diffusers.models.upsampling import ( # noqa
Upsample2D,
upsample_2d,
)
class ResnetBlock2D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
pre_norm: bool = True,
eps: float = 1e-6,
non_linearity: str = "swish",
skip_time_act: bool = False,
time_embedding_norm: str = "default", # default, scale_shift,
kernel: Optional[torch.FloatTensor] = None,
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
up: bool = False,
down: bool = False,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
if time_embedding_norm == "ada_group":
raise ValueError(
"This class cannot be used with `time_embedding_norm==ada_group`, please use `ResnetBlockCondNorm2D` instead",
)
if time_embedding_norm == "spatial":
raise ValueError(
"This class cannot be used with `time_embedding_norm==spatial`, please use `ResnetBlockCondNorm2D` instead",
)
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.up = up
self.down = down
self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear
conv_cls = nn.Conv2d
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
self.conv1 = conv_cls(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
else:
raise ValueError(
f"unknown time_embedding_norm : {self.time_embedding_norm} "
)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(
num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(
out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1
)
self.nonlinearity = get_activation(non_linearity)
self.upsample = self.downsample = None
if self.up:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp":
self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
else:
self.upsample = Upsample2D(in_channels, use_conv=False)
elif self.down:
if kernel == "fir":
fir_kernel = (1, 3, 3, 1)
self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
elif kernel == "sde_vp":
self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
else:
self.downsample = Downsample2D(
in_channels, use_conv=False, padding=1, name="op"
)
self.use_in_shortcut = (
self.in_channels != conv_2d_out_channels
if use_in_shortcut is None
else use_in_shortcut
)
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs
) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
if self.time_embedding_norm == "default":
if temb is not None:
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
elif self.time_embedding_norm == "scale_shift":
if temb is None:
raise ValueError(
f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
)
time_scale, time_shift = torch.chunk(temb, 2, dim=1)
hidden_states = self.norm2(hidden_states)
hidden_states = hidden_states * (1 + time_scale) + time_shift
else:
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class TemporalResnetBlock(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
temb_channels: int = 512,
eps: float = 1e-6,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
kernel_size = (3, 1, 1)
padding = [k // 2 for k in kernel_size]
self.norm1 = torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=eps, affine=True
)
self.conv1 = nn.Conv3d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=eps, affine=True
)
self.dropout = torch.nn.Dropout(0.0)
self.conv2 = nn.Conv3d(
out_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
self.nonlinearity = get_activation("silu")
self.use_in_shortcut = self.in_channels != out_channels
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = nn.Conv3d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor
) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, :, None, None]
temb = temb.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
# VideoResBlock
class SpatioTemporalResBlock(nn.Module):
r"""
A SpatioTemporal Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
The merge strategy to use for the temporal mixing.
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
If `True`, switch the spatial and temporal mixing.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
temb_channels: int = 512,
eps: float = 1e-6,
temporal_eps: Optional[float] = None,
merge_factor: float = 0.5,
merge_strategy="learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.spatial_res_block = ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=eps,
)
self.temporal_res_block = TemporalResnetBlock(
in_channels=out_channels if out_channels is not None else in_channels,
out_channels=out_channels if out_channels is not None else in_channels,
temb_channels=temb_channels,
eps=temporal_eps if temporal_eps is not None else eps,
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
):
num_frames = image_only_indicator.shape[-1]
hidden_states = self.spatial_res_block(hidden_states, temb)
batch_frames, channels, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states_mix = (
hidden_states[None, :]
.reshape(batch_size, num_frames, channels, height, width)
.permute(0, 2, 1, 3, 4)
)
hidden_states = (
hidden_states[None, :]
.reshape(batch_size, num_frames, channels, height, width)
.permute(0, 2, 1, 3, 4)
)
if temb is not None:
temb = temb.reshape(batch_size, num_frames, -1)
hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
x_spatial=hidden_states_mix,
x_temporal=hidden_states,
image_only_indicator=image_only_indicator,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(
batch_frames, channels, height, width
)
return hidden_states
class AlphaBlender(nn.Module):
r"""
A module to blend spatial and temporal features.
Parameters:
alpha (`float`): The initial value of the blending factor.
merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
The merge strategy to use for the temporal mixing.
switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
If `True`, switch the spatial and temporal mixing.
"""
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.merge_strategy = merge_strategy
self.switch_spatial_to_temporal_mix = (
switch_spatial_to_temporal_mix # For TemporalVAE
)
if merge_strategy not in self.strategies:
raise ValueError(f"merge_strategy needs to be in {self.strategies}")
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif (
self.merge_strategy == "learned"
or self.merge_strategy == "learned_with_images"
):
self.register_parameter(
"mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
)
else:
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
if image_only_indicator is None:
raise ValueError(
"Please provide image_only_indicator to use learned_with_images merge strategy"
)
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
torch.sigmoid(self.mix_factor)[..., None],
)
# (batch, channel, frames, height, width)
if ndims == 5:
alpha = alpha[:, None, :, None, None]
# (batch*frames, height*width, channels)
elif ndims == 3:
alpha = alpha.reshape(-1)[:, None, None]
else:
raise ValueError(
f"Unexpected ndims {ndims}. Dimensions should be 3 or 5"
)
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
alpha = alpha.to(x_spatial.dtype)
if self.switch_spatial_to_temporal_mix:
alpha = 1.0 - alpha
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
return x