ObjCtrl-2.5D / cameractrl /models /attention_processor.py
wzhouxiff
init
38e3f9b
raw
history blame
23.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import logging
from diffusers.models.attention import Attention
from diffusers.utils import USE_PEFT_BACKEND, is_xformers_available
from typing import Optional, Callable
from einops import rearrange
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
logger = logging.getLogger(__name__)
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
pose_feature=None
) -> torch.Tensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
pose_feature=None
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class XFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
pose_feature=None
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PoseAdaptorAttnProcessor(nn.Module):
def __init__(self,
hidden_size, # dimension of hidden state
pose_feature_dim=None, # dimension of the pose feature
cross_attention_dim=None, # dimension of the text embedding
query_condition=False,
key_value_condition=False,
scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.pose_feature_dim = pose_feature_dim
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.query_condition = query_condition
self.key_value_condition = key_value_condition
assert hidden_size == pose_feature_dim
if self.query_condition and self.key_value_condition:
self.qkv_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.qkv_merge.weight)
init.zeros_(self.qkv_merge.bias)
elif self.query_condition:
self.q_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.q_merge.weight)
init.zeros_(self.q_merge.bias)
else:
self.kv_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.kv_merge.weight)
init.zeros_(self.kv_merge.bias)
def forward(self,
attn,
hidden_states,
pose_feature,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=None,):
assert pose_feature is not None
pose_embedding_scale = (scale or self.scale)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
assert hidden_states.ndim == 3 and pose_feature.ndim == 3
if self.query_condition and self.key_value_condition:
assert encoder_hidden_states is None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
assert encoder_hidden_states.ndim == 3
batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.query_condition and self.key_value_condition: # only self attention
query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
key_value_hidden_state = query_hidden_state
elif self.query_condition:
query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
key_value_hidden_state = encoder_hidden_states
else:
key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
query_hidden_state = hidden_states
# original attention
query = attn.to_q(query_hidden_state)
key = attn.to_k(key_value_hidden_state)
value = attn.to_v(key_value_hidden_state)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PoseAdaptorAttnProcessor2_0(nn.Module):
def __init__(self,
hidden_size, # dimension of hidden state
pose_feature_dim=None, # dimension of the pose feature
cross_attention_dim=None, # dimension of the text embedding
query_condition=False,
key_value_condition=False,
scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.pose_feature_dim = pose_feature_dim
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.query_condition = query_condition
self.key_value_condition = key_value_condition
assert hidden_size == pose_feature_dim
if self.query_condition and self.key_value_condition:
self.qkv_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.qkv_merge.weight)
init.zeros_(self.qkv_merge.bias)
elif self.query_condition:
self.q_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.q_merge.weight)
init.zeros_(self.q_merge.bias)
else:
self.kv_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.kv_merge.weight)
init.zeros_(self.kv_merge.bias)
def forward(self,
attn,
hidden_states,
pose_feature,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=None,):
assert pose_feature is not None
pose_embedding_scale = (scale or self.scale)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
assert hidden_states.ndim == 3 and pose_feature.ndim == 3
if self.query_condition and self.key_value_condition:
assert encoder_hidden_states is None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
assert encoder_hidden_states.ndim == 3
batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.query_condition and self.key_value_condition: # only self attention
query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
key_value_hidden_state = query_hidden_state
elif self.query_condition:
query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
key_value_hidden_state = encoder_hidden_states
else:
key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
query_hidden_state = hidden_states
# original attention
query = attn.to_q(query_hidden_state)
key = attn.to_k(key_value_hidden_state)
value = attn.to_v(key_value_hidden_state)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # [bs, seq_len, nhead, head_dim] -> [bs, nhead, seq_len, head_dim]
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False) # [bs, nhead, seq_len, head_dim]
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) # [bs, seq_len, dim]
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class PoseAdaptorXFormersAttnProcessor(nn.Module):
def __init__(self,
hidden_size, # dimension of hidden state
pose_feature_dim=None, # dimension of the pose feature
cross_attention_dim=None, # dimension of the text embedding
query_condition=False,
key_value_condition=False,
scale=1.0,
attention_op: Optional[Callable] = None):
super().__init__()
self.hidden_size = hidden_size
self.pose_feature_dim = pose_feature_dim
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.query_condition = query_condition
self.key_value_condition = key_value_condition
self.attention_op = attention_op
assert hidden_size == pose_feature_dim
if self.query_condition and self.key_value_condition:
self.qkv_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.qkv_merge.weight)
init.zeros_(self.qkv_merge.bias)
elif self.query_condition:
self.q_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.q_merge.weight)
init.zeros_(self.q_merge.bias)
else:
self.kv_merge = nn.Linear(hidden_size, hidden_size)
init.zeros_(self.kv_merge.weight)
init.zeros_(self.kv_merge.bias)
def forward(self,
attn,
hidden_states,
pose_feature,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=None,):
assert pose_feature is not None
pose_embedding_scale = (scale or self.scale)
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
assert hidden_states.ndim == 3 and pose_feature.ndim == 3
if self.query_condition and self.key_value_condition:
assert encoder_hidden_states is None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
assert encoder_hidden_states.ndim == 3
batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if self.query_condition and self.key_value_condition: # only self attention
query_hidden_state = self.qkv_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
key_value_hidden_state = query_hidden_state
elif self.query_condition:
query_hidden_state = self.q_merge(hidden_states + pose_feature) * pose_embedding_scale + hidden_states
key_value_hidden_state = encoder_hidden_states
else:
key_value_hidden_state = self.kv_merge(encoder_hidden_states + pose_feature) * pose_embedding_scale + encoder_hidden_states
query_hidden_state = hidden_states
# original attention
query = attn.to_q(query_hidden_state)
key = attn.to_k(key_value_hidden_state)
value = attn.to_v(key_value_hidden_state)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states