Spaces:
Runtime error
Runtime error
import dataclasses | |
import warnings | |
warnings.filterwarnings("ignore") | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
import uuid | |
import torch.nn.functional as F | |
from PIL import Image | |
from pathlib import Path | |
from diffusers import AutoencoderKL, UNet2DConditionModel | |
from diffusers.models.attention_processor import AttnProcessor, Attention | |
from rich import traceback | |
from torchvision.transforms.functional import to_tensor | |
from transformers import CLIPTokenizer, CLIPTextModel | |
from tqdm import tqdm | |
import spaces | |
MODEL_ID = "CompVis/stable-diffusion-v1-4" | |
SEED = 1117 | |
UNET_TIMESTEP = 1 | |
traceback.install() | |
class AttentionStore: | |
index: int | |
query: torch.Tensor | |
key: torch.Tensor | |
value: torch.Tensor | |
attention_probs: torch.Tensor | |
class NewAttnProcessor(AttnProcessor): | |
def __init__( | |
self, | |
save_uncond_attention: bool = True, | |
save_cond_attention: bool = True, | |
max_cross_attention_maps: int = 64, | |
max_self_attention_maps: int = 64, | |
): | |
super().__init__() | |
self.save_uncond_attn = save_uncond_attention | |
self.save_cond_attn = save_cond_attention | |
self.max_cross_size = max_cross_attention_maps | |
self.max_self_size = max_self_attention_maps | |
self.cross_attention_stores = [] | |
self.self_attention_stores = [] | |
def _save_attention_store( | |
self, | |
is_cross: bool, | |
q: torch.Tensor, | |
k: torch.Tensor, | |
v: torch.Tensor, | |
attn_probs: torch.Tensor | |
) -> None: | |
# Function to split tensors based on conditional probability | |
def split_tensors(tensor): | |
half_size = tensor.shape[0] // 2 | |
return tensor[:half_size], tensor[half_size:] | |
# Split attention probabilities and q, k, v tensors | |
uncond_attn_probs, cond_attn_probs = split_tensors(attn_probs) | |
uncond_q, cond_q = split_tensors(q) | |
uncond_k, cond_k = split_tensors(k) | |
uncond_v, cond_v = split_tensors(v) | |
# Select tensors based on flags | |
if self.save_cond_attn and self.save_uncond_attn: | |
selected_probs, selected_q, selected_k, selected_v = attn_probs, q, k, v | |
elif self.save_cond_attn: | |
selected_probs, selected_q, selected_k, selected_v = cond_attn_probs, cond_q, cond_k, cond_v | |
elif self.save_uncond_attn: | |
selected_probs, selected_q, selected_k, selected_v = uncond_attn_probs, uncond_q, uncond_k, uncond_v | |
else: | |
return | |
# Determine max size based on attention type (cross or self) | |
max_size = self.max_cross_size if is_cross else self.max_self_size | |
# Filter out large attention maps | |
if selected_probs.shape[1] > max_size ** 2: | |
return | |
# Create and append attention store object | |
store = AttentionStore( | |
index=len(self.cross_attention_stores) if is_cross else len(self.self_attention_stores), | |
query=selected_q, | |
key=selected_k, | |
value=selected_v, | |
attention_probs=selected_probs | |
) | |
target_store = self.cross_attention_stores if is_cross else self.self_attention_stores | |
target_store.append(store) | |
return | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.FloatTensor, | |
encoder_hidden_states: torch.FloatTensor = None, | |
attention_mask: torch.FloatTensor = None, | |
temb: torch.FloatTensor = None, | |
*args, | |
**kwargs, | |
) -> torch.Tensor: | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
query = attn.to_q(hidden_states) | |
is_cross_attention = encoder_hidden_states is not None | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
# Save attention maps | |
self._save_attention_store(is_cross=is_cross_attention, q=query, k=key, v=value, attn_probs=attention_probs) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
def reset_attention_stores(self) -> None: | |
self.cross_attention_stores = [] | |
self.self_attention_stores = [] | |
return | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer") | |
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder").to(device) | |
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet").to(device) | |
vae: AutoencoderKL = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae").to(device) | |
unet.set_attn_processor( | |
NewAttnProcessor( | |
save_uncond_attention=False, | |
save_cond_attention=True, | |
) | |
) | |
def inference( | |
image_path: str, | |
prompt: str, | |
has_include_special_tokens: bool = False, | |
progress=gr.Progress(track_tqdm=False)): | |
progress(0, "Initializing...") | |
image = Image.open(image_path) | |
image = image.convert("RGB").resize((512, 512)) | |
image = to_tensor(image).unsqueeze(0).to(device) | |
progress(0.1, "Generating text embeddings...") | |
input_ids = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=tokenizer.model_max_length, | |
).input_ids.to(device) | |
n_cond_tokens = len( | |
tokenizer( | |
prompt, | |
return_tensors="pt", | |
truncation=True, | |
).input_ids[0] | |
) | |
cond_text_embeddings = text_encoder(input_ids).last_hidden_state[0].to(device) | |
uncond_input_ids = tokenizer( | |
"", | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=tokenizer.model_max_length, | |
).input_ids.to(device) | |
uncond_text_embeddings = text_encoder(uncond_input_ids).last_hidden_state[0].to(device) | |
text_embeddings = torch.stack([uncond_text_embeddings, cond_text_embeddings], dim=0) | |
progress(0.2, "Encoding the input image...") | |
init_image = image.to(device) | |
init_latent_dist = vae.encode(init_image).latent_dist | |
# Fix the random seed for reproducibility | |
progress(0.3, "Generating the latents...") | |
generator = torch.Generator(device=device).manual_seed(SEED) | |
latent = init_latent_dist.sample(generator=generator) | |
latent = latent * vae.config['scaling_factor'] # scaling_factor = 0.18215 | |
latents = latent.expand(len(image), unet.config['in_channels'], 512 // 8, 512 // 8) | |
latents_input = torch.cat([latents] * 2).to(device) | |
progress(0.5, "Forwarding the UNet model...") | |
_ = unet(latents_input, UNET_TIMESTEP, encoder_hidden_states=text_embeddings) | |
attn_processor = next(iter(unet.attn_processors.values())) | |
cross_attention_stores = attn_processor.cross_attention_stores | |
progress(0.7, "Processing the cross attention maps...") | |
cross_attention_probs_list = [] | |
# 事前に保存しておいた、全ての Cross-Attention 層の出力を取得 | |
for i, cross_attn_store in enumerate(cross_attention_stores): | |
cross_attn_probs = cross_attn_store.attention_probs # (8, 8x8~64x64, 77) | |
n_heads, scale_pow, n_tokens = cross_attn_probs.shape | |
# scale: 8, 16, 32, 64 | |
scale = int(np.sqrt(scale_pow)) | |
# Multi-head Attentionの平均を取って、1つのAttention Mapにする | |
mean_cross_attn_probs = ( | |
cross_attn_probs | |
.permute(0, 2, 1) # (8, 77, 8x8~64x64) | |
.reshape(n_heads, n_tokens, scale, scale) # (8, 77, 8~64, 8~64) | |
.mean(dim=0) # (77, 8~64, 8~64) | |
) | |
# scale を 全て 512x512 に合わせる | |
mean_cross_attn_probs = F.interpolate( | |
mean_cross_attn_probs.unsqueeze(0), | |
size=(512, 512), | |
mode='bilinear', | |
align_corners=True | |
).squeeze(0) # (77, 512, 512) | |
# <bos> と <eos> トークンの間に挿入されたトークンのみを取得 | |
if has_include_special_tokens: | |
mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512) | |
else: | |
mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...] # (n_tokens-2, 512, 512) | |
cross_attention_probs_list.append(mean_cross_attn_probs) | |
# list -> torch.Tensor | |
cross_attention_probs = torch.stack(cross_attention_probs_list) # (16, n_classes, 512, 512) | |
n_layers, n_cond_tokens, _, _ = cross_attention_probs.shape | |
progress(0.9, "Post-processing the attention maps...") | |
image_list = [] | |
# 各行ごとに画像を作成し保存 | |
for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."): | |
fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4)) | |
for j in range(cross_attention_probs.shape[1]): | |
# 各クラスのアテンションマップを Min-Max 正規化 (0~1) | |
min_val = cross_attention_probs[i, j].min() | |
max_val = cross_attention_probs[i, j].max() | |
cross_attention_probs[i, j] = (cross_attention_probs[i, j] - min_val) / (max_val - min_val) | |
attn_probs = cross_attention_probs[i, j].cpu().detach().numpy() | |
ax[j].imshow(attn_probs, alpha=0.9) | |
ax[j].axis('off') | |
if has_include_special_tokens: | |
ax[j].set_title(tokenizer.decode(input_ids[0, j].item())) | |
else: | |
ax[j].set_title(tokenizer.decode(input_ids[0, j + 1].item())) | |
# 各行ごとの画像を保存 | |
out_dir = Path("output") | |
out_dir.mkdir(exist_ok=True) | |
# 一意なランダムファイル名を生成 | |
unique_filename = str(uuid.uuid4()) | |
filepath = out_dir / f"{unique_filename}.png" | |
plt.savefig(filepath, bbox_inches='tight', pad_inches=0) | |
plt.close(fig) | |
# 保存した画像をPILで読み込んでリストに追加 | |
image_list.append(Image.open(filepath)) | |
attn_processor.reset_attention_stores() | |
return image_list | |
if __name__ == '__main__': | |
unet_mapping = [ | |
"0: Down 64", | |
"1: Down 64", | |
"2: Down 32", | |
"3: Down 32", | |
"4: Down 16", | |
"5: Down 16", | |
"6: Mid 8", | |
"7: Up 16", | |
"8: Up 16", | |
"9: Up 16", | |
"10: Up 32", | |
"11: Up 32", | |
"12: Up 32", | |
"13: Up 64", | |
"14: Up 64", | |
"15: Up 64", | |
] | |
ca_output = [gr.Image(type="pil", label=unet_mapping[i]) for i in range(16)] | |
iface = gr.Interface( | |
title="Stable Diffusion Attention Visualizer", | |
description="This is a visualizer for the attention maps of the Stable Diffusion model. ", | |
fn=inference, | |
inputs=[ | |
gr.Image(type="filepath", label="Input", width=512, height=512), | |
gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog..."), | |
gr.Checkbox(label="Include Special Tokens", value=False), | |
], | |
outputs=ca_output, | |
cache_examples=True, | |
examples=[ | |
["assets/aeroplane.png", "plane background", False], | |
["assets/dogcat.png", "a photo of dog", False], | |
] | |
) | |
iface.launch() | |