Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.utils import BaseOutput, logging | |
from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
from diffusers.models.modeling_utils import ModelMixin | |
from diffusers.models.resnet import Downsample2D, ResnetBlock2D | |
from einops import rearrange | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class ControlNetOutput(BaseOutput): | |
""" | |
The output of [`ControlNetModel`]. | |
Args: | |
down_block_res_samples (`tuple[torch.Tensor]`): | |
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should | |
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be | |
used to condition the original UNet's downsampling activations. | |
mid_down_block_re_sample (`torch.Tensor`): | |
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape | |
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. | |
Output can be used to condition the original UNet's middle block activation. | |
""" | |
down_block_res_samples: Tuple[torch.Tensor] | |
mid_block_res_sample: torch.Tensor | |
class Block2D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
downsample_padding: int = 1, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
Downsample2D( | |
out_channels, | |
use_conv=True, | |
out_channels=out_channels, | |
padding=downsample_padding, | |
name="op", | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |
output_states = () | |
for resnet in zip(self.resnets): | |
hidden_states = resnet(hidden_states, temb) | |
output_states += (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class IdentityModule(nn.Module): | |
def __init__(self): | |
super(IdentityModule, self).__init__() | |
def forward(self, *args): | |
if len(args) > 0: | |
return args[0] | |
else: | |
return None | |
class BasicBlock(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: Optional[int] = None, | |
stride=1, | |
conv_shortcut: bool = False, | |
dropout: float = 0.0, | |
temb_channels: int = 512, | |
groups: int = 32, | |
groups_out: Optional[int] = None, | |
pre_norm: bool = True, | |
eps: float = 1e-6, | |
non_linearity: str = "swish", | |
skip_time_act: bool = False, | |
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial | |
kernel: Optional[torch.FloatTensor] = None, | |
output_scale_factor: float = 1.0, | |
use_in_shortcut: Optional[bool] = None, | |
up: bool = False, | |
down: bool = False, | |
conv_shortcut_bias: bool = True, | |
conv_2d_out_channels: Optional[int] = None,): | |
super(BasicBlock, self).__init__() | |
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False) | |
self.bn1 = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(out_channels) | |
self.downsample = None | |
if stride != 1 or in_channels != out_channels: | |
self.downsample = nn.Sequential( | |
nn.Conv2d(in_channels, | |
out_channels, | |
kernel_size=3 if stride != 1 else 1, | |
stride=stride, | |
padding=1 if stride != 1 else 0, | |
bias=False), | |
nn.BatchNorm2d(out_channels) | |
) | |
def forward(self, x, *args): | |
residual = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.downsample is not None: | |
residual = self.downsample(x) | |
out += residual | |
out = self.relu(out) | |
return out | |
class Block2D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
temb_channels: int, | |
dropout: float = 0.0, | |
num_layers: int = 1, | |
resnet_eps: float = 1e-6, | |
resnet_time_scale_shift: str = "default", | |
resnet_act_fn: str = "swish", | |
resnet_groups: int = 32, | |
resnet_pre_norm: bool = True, | |
output_scale_factor: float = 1.0, | |
add_downsample: bool = True, | |
downsample_padding: int = 1, | |
): | |
super().__init__() | |
resnets = [] | |
for i in range(num_layers): | |
# in_channels = in_channels if i == 0 else out_channels | |
resnets.append( | |
# ResnetBlock2D( | |
# in_channels=in_channels, | |
# out_channels=out_channels, | |
# temb_channels=temb_channels, | |
# eps=resnet_eps, | |
# groups=resnet_groups, | |
# dropout=dropout, | |
# time_embedding_norm=resnet_time_scale_shift, | |
# non_linearity=resnet_act_fn, | |
# output_scale_factor=output_scale_factor, | |
# pre_norm=resnet_pre_norm, | |
BasicBlock( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) if i == num_layers - 1 else \ | |
IdentityModule() | |
) | |
self.resnets = nn.ModuleList(resnets) | |
if add_downsample: | |
self.downsamplers = nn.ModuleList( | |
[ | |
# Downsample2D( | |
# out_channels, | |
# use_conv=True, | |
# out_channels=out_channels, | |
# padding=downsample_padding, | |
# name="op", | |
# ) | |
BasicBlock( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
stride=2, | |
eps=resnet_eps, | |
groups=resnet_groups, | |
dropout=dropout, | |
time_embedding_norm=resnet_time_scale_shift, | |
non_linearity=resnet_act_fn, | |
output_scale_factor=output_scale_factor, | |
pre_norm=resnet_pre_norm, | |
) | |
] | |
) | |
else: | |
self.downsamplers = None | |
self.gradient_checkpointing = False | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
temb: Optional[torch.FloatTensor] = None, | |
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: | |
output_states = () | |
for resnet in self.resnets: | |
hidden_states = resnet(hidden_states, temb) | |
output_states += (hidden_states,) | |
if self.downsamplers is not None: | |
for downsampler in self.downsamplers: | |
hidden_states = downsampler(hidden_states) | |
output_states += (hidden_states,) | |
return hidden_states, output_states | |
class ControlProject(nn.Module): | |
def __init__(self, num_channels, scale=8, is_empty=False) -> None: | |
super().__init__() | |
assert scale and scale & (scale - 1) == 0 | |
self.is_empty = is_empty | |
self.scale = scale | |
if not is_empty: | |
if scale > 1: | |
self.down_scale = nn.AvgPool2d(scale, scale) | |
else: | |
self.down_scale = nn.Identity() | |
self.out = nn.Conv2d(num_channels, num_channels, kernel_size=1, stride=1, bias=False) | |
for p in self.out.parameters(): | |
nn.init.zeros_(p) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor): | |
if self.is_empty: | |
shape = list(hidden_states.shape) | |
shape[-2] = shape[-2] // self.scale | |
shape[-1] = shape[-1] // self.scale | |
return torch.zeros(shape).to(hidden_states) | |
if len(hidden_states.shape) == 5: | |
B, F, C, H, W = hidden_states.shape | |
hidden_states = rearrange(hidden_states, "B F C H W -> (B F) C H W") | |
hidden_states = self.down_scale(hidden_states) | |
hidden_states = self.out(hidden_states) | |
hidden_states = rearrange(hidden_states, "(B F) C H W -> B F C H W", F=F) | |
else: | |
hidden_states = self.down_scale(hidden_states) | |
hidden_states = self.out(hidden_states) | |
return hidden_states | |
class ControlNetModel(ModelMixin, ConfigMixin): | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
in_channels: List[int] = [128, 128], | |
out_channels: List[int] = [128, 256], | |
groups: List[int] = [4, 8], | |
time_embed_dim: int = 256, | |
final_out_channels: int = 320, | |
): | |
super().__init__() | |
self.time_proj = Timesteps(128, True, downscale_freq_shift=0) | |
self.time_embedding = TimestepEmbedding(128, time_embed_dim) | |
self.embedding = nn.Sequential( | |
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), | |
nn.GroupNorm(2, 64), | |
nn.ReLU(), | |
nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
nn.GroupNorm(2, 64), | |
nn.ReLU(), | |
nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
nn.GroupNorm(2, 128), | |
nn.ReLU(), | |
) | |
self.down_res = nn.ModuleList() | |
self.down_sample = nn.ModuleList() | |
for i in range(len(in_channels)): | |
self.down_res.append( | |
ResnetBlock2D( | |
in_channels=in_channels[i], | |
out_channels=out_channels[i], | |
temb_channels=time_embed_dim, | |
groups=groups[i] | |
), | |
) | |
self.down_sample.append( | |
Downsample2D( | |
out_channels[i], | |
use_conv=True, | |
out_channels=out_channels[i], | |
padding=1, | |
name="op", | |
) | |
) | |
self.mid_convs = nn.ModuleList() | |
self.mid_convs.append(nn.Sequential( | |
nn.Conv2d( | |
in_channels=out_channels[-1], | |
out_channels=out_channels[-1], | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
), | |
nn.ReLU(), | |
nn.GroupNorm(8, out_channels[-1]), | |
nn.Conv2d( | |
in_channels=out_channels[-1], | |
out_channels=out_channels[-1], | |
kernel_size=3, | |
stride=1, | |
padding=1 | |
), | |
nn.GroupNorm(8, out_channels[-1]), | |
)) | |
self.mid_convs.append( | |
nn.Conv2d( | |
in_channels=out_channels[-1], | |
out_channels=final_out_channels, | |
kernel_size=1, | |
stride=1, | |
)) | |
self.scale = 1.0 # nn.Parameter(torch.tensor(1.)) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking | |
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: | |
""" | |
Sets the attention processor to use [feed forward | |
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). | |
Parameters: | |
chunk_size (`int`, *optional*): | |
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually | |
over each tensor of dim=`dim`. | |
dim (`int`, *optional*, defaults to `0`): | |
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) | |
or dim=1 (sequence length). | |
""" | |
if dim not in [0, 1]: | |
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") | |
# By default chunk size is 1 | |
chunk_size = chunk_size or 1 | |
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): | |
if hasattr(module, "set_chunk_feed_forward"): | |
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) | |
for child in module.children(): | |
fn_recursive_feed_forward(child, chunk_size, dim) | |
for module in self.children(): | |
fn_recursive_feed_forward(module, chunk_size, dim) | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
) -> Union[ControlNetOutput, Tuple]: | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
batch_size = sample.shape[0] | |
timesteps = timesteps.expand(batch_size) | |
t_emb = self.time_proj(timesteps) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
emb_batch = self.time_embedding(t_emb) | |
# Repeat the embeddings num_video_frames times | |
# emb: [batch, channels] -> [batch * frames, channels] | |
emb = emb_batch | |
sample = self.embedding(sample) | |
for res, downsample in zip(self.down_res, self.down_sample): | |
sample = res(sample, emb) | |
sample = downsample(sample, emb) | |
sample = self.mid_convs[0](sample) + sample | |
sample = self.mid_convs[1](sample) | |
return { | |
'out': sample, | |
'scale': self.scale, | |
} | |
def zero_module(module): | |
for p in module.parameters(): | |
nn.init.zeros_(p) | |
return module | |