Spaces:
Runtime error
Runtime error
from einops import repeat, rearrange | |
from typing import Callable, Optional, Union | |
from t2v_enhanced.model.diffusers_conditional.models.controlnet.attention_processor import Attention | |
# from t2v_enhanced.model.diffusers_conditional.controldiffusers.models.attention import Attention | |
from diffusers.utils.import_utils import is_xformers_available | |
from t2v_enhanced.model.pl_module_params_controlnet import AttentionMaskParams | |
import torch | |
import torch.nn.functional as F | |
if is_xformers_available(): | |
import xformers | |
import xformers.ops | |
else: | |
xformers = None | |
def set_use_memory_efficient_attention_xformers( | |
model, num_frame_conditioning: int, num_frames: int, attention_mask_params: AttentionMaskParams, valid: bool = True, attention_op: Optional[Callable] = None | |
) -> None: | |
# Recursively walk through all the children. | |
# Any children which exposes the set_use_memory_efficient_attention_xformers method | |
# gets the message | |
def fn_recursive_set_mem_eff(module: torch.nn.Module): | |
if hasattr(module, "set_processor"): | |
module.set_processor(XFormersAttnProcessor(attention_op=attention_op, | |
num_frame_conditioning=num_frame_conditioning, | |
num_frames=num_frames, | |
attention_mask_params=attention_mask_params,) | |
) | |
for child in module.children(): | |
fn_recursive_set_mem_eff(child) | |
for module in model.children(): | |
if isinstance(module, torch.nn.Module): | |
fn_recursive_set_mem_eff(module) | |
class XFormersAttnProcessor: | |
def __init__(self, | |
attention_mask_params: AttentionMaskParams, | |
attention_op: Optional[Callable] = None, | |
num_frame_conditioning: int = None, | |
num_frames: int = None, | |
use_image_embedding: bool = False, | |
): | |
self.attention_op = attention_op | |
self.num_frame_conditioning = num_frame_conditioning | |
self.num_frames = num_frames | |
self.temp_attend_on_neighborhood_of_condition_frames = attention_mask_params.temp_attend_on_neighborhood_of_condition_frames | |
self.spatial_attend_on_condition_frames = attention_mask_params.spatial_attend_on_condition_frames | |
self.use_image_embedding = use_image_embedding | |
def __call__(self, attn: Attention, hidden_states, hidden_state_height=None, hidden_state_width=None, encoder_hidden_states=None, attention_mask=None): | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
key_img = None | |
value_img = None | |
hidden_states_img = None | |
if attention_mask is not None: | |
attention_mask = repeat( | |
attention_mask, "1 F D -> B F D", B=batch_size) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size) | |
query = attn.to_q(hidden_states) | |
is_cross_attention = encoder_hidden_states is not None | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states) | |
default_attention = not hasattr(attn, "is_spatial_attention") | |
if default_attention: | |
assert not self.temp_attend_on_neighborhood_of_condition_frames, "special attention must be implemented with new interface" | |
assert not self.spatial_attend_on_condition_frames, "special attention must be implemented with new interface" | |
is_spatial_attention = attn.is_spatial_attention if hasattr( | |
attn, "is_spatial_attention") else False | |
use_image_embedding = attn.use_image_embedding if hasattr( | |
attn, "use_image_embedding") else False | |
if is_spatial_attention and use_image_embedding and attn.cross_attention_mode: | |
assert not self.spatial_attend_on_condition_frames, "Not implemented together with image embedding" | |
alpha = attn.alpha | |
encoder_hidden_states_txt = encoder_hidden_states[:, :77, :] | |
encoder_hidden_states_mixed = attn.conv(encoder_hidden_states) | |
encoder_hidden_states_mixed = attn.conv_ln(encoder_hidden_states_mixed) | |
encoder_hidden_states = encoder_hidden_states_txt + encoder_hidden_states_mixed * F.silu(alpha) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
else: | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
if not default_attention and not is_spatial_attention and self.temp_attend_on_neighborhood_of_condition_frames and not attn.cross_attention_mode: | |
# normal attention | |
query_condition = query[:, :self.num_frame_conditioning] | |
query_condition = attn.head_to_batch_dim( | |
query_condition).contiguous() | |
key_condition = key | |
value_condition = value | |
key_condition = attn.head_to_batch_dim(key_condition).contiguous() | |
value_condition = attn.head_to_batch_dim( | |
value_condition).contiguous() | |
hidden_states_condition = xformers.ops.memory_efficient_attention( | |
query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states_condition = hidden_states_condition.to(query.dtype) | |
hidden_states_condition = attn.batch_to_head_dim( | |
hidden_states_condition) | |
# | |
query_uncondition = query[:, self.num_frame_conditioning:] | |
key = key[:, :self.num_frame_conditioning] | |
value = value[:, :self.num_frame_conditioning] | |
key = rearrange(key, "(B W H) F C -> B W H F C", | |
H=hidden_state_height, W=hidden_state_width) | |
value = rearrange(value, "(B W H) F C -> B W H F C", | |
H=hidden_state_height, W=hidden_state_width) | |
keys = [] | |
values = [] | |
for shifts_width in [-1, 0, 1]: | |
for shifts_height in [-1, 0, 1]: | |
keys.append(torch.roll(key, shifts=( | |
shifts_width, shifts_height), dims=(1, 2))) | |
values.append(torch.roll(value, shifts=( | |
shifts_width, shifts_height), dims=(1, 2))) | |
key = rearrange(torch.cat(keys, dim=3), "B W H F C -> (B W H) F C") | |
value = rearrange(torch.cat(values, dim=3), | |
'B W H F C -> (B W H) F C') | |
query = attn.head_to_batch_dim(query_uncondition).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
hidden_states = torch.cat( | |
[hidden_states_condition, hidden_states], dim=1) | |
elif not default_attention and is_spatial_attention and self.spatial_attend_on_condition_frames and not attn.cross_attention_mode: | |
# (B F) W H C -> B F W H C | |
query_condition = rearrange( | |
query, "(B F) S C -> B F S C", F=self.num_frames) | |
query_condition = query_condition[:, :self.num_frame_conditioning] | |
query_condition = rearrange( | |
query_condition, "B F S C -> (B F) S C") | |
query_condition = attn.head_to_batch_dim( | |
query_condition).contiguous() | |
key_condition = rearrange( | |
key, "(B F) S C -> B F S C", F=self.num_frames) | |
key_condition = key_condition[:, :self.num_frame_conditioning] | |
key_condition = rearrange(key_condition, "B F S C -> (B F) S C") | |
value_condition = rearrange( | |
value, "(B F) S C -> B F S C", F=self.num_frames) | |
value_condition = value_condition[:, :self.num_frame_conditioning] | |
value_condition = rearrange( | |
value_condition, "B F S C -> (B F) S C") | |
key_condition = attn.head_to_batch_dim(key_condition).contiguous() | |
value_condition = attn.head_to_batch_dim( | |
value_condition).contiguous() | |
hidden_states_condition = xformers.ops.memory_efficient_attention( | |
query_condition, key_condition, value_condition, attn_bias=None, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states_condition = hidden_states_condition.to(query.dtype) | |
hidden_states_condition = attn.batch_to_head_dim( | |
hidden_states_condition) | |
query_uncondition = rearrange( | |
query, "(B F) S C -> B F S C", F=self.num_frames) | |
query_uncondition = query_uncondition[:, | |
self.num_frame_conditioning:] | |
key_uncondition = rearrange( | |
key, "(B F) S C -> B F S C", F=self.num_frames) | |
value_uncondition = rearrange( | |
value, "(B F) S C -> B F S C", F=self.num_frames) | |
key_uncondition = key_uncondition[:, | |
self.num_frame_conditioning-1, None] | |
value_uncondition = value_uncondition[:, | |
self.num_frame_conditioning-1, None] | |
# if self.trainer.training: | |
# import pdb | |
# pdb.set_trace() | |
# print("now") | |
query_uncondition = rearrange( | |
query_uncondition, "B F S C -> (B F) S C") | |
key_uncondition = repeat(rearrange( | |
key_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning) | |
value_uncondition = repeat(rearrange( | |
value_uncondition, "B F S C -> B (F S) C"), "B T C -> (B F) T C", F=self.num_frames-self.num_frame_conditioning) | |
query_uncondition = attn.head_to_batch_dim( | |
query_uncondition).contiguous() | |
key_uncondition = attn.head_to_batch_dim( | |
key_uncondition).contiguous() | |
value_uncondition = attn.head_to_batch_dim( | |
value_uncondition).contiguous() | |
hidden_states_uncondition = xformers.ops.memory_efficient_attention( | |
query_uncondition, key_uncondition, value_uncondition, attn_bias=None, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states_uncondition = hidden_states_uncondition.to( | |
query.dtype) | |
hidden_states_uncondition = attn.batch_to_head_dim( | |
hidden_states_uncondition) | |
hidden_states = torch.cat([rearrange(hidden_states_condition, "(B F) S C -> B F S C", F=self.num_frame_conditioning), rearrange( | |
hidden_states_uncondition, "(B F) S C -> B F S C", F=self.num_frames-self.num_frame_conditioning)], dim=1) | |
hidden_states = rearrange(hidden_states, "B F S C -> (B F) S C") | |
else: | |
query = attn.head_to_batch_dim(query).contiguous() | |
key = attn.head_to_batch_dim(key).contiguous() | |
value = attn.head_to_batch_dim(value).contiguous() | |
hidden_states = xformers.ops.memory_efficient_attention( | |
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |