MV-Adapter-T2MV-SDXL / mvadapter /models /attention_processor.py
huanngzh's picture
init
6ef620e
import math
from typing import Callable, List, Optional, Union
import torch
import torch.nn.functional as F
from diffusers.models.attention_processor import Attention
from diffusers.models.unets import UNet2DConditionModel
from diffusers.utils import deprecate, logging
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
from einops import rearrange
from torch import nn
def default_set_attn_proc_func(
name: str,
hidden_size: int,
cross_attention_dim: Optional[int],
ori_attn_proc: object,
) -> object:
return ori_attn_proc
def set_unet_2d_condition_attn_processor(
unet: UNet2DConditionModel,
set_self_attn_proc_func: Callable = default_set_attn_proc_func,
set_cross_attn_proc_func: Callable = default_set_attn_proc_func,
set_custom_attn_proc_func: Callable = default_set_attn_proc_func,
set_self_attn_module_names: Optional[List[str]] = None,
set_cross_attn_module_names: Optional[List[str]] = None,
set_custom_attn_module_names: Optional[List[str]] = None,
) -> None:
do_set_processor = lambda name, module_names: (
any([name.startswith(module_name) for module_name in module_names])
if module_names is not None
else True
) # prefix match
attn_procs = {}
for name, attn_processor in unet.attn_processors.items():
# set attn_processor by default, if module_names is None
set_self_attn_processor = do_set_processor(name, set_self_attn_module_names)
set_cross_attn_processor = do_set_processor(name, set_cross_attn_module_names)
set_custom_attn_processor = do_set_processor(name, set_custom_attn_module_names)
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
is_custom = "attn_mid_blocks" in name or "attn_post_blocks" in name
if is_custom:
attn_procs[name] = (
set_custom_attn_proc_func(name, hidden_size, None, attn_processor)
if set_custom_attn_processor
else attn_processor
)
else:
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else unet.config.cross_attention_dim
)
if cross_attention_dim is None or "motion_modules" in name:
# self attention
attn_procs[name] = (
set_self_attn_proc_func(
name, hidden_size, cross_attention_dim, attn_processor
)
if set_self_attn_processor
else attn_processor
)
else:
# cross attention
attn_procs[name] = (
set_cross_attn_proc_func(
name, hidden_size, cross_attention_dim, attn_processor
)
if set_cross_attn_processor
else attn_processor
)
unet.set_attn_processor(attn_procs)
class DecoupledMVRowSelfAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for Decoupled Row-wise Self-Attention and Image Cross-Attention for PyTorch 2.0.
"""
def __init__(
self,
query_dim: int,
inner_dim: int,
num_views: int = 1,
name: Optional[str] = None,
use_mv: bool = True,
use_ref: bool = False,
):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"DecoupledMVRowSelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
super().__init__()
self.num_views = num_views
self.name = name # NOTE: need for image cross-attention
self.use_mv = use_mv
self.use_ref = use_ref
if self.use_mv:
self.to_q_mv = nn.Linear(
in_features=query_dim, out_features=inner_dim, bias=False
)
self.to_k_mv = nn.Linear(
in_features=query_dim, out_features=inner_dim, bias=False
)
self.to_v_mv = nn.Linear(
in_features=query_dim, out_features=inner_dim, bias=False
)
self.to_out_mv = nn.ModuleList(
[
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
nn.Dropout(0.0),
]
)
if self.use_ref:
self.to_q_ref = nn.Linear(
in_features=query_dim, out_features=inner_dim, bias=False
)
self.to_k_ref = nn.Linear(
in_features=query_dim, out_features=inner_dim, bias=False
)
self.to_v_ref = nn.Linear(
in_features=query_dim, out_features=inner_dim, bias=False
)
self.to_out_ref = nn.ModuleList(
[
nn.Linear(in_features=inner_dim, out_features=query_dim, bias=True),
nn.Dropout(0.0),
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
mv_scale: float = 1.0,
ref_hidden_states: Optional[torch.FloatTensor] = None,
ref_scale: float = 1.0,
cache_hidden_states: Optional[List[torch.FloatTensor]] = None,
use_mv: bool = True,
use_ref: bool = True,
*args,
**kwargs,
) -> torch.FloatTensor:
"""
New args:
mv_scale (float): scale for multi-view self-attention.
ref_hidden_states (torch.FloatTensor): reference encoder hidden states for image cross-attention.
ref_scale (float): scale for image cross-attention.
cache_hidden_states (List[torch.FloatTensor]): cache hidden states from reference unet.
"""
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
# NEW: cache hidden states for reference unet
if cache_hidden_states is not None:
cache_hidden_states[self.name] = hidden_states.clone()
# NEW: whether to use multi-view attention and image cross-attention
use_mv = self.use_mv and use_mv
use_ref = self.use_ref and use_ref
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(
attention_mask, sequence_length, batch_size
)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(
batch_size, attn.heads, -1, attention_mask.shape[-1]
)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states)
# NEW: for decoupled multi-view attention
if use_mv:
query_mv = self.to_q_mv(hidden_states)
# NEW: for decoupled reference cross attention
if use_ref:
query_ref = self.to_q_ref(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states = hidden_states.to(query.dtype)
####### Decoupled multi-view self-attention ########
if use_mv:
key_mv = self.to_k_mv(encoder_hidden_states)
value_mv = self.to_v_mv(encoder_hidden_states)
query_mv = query_mv.view(batch_size, -1, attn.heads, head_dim)
key_mv = key_mv.view(batch_size, -1, attn.heads, head_dim)
value_mv = value_mv.view(batch_size, -1, attn.heads, head_dim)
height = width = math.isqrt(sequence_length)
# row self-attention
query_mv = rearrange(
query_mv,
"(b nv) (ih iw) h c -> (b nv ih) iw h c",
nv=self.num_views,
ih=height,
iw=width,
).transpose(1, 2)
key_mv = rearrange(
key_mv,
"(b nv) (ih iw) h c -> b ih (nv iw) h c",
nv=self.num_views,
ih=height,
iw=width,
)
key_mv = (
key_mv.repeat_interleave(self.num_views, dim=0)
.view(batch_size * height, -1, attn.heads, head_dim)
.transpose(1, 2)
)
value_mv = rearrange(
value_mv,
"(b nv) (ih iw) h c -> b ih (nv iw) h c",
nv=self.num_views,
ih=height,
iw=width,
)
value_mv = (
value_mv.repeat_interleave(self.num_views, dim=0)
.view(batch_size * height, -1, attn.heads, head_dim)
.transpose(1, 2)
)
hidden_states_mv = F.scaled_dot_product_attention(
query_mv,
key_mv,
value_mv,
dropout_p=0.0,
is_causal=False,
)
hidden_states_mv = rearrange(
hidden_states_mv,
"(b nv ih) h iw c -> (b nv) (ih iw) (h c)",
nv=self.num_views,
ih=height,
)
hidden_states_mv = hidden_states_mv.to(query.dtype)
# linear proj
hidden_states_mv = self.to_out_mv[0](hidden_states_mv)
# dropout
hidden_states_mv = self.to_out_mv[1](hidden_states_mv)
if use_ref:
reference_hidden_states = ref_hidden_states[self.name]
key_ref = self.to_k_ref(reference_hidden_states)
value_ref = self.to_v_ref(reference_hidden_states)
query_ref = query_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
1, 2
)
key_ref = key_ref.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ref = value_ref.view(batch_size, -1, attn.heads, head_dim).transpose(
1, 2
)
hidden_states_ref = F.scaled_dot_product_attention(
query_ref, key_ref, value_ref, dropout_p=0.0, is_causal=False
)
hidden_states_ref = hidden_states_ref.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
hidden_states_ref = hidden_states_ref.to(query.dtype)
# linear proj
hidden_states_ref = self.to_out_ref[0](hidden_states_ref)
# dropout
hidden_states_ref = self.to_out_ref[1](hidden_states_ref)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if use_mv:
hidden_states = hidden_states + hidden_states_mv * mv_scale
if use_ref:
hidden_states = hidden_states + hidden_states_ref * ref_scale
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def set_num_views(self, num_views: int) -> None:
self.num_views = num_views