DeciDiffusion-v1-0 / pipeline.py
NatanBagrov's picture
added option to skip mid block (#5)
10c31ce
import itertools
from typing import Any, Optional, Dict, Tuple
import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import Transformer2DModel, ModelMixin, ConfigMixin
from diffusers import UNet2DConditionModel
from diffusers.configuration_utils import register_to_config
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModelOutput
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from torch import nn
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
FlexibleUnetConfigurations = {
# General parameters for all blocks
'sample_size': 64,
'temb_dim': 320 * 4,
'resnet_eps': 1e-5,
'resnet_act_fn': 'silu',
'num_attention_heads': 8,
'cross_attention_dim': 768,
# Controls modules execute order in unet's forward
'mix_block_in_forward': True,
# Down blocks parameters
'down_blocks_in_channels': [320, 320, 640],
'down_blocks_out_channels': [320, 640, 1280],
'down_blocks_num_attentions': [0, 1, 3],
'down_blocks_num_resnets': [2, 2, 1],
'add_downsample': [True, True, True],
# Middle block parameters
'add_upsample_mid_block': True,
'mid_num_resnets': 4,
'mid_num_attentions': 2,
# Up block parameters
'prev_output_channels': [1280, 1280, 640],
'up_blocks_num_attentions': [6, 3, 0],
'up_blocks_num_resnets': [2, 3, 3],
'add_upsample': [True, True, False],
}
def custom_sort_order(obj):
"""
Key function for sorting order of execution in forward methods
"""
return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
class FlexibleIdentityBlock(nn.Module):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
return hidden_states
class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
configurations = FlexibleUnetConfigurations
@register_to_config
def __init__(self):
super().__init__(sample_size=self.configurations.get('sample_size',
FlexibleUnetConfigurations['sample_size']),
cross_attention_dim=self.configurations.get("cross_attention_dim",
FlexibleUnetConfigurations['cross_attention_dim']))
num_attention_heads = self.configurations.get("num_attention_heads")
cross_attention_dim = self.configurations.get("cross_attention_dim")
mix_block_in_forward = self.configurations.get("mix_block_in_forward")
resnet_act_fn = self.configurations.get("resnet_act_fn")
resnet_eps = self.configurations.get("resnet_eps")
temb_dim = self.configurations.get("temb_dim")
###############
# Down blocks #
###############
down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions")
down_blocks_out_channels = self.configurations.get("down_blocks_out_channels")
down_blocks_in_channels = self.configurations.get("down_blocks_in_channels")
down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets")
add_downsample = self.configurations.get("add_downsample")
self.down_blocks = nn.ModuleList()
for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(zip(down_blocks_in_channels, down_blocks_out_channels,
down_blocks_num_resnets,
down_blocks_num_attentions,
add_downsample)):
last_block = i == len(down_blocks_in_channels) - 1
self.down_blocks.append(FlexibleCrossAttnDownBlock2D(in_channels=in_c,
out_channels=out_c,
temb_channels=temb_dim,
num_resnets=n_res,
num_attentions=n_att,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_downsample=add_down,
last_block=last_block,
mix_block_in_forward=mix_block_in_forward))
###############
# Mid blocks #
###############
mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
mid_num_attentions = self.configurations.get("mid_num_attentions")
mid_num_resnets = self.configurations.get("mid_num_resnets")
if mid_num_resnets == mid_num_attentions == 0:
self.mid_block = FlexibleIdentityBlock()
else:
self.mid_block = FlexibleUNetMidBlock2DCrossAttn(in_channels=down_blocks_out_channels[-1],
temb_channels=temb_dim,
resnet_act_fn=resnet_act_fn,
resnet_eps=resnet_eps,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
num_resnets=mid_num_resnets,
num_attentions=mid_num_attentions,
mix_block_in_forward=mix_block_in_forward,
add_upsample=mid_block_add_upsample
)
###############
# Up blocks #
###############
up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions")
up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets")
prev_output_channels = self.configurations.get("prev_output_channels")
up_upsample = self.configurations.get("add_upsample")
self.up_blocks = nn.ModuleList()
for in_c, out_c, prev_out, n_res, n_att, add_up in zip(reversed(down_blocks_in_channels),
reversed(down_blocks_out_channels),
prev_output_channels,
up_blocks_num_resnets, up_blocks_num_attentions,
up_upsample):
self.up_blocks.append(FlexibleCrossAttnUpBlock2D(in_channels=in_c,
out_channels=out_c,
prev_output_channel=prev_out,
temb_channels=temb_dim,
num_resnets=n_res,
num_attentions=n_att,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
num_attention_heads=num_attention_heads,
cross_attention_dim=cross_attention_dim,
add_upsample=add_up,
mix_block_in_forward=mix_block_in_forward
))
class FlexibleCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
downsample_padding: int = 1,
add_downsample: bool = True,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
last_block: bool = False,
mix_block_in_forward: bool = True,
):
super().__init__()
self.last_block = last_block
self.mix_block_in_forward = mix_block_in_forward
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
modules = []
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(
itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
in_channels = in_channels if i == 0 else out_channels
if add_resnet:
modules.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample2D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
output_states = ()
for module in self.modules_list:
if isinstance(module, ResnetBlock2D):
hidden_states = module(hidden_states, temb)
elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
raise ValueError(f'Got an unexpected module in modules list! {type(module)}')
if isinstance(module, ResnetBlock2D):
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
if not self.last_block:
output_states = output_states + (hidden_states,)
return hidden_states, output_states
class FlexibleCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
cross_attention_dim: int = 1280,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
mix_block_in_forward: bool = True
):
super().__init__()
modules = []
# WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline
self.resnets = []
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(
itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
if add_resnet:
self.resnets += [True]
modules.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads,
out_channels // num_attention_heads,
in_channels=out_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
self.upsamplers = None
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.FloatTensor,
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
upsample_size: Optional[int] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
for module in self.modules_list:
if isinstance(module, ResnetBlock2D):
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = module(hidden_states, temb)
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class FlexibleUNetMidBlock2DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_resnets: int = 1,
num_attentions: int = 1,
transformer_layers_per_block: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
num_attention_heads: int = 1,
output_scale_factor: float = 1.0,
cross_attention_dim: int = 1280,
use_linear_projection: bool = False,
upcast_attention: bool = False,
mix_block_in_forward: bool = True,
add_upsample: bool = True,
):
super().__init__()
self.has_cross_attention = True
self.num_attention_heads = num_attention_heads
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# There is always at least one resnet
modules = [ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)]
add_resnets = [True] * num_resnets
add_cross_attentions = [True] * num_attentions
for i, (add_resnet, add_cross_attention) in enumerate(
itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
if add_cross_attention:
modules.append(
FlexibleTransformer2DModel(
num_attention_heads,
in_channels // num_attention_heads,
in_channels=in_channels,
num_layers=transformer_layers_per_block,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
)
if add_resnet:
modules.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
if not mix_block_in_forward:
modules = sorted(modules, key=custom_sort_order)
self.modules_list = nn.ModuleList(modules)
self.upsamplers = nn.ModuleList([nn.Identity()])
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.modules_list[0](hidden_states, temb)
for module in self.modules_list:
if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
hidden_states = module(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
elif isinstance(module, ResnetBlock2D):
hidden_states = module(hidden_states, temb)
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
return hidden_states
class FlexibleTransformer2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
only_cross_attention: bool = False,
use_linear_projection: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.in_channels = in_channels
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.use_linear_projection = use_linear_projection
if self.use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
)
for _ in range(num_layers)
]
)
# Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = False
):
# 1. Input
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
class DeciDiffusionPipeline(StableDiffusionPipeline):
deci_default_number_of_iterations = 30
deci_default_guidance_rescale = 0.7
def __init__(self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True
):
# Replace UNet with Deci`s unet
del unet
unet = FlexibleUNet2DConditionModel()
super().__init__(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
requires_safety_checker=requires_safety_checker
)
self.register_modules(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor)
def __call__(self, *args, **kwargs):
# Set up default training parameters (if not given by user specifically)
if "guidance_rescale" not in kwargs:
kwargs.update({'guidance_rescale': self.deci_default_guidance_rescale})
if "num_inference_steps" not in kwargs:
kwargs.update({'num_inference_steps': self.deci_default_number_of_iterations})
return super().__call__(*args, **kwargs)