Spaces:
Runtime error
Runtime error
| import dataclasses | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import uuid | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from pathlib import Path | |
| from diffusers import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.models.attention_processor import AttnProcessor, Attention | |
| from rich import traceback | |
| from torchvision.transforms.functional import to_tensor | |
| from transformers import CLIPTokenizer, CLIPTextModel | |
| from tqdm import tqdm | |
| MODEL_ID = "CompVis/stable-diffusion-v1-4" | |
| SEED = 1117 | |
| UNET_TIMESTEP = 1 | |
| traceback.install() | |
| class AttentionStore: | |
| index: int | |
| query: torch.Tensor | |
| key: torch.Tensor | |
| value: torch.Tensor | |
| attention_probs: torch.Tensor | |
| class NewAttnProcessor(AttnProcessor): | |
| def __init__( | |
| self, | |
| save_uncond_attention: bool = True, | |
| save_cond_attention: bool = True, | |
| max_cross_attention_maps: int = 64, | |
| max_self_attention_maps: int = 64, | |
| ): | |
| super().__init__() | |
| self.save_uncond_attn = save_uncond_attention | |
| self.save_cond_attn = save_cond_attention | |
| self.max_cross_size = max_cross_attention_maps | |
| self.max_self_size = max_self_attention_maps | |
| self.cross_attention_stores = [] | |
| self.self_attention_stores = [] | |
| def _save_attention_store( | |
| self, | |
| is_cross: bool, | |
| q: torch.Tensor, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| attn_probs: torch.Tensor | |
| ) -> None: | |
| # Function to split tensors based on conditional probability | |
| def split_tensors(tensor): | |
| half_size = tensor.shape[0] // 2 | |
| return tensor[:half_size], tensor[half_size:] | |
| # Split attention probabilities and q, k, v tensors | |
| uncond_attn_probs, cond_attn_probs = split_tensors(attn_probs) | |
| uncond_q, cond_q = split_tensors(q) | |
| uncond_k, cond_k = split_tensors(k) | |
| uncond_v, cond_v = split_tensors(v) | |
| # Select tensors based on flags | |
| if self.save_cond_attn and self.save_uncond_attn: | |
| selected_probs, selected_q, selected_k, selected_v = attn_probs, q, k, v | |
| elif self.save_cond_attn: | |
| selected_probs, selected_q, selected_k, selected_v = cond_attn_probs, cond_q, cond_k, cond_v | |
| elif self.save_uncond_attn: | |
| selected_probs, selected_q, selected_k, selected_v = uncond_attn_probs, uncond_q, uncond_k, uncond_v | |
| else: | |
| return | |
| # Determine max size based on attention type (cross or self) | |
| max_size = self.max_cross_size if is_cross else self.max_self_size | |
| # Filter out large attention maps | |
| if selected_probs.shape[1] > max_size ** 2: | |
| return | |
| # Create and append attention store object | |
| store = AttentionStore( | |
| index=len(self.cross_attention_stores) if is_cross else len(self.self_attention_stores), | |
| query=selected_q, | |
| key=selected_k, | |
| value=selected_v, | |
| attention_probs=selected_probs | |
| ) | |
| target_store = self.cross_attention_stores if is_cross else self.self_attention_stores | |
| target_store.append(store) | |
| return | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: torch.FloatTensor = None, | |
| attention_mask: torch.FloatTensor = None, | |
| temb: torch.FloatTensor = None, | |
| *args, | |
| **kwargs, | |
| ) -> torch.Tensor: | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| is_cross_attention = encoder_hidden_states is not None | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| # Save attention maps | |
| self._save_attention_store(is_cross=is_cross_attention, q=query, k=key, v=value, attn_probs=attention_probs) | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| # linear proj | |
| hidden_states = attn.to_out[0](hidden_states) | |
| # dropout | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def reset_attention_stores(self) -> None: | |
| self.cross_attention_stores = [] | |
| self.self_attention_stores = [] | |
| return | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer") | |
| text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder").to(device) | |
| unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet").to(device) | |
| vae: AutoencoderKL = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae").to(device) | |
| unet.set_attn_processor( | |
| NewAttnProcessor( | |
| save_uncond_attention=False, | |
| save_cond_attention=True, | |
| ) | |
| ) | |
| def inference( | |
| image_path: str, | |
| prompt: str, | |
| has_include_special_tokens: bool = False, | |
| progress=gr.Progress(track_tqdm=False)): | |
| progress(0, "Initializing...") | |
| image = Image.open(image_path) | |
| image = image.convert("RGB").resize((512, 512)) | |
| image = to_tensor(image).unsqueeze(0).to(device) | |
| progress(0.1, "Generating text embeddings...") | |
| input_ids = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=tokenizer.model_max_length, | |
| ).input_ids.to(device) | |
| n_cond_tokens = len( | |
| tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| ).input_ids[0] | |
| ) | |
| cond_text_embeddings = text_encoder(input_ids).last_hidden_state[0].to(device) | |
| uncond_input_ids = tokenizer( | |
| "", | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=tokenizer.model_max_length, | |
| ).input_ids.to(device) | |
| uncond_text_embeddings = text_encoder(uncond_input_ids).last_hidden_state[0].to(device) | |
| text_embeddings = torch.stack([uncond_text_embeddings, cond_text_embeddings], dim=0) | |
| progress(0.2, "Encoding the input image...") | |
| init_image = image.to(device) | |
| init_latent_dist = vae.encode(init_image).latent_dist | |
| # Fix the random seed for reproducibility | |
| progress(0.3, "Generating the latents...") | |
| generator = torch.Generator(device=device).manual_seed(SEED) | |
| latent = init_latent_dist.sample(generator=generator) | |
| latent = latent * vae.config['scaling_factor'] # scaling_factor = 0.18215 | |
| latents = latent.expand(len(image), unet.config['in_channels'], 512 // 8, 512 // 8) | |
| latents_input = torch.cat([latents] * 2).to(device) | |
| progress(0.5, "Forwarding the UNet model...") | |
| _ = unet(latents_input, UNET_TIMESTEP, encoder_hidden_states=text_embeddings) | |
| attn_processor = next(iter(unet.attn_processors.values())) | |
| cross_attention_stores = attn_processor.cross_attention_stores | |
| progress(0.7, "Processing the cross attention maps...") | |
| cross_attention_probs_list = [] | |
| # 事前に保存しておいた、全ての Cross-Attention 層の出力を取得 | |
| for i, cross_attn_store in enumerate(cross_attention_stores): | |
| cross_attn_probs = cross_attn_store.attention_probs # (8, 8x8~64x64, 77) | |
| n_heads, scale_pow, n_tokens = cross_attn_probs.shape | |
| # scale: 8, 16, 32, 64 | |
| scale = int(np.sqrt(scale_pow)) | |
| # Multi-head Attentionの平均を取って、1つのAttention Mapにする | |
| mean_cross_attn_probs = ( | |
| cross_attn_probs | |
| .permute(0, 2, 1) # (8, 77, 8x8~64x64) | |
| .reshape(n_heads, n_tokens, scale, scale) # (8, 77, 8~64, 8~64) | |
| .mean(dim=0) # (77, 8~64, 8~64) | |
| ) | |
| # scale を 全て 512x512 に合わせる | |
| mean_cross_attn_probs = F.interpolate( | |
| mean_cross_attn_probs.unsqueeze(0), | |
| size=(512, 512), | |
| mode='bilinear', | |
| align_corners=True | |
| ).squeeze(0) # (77, 512, 512) | |
| # <bos> と <eos> トークンの間に挿入されたトークンのみを取得 | |
| if has_include_special_tokens: | |
| mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512) | |
| else: | |
| mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...] # (n_tokens-2, 512, 512) | |
| cross_attention_probs_list.append(mean_cross_attn_probs) | |
| # list -> torch.Tensor | |
| cross_attention_probs = torch.stack(cross_attention_probs_list) # (16, n_classes, 512, 512) | |
| n_layers, n_cond_tokens, _, _ = cross_attention_probs.shape | |
| progress(0.9, "Post-processing the attention maps...") | |
| image_list = [] | |
| # 各行ごとに画像を作成し保存 | |
| for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."): | |
| fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4)) | |
| for j in range(cross_attention_probs.shape[1]): | |
| # 各クラスのアテンションマップを Min-Max 正規化 (0~1) | |
| min_val = cross_attention_probs[i, j].min() | |
| max_val = cross_attention_probs[i, j].max() | |
| cross_attention_probs[i, j] = (cross_attention_probs[i, j] - min_val) / (max_val - min_val) | |
| attn_probs = cross_attention_probs[i, j].cpu().detach().numpy() | |
| ax[j].imshow(attn_probs, alpha=0.9) | |
| ax[j].axis('off') | |
| ax[j].set_title(tokenizer.decode(input_ids[0, j].item())) | |
| # 各行ごとの画像を保存 | |
| out_dir = Path("output") | |
| out_dir.mkdir(exist_ok=True) | |
| # 一意なランダムファイル名を生成 | |
| unique_filename = str(uuid.uuid4()) | |
| filepath = out_dir / f"{unique_filename}.png" | |
| plt.savefig(filepath, bbox_inches='tight', pad_inches=0) | |
| plt.close(fig) | |
| # 保存した画像をPILで読み込んでリストに追加 | |
| image_list.append(Image.open(filepath)) | |
| attn_processor.reset_attention_stores() | |
| return image_list | |
| if __name__ == '__main__': | |
| unet_mapping = [ | |
| "0: Down 64", | |
| "1: Down 64", | |
| "2: Down 32", | |
| "3: Down 32", | |
| "4: Down 16", | |
| "5: Down 16", | |
| "6: Mid 8", | |
| "7: Up 16", | |
| "8: Up 16", | |
| "9: Up 16", | |
| "10: Up 32", | |
| "11: Up 32", | |
| "12: Up 32", | |
| "13: Up 64", | |
| "14: Up 64", | |
| "15: Up 64", | |
| ] | |
| ca_output = [gr.Image(type="pil", label=unet_mapping[i]) for i in range(16)] | |
| iface = gr.Interface( | |
| title="Stable Diffusion Attention Visualizer", | |
| description="This is a visualizer for the attention maps of the Stable Diffusion model. ", | |
| fn=inference, | |
| inputs=[ | |
| gr.Image(type="filepath", label="Input", width=512, height=512), | |
| gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog..."), | |
| gr.Checkbox(label="Include Special Tokens", value=False), | |
| ], | |
| outputs=ca_output, | |
| cache_examples=True, | |
| examples=[ | |
| ["assets/aeroplane.png", "plane background", False], | |
| ["assets/dogcat.png", "a photo of dog", False], | |
| ] | |
| ) | |
| iface.launch() | |