wooyeolbaek's picture
Add app.py, utils.py
0c1540a
import os
import math
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from diffusers.utils import deprecate
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
AttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0
)
attn_maps = {}
def attn_call(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
####################################################################################################
# (20,4096,77) or (40,1024,77)
if hasattr(self, "store_attn_map"):
self.attn_map = attention_probs
####################################################################################################
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias.to(attn_weight.device)
attn_weight = torch.softmax(attn_weight, dim=-1)
return torch.dropout(attn_weight, dropout_p, train=True) @ value, attn_weight
def attn_call2_0(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale: float = 1.0,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
####################################################################################################
# if self.store_attn_map:
if hasattr(self, "store_attn_map"):
hidden_states, attn_map = scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# (2,10,4096,77) or (2,20,1024,77)
self.attn_map = attn_map
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
####################################################################################################
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def lora_attn_call(self, attn: Attention, hidden_states, *args, **kwargs):
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor()
if hasattr(self, "store_attn_map"):
attn.processor.store_attn_map = True
return attn.processor(attn, hidden_states, *args, **kwargs)
def lora_attn_call2_0(self, attn: Attention, hidden_states, *args, **kwargs):
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor2_0()
if hasattr(self, "store_attn_map"):
attn.processor.store_attn_map = True
return attn.processor(attn, hidden_states, *args, **kwargs)
def cross_attn_init():
AttnProcessor.__call__ = attn_call
AttnProcessor2_0.__call__ = attn_call # attn_call is faster
# AttnProcessor2_0.__call__ = attn_call2_0
LoRAAttnProcessor.__call__ = lora_attn_call
# LoRAAttnProcessor2_0.__call__ = lora_attn_call2_0
LoRAAttnProcessor2_0.__call__ = lora_attn_call
def reshape_attn_map(attn_map):
attn_map = torch.mean(attn_map,dim=0) # mean by head dim: (20,4096,77) -> (4096,77)
attn_map = attn_map.permute(1,0) # (4096,77) -> (77,4096)
latent_size = int(math.sqrt(attn_map.shape[1]))
latent_shape = (attn_map.shape[0],latent_size,-1)
attn_map = attn_map.reshape(latent_shape) # (77,4096) -> (77,64,64)
return attn_map # torch.sum(attn_map,dim=0) = [1,1,...,1]
def hook_fn(name):
def forward_hook(module, input, output):
if hasattr(module.processor, "attn_map"):
attn_maps[name] = module.processor.attn_map
del module.processor.attn_map
return forward_hook
def register_cross_attention_hook(unet):
for name, module in unet.named_modules():
if not name.split('.')[-1].startswith('attn2'):
continue
if isinstance(module.processor, AttnProcessor):
module.processor.store_attn_map = True
elif isinstance(module.processor, AttnProcessor2_0):
module.processor.store_attn_map = True
elif isinstance(module.processor, LoRAAttnProcessor):
module.processor.store_attn_map = True
elif isinstance(module.processor, LoRAAttnProcessor2_0):
module.processor.store_attn_map = True
hook = module.register_forward_hook(hook_fn(name))
return unet
def prompt2tokens(tokenizer, prompt):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
tokens = []
for text_input_id in text_input_ids[0]:
token = tokenizer.decoder[text_input_id.item()]
tokens.append(token)
return tokens
# TODO: generalize for rectangle images
def upscale(attn_map, target_size):
attn_map = torch.mean(attn_map, dim=0) # (10, 32*32, 77) -> (32*32, 77)
attn_map = attn_map.permute(1,0) # (32*32, 77) -> (77, 32*32)
if target_size[0]*target_size[1] != attn_map.shape[1]:
temp_size = (target_size[0]//2, target_size[1]//2)
attn_map = attn_map.view(attn_map.shape[0], *temp_size) # (77, 32,32)
attn_map = attn_map.unsqueeze(0) # (77,32,32) -> (1,77,32,32)
attn_map = F.interpolate(
attn_map.to(dtype=torch.float32),
size=target_size,
mode='bilinear',
align_corners=False
).squeeze() # (77,64,64)
else:
attn_map = attn_map.to(dtype=torch.float32) # (77,64,64)
attn_map = torch.softmax(attn_map, dim=0)
attn_map = attn_map.reshape(attn_map.shape[0],-1) # (77,64*64)
return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):
target_size = (image_size[0]//16, image_size[1]//16)
idx = 0 if instance_or_negative else 1
net_attn_maps = []
for name, attn_map in attn_maps.items():
attn_map = attn_map.cpu() if detach else attn_map
attn_map = torch.chunk(attn_map, batch_size)[idx] # (20, 32*32, 77) -> (10, 32*32, 77) # negative & positive CFG
if len(attn_map.shape) == 4:
attn_map = attn_map.squeeze()
attn_map = upscale(attn_map, target_size) # (10,32*32,77) -> (77,64*64)
net_attn_maps.append(attn_map) # (10,32*32,77) -> (77,64*64)
net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)
net_attn_maps = net_attn_maps.reshape(net_attn_maps.shape[0], 64,64) # (77,64*64) -> (77,64,64)
return net_attn_maps
def save_net_attn_map(net_attn_maps, dir_name, tokenizer, prompt):
if not os.path.exists(dir_name):
os.makedirs(dir_name)
tokens = prompt2tokens(tokenizer, prompt)
total_attn_scores = 0
for i, (token, attn_map) in enumerate(zip(tokens, net_attn_maps)):
attn_map_score = torch.sum(attn_map)
attn_map = attn_map.cpu().numpy()
h,w = attn_map.shape
attn_map_total = h*w
attn_map_score = attn_map_score / attn_map_total
total_attn_scores += attn_map_score
token = token.replace('</w>','')
save_attn_map(
attn_map,
f'{token}:{attn_map_score:.2f}',
f"{dir_name}/{i}_<{token}>:{int(attn_map_score*100)}.png"
)
print(f'total_attn_scores: {total_attn_scores}')
def resize_net_attn_map(net_attn_maps, target_size):
net_attn_maps = F.interpolate(
net_attn_maps.to(dtype=torch.float32).unsqueeze(0),
size=target_size,
mode='bilinear',
align_corners=False
).squeeze() # (77,64,64)
return net_attn_maps
def save_attn_map(attn_map, title, save_path):
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
image = Image.fromarray(normalized_attn_map)
image.save(save_path, format='PNG', compression=0)
def return_net_attn_map(net_attn_maps, tokenizer, prompt):
tokens = prompt2tokens(tokenizer, prompt)
total_attn_scores = 0
images = []
for i, (token, attn_map) in enumerate(zip(tokens, net_attn_maps)):
attn_map_score = torch.sum(attn_map)
h,w = attn_map.shape
attn_map_total = h*w
attn_map_score = attn_map_score / attn_map_total
total_attn_scores += attn_map_score
attn_map = attn_map.cpu().numpy()
normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
normalized_attn_map = normalized_attn_map.astype(np.uint8)
image = Image.fromarray(normalized_attn_map)
token = token.replace('</w>','')
images.append((image,f"{i}_<{token}>"))
print(f'total_attn_scores: {total_attn_scores}')
return images