|
import numpy as np |
|
import torch |
|
import math |
|
import xformers |
|
|
|
class DummyController: |
|
def __call__(self, *args): |
|
return args[0] |
|
def __init__(self): |
|
self.num_att_layers = 0 |
|
|
|
class GroupedCAController: |
|
def __init__(self, mask_list = None): |
|
self.mask_list = mask_list |
|
if self.mask_list is None: |
|
self.is_decom = False |
|
else: |
|
self.is_decom = True |
|
|
|
def mask_img_to_mask_vec(self, mask, length): |
|
mask_vec = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), (length, length)).squeeze() |
|
mask_vec = mask_vec.flatten() |
|
return mask_vec |
|
|
|
def ca_forward_decom(self, q, k_list, v_list, scale, place_in_unet): |
|
|
|
|
|
|
|
|
|
|
|
N = q.shape[1] |
|
mask_vec_list = [] |
|
for mask in self.mask_list: |
|
mask_vec = self.mask_img_to_mask_vec(mask, int(math.sqrt(N))) |
|
mask_vec = mask_vec.unsqueeze(0).unsqueeze(-1) |
|
mask_vec_list.append(mask_vec) |
|
out = 0 |
|
for mask_vec, k, v in zip(mask_vec_list, k_list, v_list): |
|
sim = torch.einsum("b i d, b j d -> b i j", q, k) * scale |
|
attn = sim.softmax(dim=-1) |
|
attn = attn.masked_fill(mask_vec==0, 0) |
|
masked_out = torch.einsum("b i j, b j d -> b i d", attn, v) |
|
|
|
|
|
out += masked_out |
|
return out |
|
|
|
def reshape_heads_to_batch_dim(self): |
|
def func(tensor): |
|
batch_size, seq_len, dim = tensor.shape |
|
head_size = self.num_heads |
|
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) |
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) |
|
return func |
|
|
|
def reshape_batch_dim_to_heads(self): |
|
def func(tensor): |
|
batch_size, seq_len, dim = tensor.shape |
|
head_size = self.num_heads |
|
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) |
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) |
|
return func |
|
|
|
def register_attention_disentangled_control(unet, controller): |
|
def ca_forward(self, place_in_unet): |
|
to_out = self.to_out |
|
if type(to_out) is torch.nn.modules.container.ModuleList: |
|
to_out = self.to_out[0] |
|
else: |
|
to_out = self.to_out |
|
def forward(x, encoder_hidden_states =None, attention_mask=None): |
|
if isinstance(controller, DummyController): |
|
q = self.to_q(x) |
|
is_cross = encoder_hidden_states is not None |
|
encoder_hidden_states = encoder_hidden_states if is_cross else x |
|
k = self.to_k(encoder_hidden_states) |
|
v = self.to_v(encoder_hidden_states) |
|
q = self.head_to_batch_dim(q) |
|
k = self.head_to_batch_dim(k) |
|
v = self.head_to_batch_dim(v) |
|
|
|
|
|
|
|
|
|
|
|
out = xformers.ops.memory_efficient_attention( |
|
q, k, v, attn_bias=None, op=None, scale=self.scale |
|
) |
|
out = self.batch_to_head_dim(out) |
|
else: |
|
is_cross = encoder_hidden_states is not None |
|
assert is_cross is not None |
|
encoder_hidden_states_list = encoder_hidden_states if is_cross else x |
|
q = self.to_q(x) |
|
q = self.head_to_batch_dim(q) |
|
if is_cross: |
|
k_list = [] |
|
v_list = [] |
|
assert type(encoder_hidden_states_list) is list |
|
for encoder_hidden_states in encoder_hidden_states_list: |
|
k = self.to_k(encoder_hidden_states) |
|
k = self.head_to_batch_dim(k) |
|
k_list.append(k) |
|
v = self.to_v(encoder_hidden_states) |
|
v = self.head_to_batch_dim(v) |
|
v_list.append(v) |
|
out = controller.ca_forward_decom(q, k_list, v_list, self.scale, place_in_unet) |
|
out = self.batch_to_head_dim(out) |
|
else: |
|
exit("decomposing SA!") |
|
k = self.to_k(x) |
|
v = self.to_v(x) |
|
k = self.head_to_batch_dim(k) |
|
v = self.head_to_batch_dim(v) |
|
import pdb; pdb.set_trace() |
|
if k.shape[1] <= 1024 ** 2: |
|
out = controller.sa_forward(q, k, v, self.scale, place_in_unet) |
|
else: |
|
print("warining") |
|
out = controller.sa_forward_decom(q, k, v, self.scale, place_in_unet) |
|
|
|
|
|
|
|
|
|
out = self.batch_to_head_dim(out) |
|
|
|
return to_out(out) |
|
|
|
return forward |
|
|
|
if controller is None: |
|
controller = DummyController() |
|
|
|
def register_recr(net_, count, place_in_unet): |
|
if net_.__class__.__name__ == 'Attention' and net_.to_k.in_features == unet.ca_dim: |
|
net_.forward = ca_forward(net_, place_in_unet) |
|
return count + 1 |
|
elif hasattr(net_, 'children'): |
|
for net__ in net_.children(): |
|
count = register_recr(net__, count, place_in_unet) |
|
return count |
|
|
|
cross_att_count = 0 |
|
sub_nets = unet.named_children() |
|
|
|
for net in sub_nets: |
|
if "down" in net[0]: |
|
down_count = register_recr(net[1], 0, "down") |
|
cross_att_count += down_count |
|
elif "up" in net[0]: |
|
up_count = register_recr(net[1], 0, "up") |
|
cross_att_count += up_count |
|
elif "mid" in net[0]: |
|
mid_count = register_recr(net[1], 0, "mid") |
|
cross_att_count += mid_count |
|
controller.num_att_layers = cross_att_count |
|
|