Prgckwb's picture
Update app.py
6e60570 verified
import dataclasses
import warnings
warnings.filterwarnings("ignore")
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
import uuid
import torch.nn.functional as F
from PIL import Image
from pathlib import Path
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor, Attention
from rich import traceback
from torchvision.transforms.functional import to_tensor
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm import tqdm
import spaces
MODEL_ID = "CompVis/stable-diffusion-v1-4"
SEED = 1117
UNET_TIMESTEP = 1
traceback.install()
@dataclasses.dataclass
class AttentionStore:
index: int
query: torch.Tensor
key: torch.Tensor
value: torch.Tensor
attention_probs: torch.Tensor
class NewAttnProcessor(AttnProcessor):
def __init__(
self,
save_uncond_attention: bool = True,
save_cond_attention: bool = True,
max_cross_attention_maps: int = 64,
max_self_attention_maps: int = 64,
):
super().__init__()
self.save_uncond_attn = save_uncond_attention
self.save_cond_attn = save_cond_attention
self.max_cross_size = max_cross_attention_maps
self.max_self_size = max_self_attention_maps
self.cross_attention_stores = []
self.self_attention_stores = []
def _save_attention_store(
self,
is_cross: bool,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_probs: torch.Tensor
) -> None:
# Function to split tensors based on conditional probability
def split_tensors(tensor):
half_size = tensor.shape[0] // 2
return tensor[:half_size], tensor[half_size:]
# Split attention probabilities and q, k, v tensors
uncond_attn_probs, cond_attn_probs = split_tensors(attn_probs)
uncond_q, cond_q = split_tensors(q)
uncond_k, cond_k = split_tensors(k)
uncond_v, cond_v = split_tensors(v)
# Select tensors based on flags
if self.save_cond_attn and self.save_uncond_attn:
selected_probs, selected_q, selected_k, selected_v = attn_probs, q, k, v
elif self.save_cond_attn:
selected_probs, selected_q, selected_k, selected_v = cond_attn_probs, cond_q, cond_k, cond_v
elif self.save_uncond_attn:
selected_probs, selected_q, selected_k, selected_v = uncond_attn_probs, uncond_q, uncond_k, uncond_v
else:
return
# Determine max size based on attention type (cross or self)
max_size = self.max_cross_size if is_cross else self.max_self_size
# Filter out large attention maps
if selected_probs.shape[1] > max_size ** 2:
return
# Create and append attention store object
store = AttentionStore(
index=len(self.cross_attention_stores) if is_cross else len(self.self_attention_stores),
query=selected_q,
key=selected_k,
value=selected_v,
attention_probs=selected_probs
)
target_store = self.cross_attention_stores if is_cross else self.self_attention_stores
target_store.append(store)
return
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: torch.FloatTensor = None,
temb: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
is_cross_attention = encoder_hidden_states is not None
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# Save attention maps
self._save_attention_store(is_cross=is_cross_attention, q=query, k=key, v=value, attn_probs=attention_probs)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def reset_attention_stores(self) -> None:
self.cross_attention_stores = []
self.self_attention_stores = []
return
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer")
text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder").to(device)
unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet").to(device)
vae: AutoencoderKL = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae").to(device)
unet.set_attn_processor(
NewAttnProcessor(
save_uncond_attention=False,
save_cond_attention=True,
)
)
@spaces.GPU()
@torch.inference_mode()
def inference(
image_path: str,
prompt: str,
has_include_special_tokens: bool = False,
progress=gr.Progress(track_tqdm=False)):
progress(0, "Initializing...")
image = Image.open(image_path)
image = image.convert("RGB").resize((512, 512))
image = to_tensor(image).unsqueeze(0).to(device)
progress(0.1, "Generating text embeddings...")
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
).input_ids.to(device)
n_cond_tokens = len(
tokenizer(
prompt,
return_tensors="pt",
truncation=True,
).input_ids[0]
)
cond_text_embeddings = text_encoder(input_ids).last_hidden_state[0].to(device)
uncond_input_ids = tokenizer(
"",
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
).input_ids.to(device)
uncond_text_embeddings = text_encoder(uncond_input_ids).last_hidden_state[0].to(device)
text_embeddings = torch.stack([uncond_text_embeddings, cond_text_embeddings], dim=0)
progress(0.2, "Encoding the input image...")
init_image = image.to(device)
init_latent_dist = vae.encode(init_image).latent_dist
# Fix the random seed for reproducibility
progress(0.3, "Generating the latents...")
generator = torch.Generator(device=device).manual_seed(SEED)
latent = init_latent_dist.sample(generator=generator)
latent = latent * vae.config['scaling_factor'] # scaling_factor = 0.18215
latents = latent.expand(len(image), unet.config['in_channels'], 512 // 8, 512 // 8)
latents_input = torch.cat([latents] * 2).to(device)
progress(0.5, "Forwarding the UNet model...")
_ = unet(latents_input, UNET_TIMESTEP, encoder_hidden_states=text_embeddings)
attn_processor = next(iter(unet.attn_processors.values()))
cross_attention_stores = attn_processor.cross_attention_stores
progress(0.7, "Processing the cross attention maps...")
cross_attention_probs_list = []
# 事前に保存しておいた、全ての Cross-Attention 層の出力を取得
for i, cross_attn_store in enumerate(cross_attention_stores):
cross_attn_probs = cross_attn_store.attention_probs # (8, 8x8~64x64, 77)
n_heads, scale_pow, n_tokens = cross_attn_probs.shape
# scale: 8, 16, 32, 64
scale = int(np.sqrt(scale_pow))
# Multi-head Attentionの平均を取って、1つのAttention Mapにする
mean_cross_attn_probs = (
cross_attn_probs
.permute(0, 2, 1) # (8, 77, 8x8~64x64)
.reshape(n_heads, n_tokens, scale, scale) # (8, 77, 8~64, 8~64)
.mean(dim=0) # (77, 8~64, 8~64)
)
# scale を 全て 512x512 に合わせる
mean_cross_attn_probs = F.interpolate(
mean_cross_attn_probs.unsqueeze(0),
size=(512, 512),
mode='bilinear',
align_corners=True
).squeeze(0) # (77, 512, 512)
# <bos> と <eos> トークンの間に挿入されたトークンのみを取得
if has_include_special_tokens:
mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512)
else:
mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...] # (n_tokens-2, 512, 512)
cross_attention_probs_list.append(mean_cross_attn_probs)
# list -> torch.Tensor
cross_attention_probs = torch.stack(cross_attention_probs_list) # (16, n_classes, 512, 512)
n_layers, n_cond_tokens, _, _ = cross_attention_probs.shape
progress(0.9, "Post-processing the attention maps...")
image_list = []
# 各行ごとに画像を作成し保存
for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4))
for j in range(cross_attention_probs.shape[1]):
# 各クラスのアテンションマップを Min-Max 正規化 (0~1)
min_val = cross_attention_probs[i, j].min()
max_val = cross_attention_probs[i, j].max()
cross_attention_probs[i, j] = (cross_attention_probs[i, j] - min_val) / (max_val - min_val)
attn_probs = cross_attention_probs[i, j].cpu().detach().numpy()
ax[j].imshow(attn_probs, alpha=0.9)
ax[j].axis('off')
if has_include_special_tokens:
ax[j].set_title(tokenizer.decode(input_ids[0, j].item()))
else:
ax[j].set_title(tokenizer.decode(input_ids[0, j + 1].item()))
# 各行ごとの画像を保存
out_dir = Path("output")
out_dir.mkdir(exist_ok=True)
# 一意なランダムファイル名を生成
unique_filename = str(uuid.uuid4())
filepath = out_dir / f"{unique_filename}.png"
plt.savefig(filepath, bbox_inches='tight', pad_inches=0)
plt.close(fig)
# 保存した画像をPILで読み込んでリストに追加
image_list.append(Image.open(filepath))
attn_processor.reset_attention_stores()
return image_list
if __name__ == '__main__':
unet_mapping = [
"0: Down 64",
"1: Down 64",
"2: Down 32",
"3: Down 32",
"4: Down 16",
"5: Down 16",
"6: Mid 8",
"7: Up 16",
"8: Up 16",
"9: Up 16",
"10: Up 32",
"11: Up 32",
"12: Up 32",
"13: Up 64",
"14: Up 64",
"15: Up 64",
]
ca_output = [gr.Image(type="pil", label=unet_mapping[i]) for i in range(16)]
iface = gr.Interface(
title="Stable Diffusion Attention Visualizer",
description="This is a visualizer for the attention maps of the Stable Diffusion model. ",
fn=inference,
inputs=[
gr.Image(type="filepath", label="Input", width=512, height=512),
gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog..."),
gr.Checkbox(label="Include Special Tokens", value=False),
],
outputs=ca_output,
cache_examples=True,
examples=[
["assets/aeroplane.png", "plane background", False],
["assets/dogcat.png", "a photo of dog", False],
]
)
iface.launch()