Spaces:
Build error
Build error
aningineer
commited on
Commit
•
5c4b5eb
1
Parent(s):
6f5b8d4
Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- README.md +41 -8
- __pycache__/merge.cpython-310.pyc +0 -0
- __pycache__/utils.cpython-310.pyc +0 -0
- app.py +124 -0
- compare3.png +3 -0
- merge.py +385 -0
- requirements.txt +3 -0
- test_notebook.ipynb +3 -0
- utils.py +80 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
compare3.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
test_notebook.ipynb filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,13 +1,46 @@
|
|
1 |
---
|
2 |
title: ToDo
|
3 |
-
emoji: 🏃
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: pink
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.19.2
|
8 |
app_file: app.py
|
9 |
-
|
10 |
-
|
11 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: ToDo
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 3.50.2
|
6 |
---
|
7 |
+
# ImprovedTokenMerge
|
8 |
+
![compare3.png](compare3.png)
|
9 |
+
![GEuoFn1bMAABQqD](https://github.com/ethansmith2000/ImprovedTokenMerge/assets/98723285/82e03423-81e6-47da-afa4-9c1b2c1c4aeb)
|
10 |
+
|
11 |
+
twitter thread explanation: https://twitter.com/Ethan_smith_20/status/1750533558509433137
|
12 |
+
|
13 |
+
heavily inspired by https://github.com/dbolya/tomesd by @dbolya, a big thanks to the original authors.
|
14 |
+
|
15 |
+
This project aims to adress some of the shortcomings of Token Merging for Stable Diffusion. Namely consistenly faster inference without quality loss.
|
16 |
+
I found with the original that you would have to use a high merging ratio to get really any speedups at all, and by then quality was tarnished. Benchmarks here: https://github.com/dbolya/tomesd/issues/19#issuecomment-1507593483
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
I propose two changes to the original to solve this.
|
21 |
+
1. Merging Method
|
22 |
+
- the original calculates a similarity matrix of the input tokens and merges those with highest similarity
|
23 |
+
- an issue here is that similarity calculation is O(n2) time, for ViT where token merging was proposed, you only had to do this a few times so it was quite efficient
|
24 |
+
- here it needs to be done at every step, and the computation ends up being nearly as costly as attention itself
|
25 |
+
- We can leverage a simple obsevation that nearby tokens tend to be similar to each other.
|
26 |
+
- therefore we can merge tokens via downsampling which is very cheap and seems to be a good approximation
|
27 |
+
- this can be analogized to grid-based subsampling of an image when using a nearest-neighbor downsample method, this is similar to what DiNAT (dilated neigborhood attention) does except for the fact we are still making use of global context
|
28 |
+
2. Merge Targets
|
29 |
+
- the original merges the input tokens to attention, and then "unmerges" the resulting tokens to the original size
|
30 |
+
- this operation seems to be quite lossy
|
31 |
+
- instead i propose simply downsampling keys/values of the attention operation. both the QK calculation and QK * V can still drastically be reduced from the typical O(n2) scaling of attention, without needing to unmerge anything
|
32 |
+
- queries are left fully intact, they just attend more sparsely to the image
|
33 |
+
- attention for images, especially at larger resolutions, seems to be very sparse in general (QK matrix is low rank) so it does not appear that we lose too much from this
|
34 |
+
|
35 |
+
putting this altogether we can get tangible speedups of ~1.5x at typical sizes like 768-1024 and up to 3x and beyond at 1536 to 2048 range, in combination with flash attention
|
36 |
+
|
37 |
+
|
38 |
+
# Setup 🛠
|
39 |
+
```
|
40 |
+
pip install -r requirements.txt
|
41 |
+
```
|
42 |
+
|
43 |
+
# Inference 🚀
|
44 |
+
See the provided notebook, or gradio demo which you can run with python app.py
|
45 |
+
|
46 |
|
|
__pycache__/merge.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.66 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
import diffusers
|
5 |
+
from utils import patch_attention_proc
|
6 |
+
import math
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
|
11 |
+
pipe.enable_xformers_memory_efficient_attention()
|
12 |
+
pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
13 |
+
pipe.safety_checker = None
|
14 |
+
|
15 |
+
with gr.Blocks() as demo:
|
16 |
+
prompt = gr.Textbox(interactive=True, label="prompt")
|
17 |
+
negative_prompt = gr.Textbox(interactive=True, label="negative_prompt")
|
18 |
+
method = gr.Dropdown(["todo", "tome"], value="todo", label="method", info="Choose Your Desired Method (Default: todo)")
|
19 |
+
height_width = gr.Dropdown([1024, 1536, 2048], value=1024, label="height/width", info="Choose Your Desired Height/Width (Default: 1024)")
|
20 |
+
# height = gr.Number(label="height", value=1024, precision=0)
|
21 |
+
# width = gr.Number(label="width", value=1024, precision=0)
|
22 |
+
guidance_scale = gr.Number(label="guidance_scale", value=7.5, precision=1)
|
23 |
+
steps = gr.Number(label="steps", value=20, precision=0)
|
24 |
+
seed = gr.Number(label="seed", value=1, precision=0)
|
25 |
+
result = gr.Textbox(label="Result")
|
26 |
+
|
27 |
+
output_image = gr.Image(label=f"output_image", type="pil", interactive=False)
|
28 |
+
|
29 |
+
gen = gr.Button("generate")
|
30 |
+
|
31 |
+
def which_image(img, target_val=253, width=1024):
|
32 |
+
npimg = np.array(img)
|
33 |
+
loc = np.where(npimg[:, :, 3] == target_val)[1].item()
|
34 |
+
if loc > width:
|
35 |
+
print("Right Image is merged!")
|
36 |
+
else:
|
37 |
+
print("Left Image is merged!")
|
38 |
+
|
39 |
+
|
40 |
+
def generate(prompt, seed, steps, height_width, negative_prompt, guidance_scale, method):
|
41 |
+
|
42 |
+
pipe.enable_xformers_memory_efficient_attention()
|
43 |
+
|
44 |
+
downsample_factor = 2
|
45 |
+
ratio = 0.38
|
46 |
+
merge_method = "downsample" if method == "todo" else "similarity"
|
47 |
+
merge_tokens = "keys/values" if method == "todo" else "all"
|
48 |
+
|
49 |
+
if height_width == 1024:
|
50 |
+
downsample_factor = 2
|
51 |
+
ratio = 0.75
|
52 |
+
downsample_factor_level_2 = 1
|
53 |
+
ratio_level_2 = 0.0
|
54 |
+
elif height_width == 1536:
|
55 |
+
downsample_factor = 3
|
56 |
+
ratio = 0.89
|
57 |
+
downsample_factor_level_2 = 1
|
58 |
+
ratio_level_2 = 0.0
|
59 |
+
elif height_width == 2048:
|
60 |
+
downsample_factor = 4
|
61 |
+
ratio = 0.9375
|
62 |
+
downsample_factor_level_2 = 2
|
63 |
+
ratio_level_2 = 0.75
|
64 |
+
|
65 |
+
token_merge_args = {"ratio": ratio,
|
66 |
+
"merge_tokens": merge_tokens,
|
67 |
+
"merge_method": merge_method,
|
68 |
+
"downsample_method": "nearest",
|
69 |
+
"downsample_factor": downsample_factor,
|
70 |
+
"timestep_threshold_switch": 0.0,
|
71 |
+
"timestep_threshold_stop": 0.0,
|
72 |
+
"downsample_factor_level_2": downsample_factor_level_2,
|
73 |
+
"ratio_level_2": ratio_level_2
|
74 |
+
}
|
75 |
+
|
76 |
+
l_r = torch.rand(1).item()
|
77 |
+
torch.manual_seed(seed)
|
78 |
+
start_time_base = time.time()
|
79 |
+
base_img = pipe(prompt,
|
80 |
+
num_inference_steps=steps, height=height_width, width=height_width,
|
81 |
+
negative_prompt=negative_prompt,
|
82 |
+
guidance_scale=guidance_scale).images[0]
|
83 |
+
end_time_base = time.time()
|
84 |
+
|
85 |
+
patch_attention_proc(pipe.unet, token_merge_args=token_merge_args)
|
86 |
+
|
87 |
+
torch.manual_seed(seed)
|
88 |
+
start_time_merge = time.time()
|
89 |
+
merged_img = pipe(prompt,
|
90 |
+
num_inference_steps=steps, height=height_width, width=height_width,
|
91 |
+
negative_prompt=negative_prompt,
|
92 |
+
guidance_scale=guidance_scale).images[0]
|
93 |
+
end_time_merge = time.time()
|
94 |
+
|
95 |
+
base_img = base_img.convert("RGBA")
|
96 |
+
merged_img = merged_img.convert("RGBA")
|
97 |
+
merged_img = np.array(merged_img)
|
98 |
+
halfh, halfw = height_width // 2, height_width // 2
|
99 |
+
merged_img[halfh, halfw, 3] = 253 # set the center pixel of the merged image to be ever so slightly below 255 in alpha channel
|
100 |
+
merged_img = Image.fromarray(merged_img)
|
101 |
+
final_img = Image.new(size=(height_width * 2, height_width), mode="RGBA")
|
102 |
+
|
103 |
+
if l_r > 0.5:
|
104 |
+
left_img = base_img
|
105 |
+
right_img = merged_img
|
106 |
+
else:
|
107 |
+
left_img = merged_img
|
108 |
+
right_img = base_img
|
109 |
+
|
110 |
+
final_img.paste(left_img, (0, 0))
|
111 |
+
final_img.paste(right_img, (height_width, 0))
|
112 |
+
|
113 |
+
which_image(final_img, width=height_width)
|
114 |
+
|
115 |
+
|
116 |
+
result = f"Baseline image: {end_time_base-start_time_base:.2f} sec | {'ToDo' if method == 'todo' else 'ToMe'} image: {end_time_merge-start_time_merge:.2f} sec"
|
117 |
+
|
118 |
+
return final_img, result
|
119 |
+
|
120 |
+
|
121 |
+
gen.click(generate, inputs=[prompt, seed, steps, height_width, negative_prompt,
|
122 |
+
guidance_scale, method], outputs=[output_image, result])
|
123 |
+
|
124 |
+
demo.launch(share=True)
|
compare3.png
ADDED
Git LFS Details
|
merge.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Tuple, Callable
|
3 |
+
from diffusers.models.attention_processor import XFormersAttnProcessor, Attention
|
4 |
+
import xformers, xformers.ops
|
5 |
+
from typing import Optional
|
6 |
+
import math
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from diffusers.utils import USE_PEFT_BACKEND
|
9 |
+
from diffusers.utils.import_utils import is_xformers_available
|
10 |
+
|
11 |
+
if is_xformers_available():
|
12 |
+
import xformers
|
13 |
+
import xformers.ops
|
14 |
+
xformers_is_available = True
|
15 |
+
else:
|
16 |
+
xformers_is_available = False
|
17 |
+
|
18 |
+
|
19 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
20 |
+
torch2_is_available = True
|
21 |
+
else:
|
22 |
+
torch2_is_available = False
|
23 |
+
|
24 |
+
|
25 |
+
def init_generator(device: torch.device, fallback: torch.Generator = None):
|
26 |
+
"""
|
27 |
+
Forks the current default random generator given device.
|
28 |
+
"""
|
29 |
+
if device.type == "cpu":
|
30 |
+
return torch.Generator(device="cpu").set_state(torch.get_rng_state())
|
31 |
+
elif device.type == "cuda":
|
32 |
+
return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
|
33 |
+
else:
|
34 |
+
if fallback is None:
|
35 |
+
return init_generator(torch.device("cpu"))
|
36 |
+
else:
|
37 |
+
return fallback
|
38 |
+
|
39 |
+
|
40 |
+
def do_nothing(x: torch.Tensor, mode: str = None):
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def mps_gather_workaround(input, dim, index):
|
45 |
+
if input.shape[-1] == 1:
|
46 |
+
return torch.gather(
|
47 |
+
input.unsqueeze(-1),
|
48 |
+
dim - 1 if dim < 0 else dim,
|
49 |
+
index.unsqueeze(-1)
|
50 |
+
).squeeze(-1)
|
51 |
+
else:
|
52 |
+
return torch.gather(input, dim, index)
|
53 |
+
|
54 |
+
|
55 |
+
def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method):
|
56 |
+
batch_size = item.shape[0]
|
57 |
+
|
58 |
+
item = item.reshape(batch_size, cur_h, cur_w, -1)
|
59 |
+
item = item.permute(0, 3, 1, 2)
|
60 |
+
df = cur_h // new_h
|
61 |
+
if method in "max_pool":
|
62 |
+
item = F.max_pool2d(item, kernel_size=df, stride=df, padding=0)
|
63 |
+
elif method in "avg_pool":
|
64 |
+
item = F.avg_pool2d(item, kernel_size=df, stride=df, padding=0)
|
65 |
+
else:
|
66 |
+
item = F.interpolate(item, size=(new_h, new_w), mode=method)
|
67 |
+
item = item.permute(0, 2, 3, 1)
|
68 |
+
item = item.reshape(batch_size, new_h * new_w, -1)
|
69 |
+
|
70 |
+
return item
|
71 |
+
|
72 |
+
|
73 |
+
def compute_merge(x: torch.Tensor, tome_info):
|
74 |
+
original_h, original_w = tome_info["size"]
|
75 |
+
original_tokens = original_h * original_w
|
76 |
+
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
77 |
+
dim = x.shape[-1]
|
78 |
+
if dim == 320:
|
79 |
+
cur_level = "level_1"
|
80 |
+
downsample_factor = tome_info['args']['downsample_factor']
|
81 |
+
ratio = tome_info['args']['ratio']
|
82 |
+
elif dim == 640:
|
83 |
+
cur_level = "level_2"
|
84 |
+
downsample_factor = tome_info['args']['downsample_factor_level_2']
|
85 |
+
ratio = tome_info['args']['ratio_level_2']
|
86 |
+
else:
|
87 |
+
cur_level = "other"
|
88 |
+
downsample_factor = 1
|
89 |
+
ratio = 0.0
|
90 |
+
|
91 |
+
args = tome_info["args"]
|
92 |
+
|
93 |
+
cur_h, cur_w = original_h // downsample, original_w // downsample
|
94 |
+
new_h, new_w = cur_h // downsample_factor, cur_w // downsample_factor
|
95 |
+
|
96 |
+
if tome_info['timestep'] / 1000 > tome_info['args']['timestep_threshold_switch']:
|
97 |
+
merge_method = args["merge_method"]
|
98 |
+
else:
|
99 |
+
merge_method = args["secondary_merge_method"]
|
100 |
+
|
101 |
+
if cur_level != "other" and tome_info['timestep'] / 1000 > tome_info['args']['timestep_threshold_stop']:
|
102 |
+
if merge_method == "downsample" and downsample_factor > 1:
|
103 |
+
m = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h, args["downsample_method"])
|
104 |
+
u = lambda x: up_or_downsample(x, new_w, new_h, cur_w, cur_h, args["downsample_method"])
|
105 |
+
elif merge_method == "similarity" and ratio > 0.0:
|
106 |
+
w = int(math.ceil(original_w / downsample))
|
107 |
+
h = int(math.ceil(original_h / downsample))
|
108 |
+
r = int(x.shape[1] * ratio)
|
109 |
+
|
110 |
+
# Re-init the generator if it hasn't already been initialized or device has changed.
|
111 |
+
if args["generator"] is None:
|
112 |
+
args["generator"] = init_generator(x.device)
|
113 |
+
elif args["generator"].device != x.device:
|
114 |
+
args["generator"] = init_generator(x.device, fallback=args["generator"])
|
115 |
+
|
116 |
+
# If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
|
117 |
+
# batch, which causes artifacts with use_rand, so force it to be off.
|
118 |
+
use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"]
|
119 |
+
m, u = bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r,
|
120 |
+
no_rand=not use_rand, generator=args["generator"])
|
121 |
+
else:
|
122 |
+
m, u = (do_nothing, do_nothing)
|
123 |
+
else:
|
124 |
+
m, u = (do_nothing, do_nothing)
|
125 |
+
|
126 |
+
merge_fn, unmerge_fn = (m, u)
|
127 |
+
|
128 |
+
return merge_fn, unmerge_fn
|
129 |
+
|
130 |
+
|
131 |
+
def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
132 |
+
w: int,
|
133 |
+
h: int,
|
134 |
+
sx: int,
|
135 |
+
sy: int,
|
136 |
+
r: int,
|
137 |
+
no_rand: bool = False,
|
138 |
+
generator: torch.Generator = None) -> Tuple[Callable, Callable]:
|
139 |
+
"""
|
140 |
+
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
141 |
+
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
- metric [B, N, C]: metric to use for similarity
|
145 |
+
- w: image width in tokens
|
146 |
+
- h: image height in tokens
|
147 |
+
- sx: stride in the x dimension for dst, must divide w
|
148 |
+
- sy: stride in the y dimension for dst, must divide h
|
149 |
+
- r: number of tokens to remove (by merging)
|
150 |
+
- no_rand: if true, disable randomness (use top left corner only)
|
151 |
+
- rand_seed: if no_rand is false, and if not None, sets random seed.
|
152 |
+
"""
|
153 |
+
B, N, _ = metric.shape
|
154 |
+
|
155 |
+
if r <= 0:
|
156 |
+
return do_nothing, do_nothing
|
157 |
+
|
158 |
+
with torch.no_grad():
|
159 |
+
hsy, wsx = h // sy, w // sx
|
160 |
+
|
161 |
+
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
162 |
+
if no_rand:
|
163 |
+
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
164 |
+
else:
|
165 |
+
rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(
|
166 |
+
metric.device)
|
167 |
+
|
168 |
+
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
169 |
+
idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
|
170 |
+
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
171 |
+
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
172 |
+
|
173 |
+
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
174 |
+
if (hsy * sy) < h or (wsx * sx) < w:
|
175 |
+
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
176 |
+
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
177 |
+
else:
|
178 |
+
idx_buffer = idx_buffer_view
|
179 |
+
|
180 |
+
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
181 |
+
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
182 |
+
|
183 |
+
# We're finished with these
|
184 |
+
del idx_buffer, idx_buffer_view
|
185 |
+
|
186 |
+
# rand_idx is currently dst|src, so split them
|
187 |
+
num_dst = hsy * wsx
|
188 |
+
a_idx = rand_idx[:, num_dst:, :] # src
|
189 |
+
b_idx = rand_idx[:, :num_dst, :] # dst
|
190 |
+
|
191 |
+
def split(x):
|
192 |
+
C = x.shape[-1]
|
193 |
+
src = torch.gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
194 |
+
dst = torch.gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
195 |
+
return src, dst
|
196 |
+
|
197 |
+
# Cosine similarity between A and B
|
198 |
+
metric = metric / metric.norm(dim=-1, keepdim=True)
|
199 |
+
a, b = split(metric)
|
200 |
+
scores = a @ b.transpose(-1, -2)
|
201 |
+
|
202 |
+
# Can't reduce more than the # tokens in src
|
203 |
+
r = min(a.shape[1], r)
|
204 |
+
|
205 |
+
# Find the most similar greedily
|
206 |
+
node_max, node_idx = scores.max(dim=-1)
|
207 |
+
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
208 |
+
|
209 |
+
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
210 |
+
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
211 |
+
dst_idx = torch.gather(node_idx[..., None], dim=-2, index=src_idx)
|
212 |
+
|
213 |
+
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
214 |
+
src, dst = split(x)
|
215 |
+
n, t1, c = src.shape
|
216 |
+
|
217 |
+
unm = torch.gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
218 |
+
src = torch.gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
219 |
+
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
220 |
+
|
221 |
+
return torch.cat([unm, dst], dim=1)
|
222 |
+
|
223 |
+
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
224 |
+
unm_len = unm_idx.shape[1]
|
225 |
+
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
226 |
+
_, _, c = unm.shape
|
227 |
+
|
228 |
+
src = torch.gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
229 |
+
|
230 |
+
# Combine back to the original shape
|
231 |
+
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
232 |
+
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
233 |
+
out.scatter_(dim=-2,
|
234 |
+
index=torch.gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c),
|
235 |
+
src=unm)
|
236 |
+
out.scatter_(dim=-2,
|
237 |
+
index=torch.gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c),
|
238 |
+
src=src)
|
239 |
+
|
240 |
+
return out
|
241 |
+
|
242 |
+
return merge, unmerge
|
243 |
+
|
244 |
+
|
245 |
+
class TokenMergeAttentionProcessor:
|
246 |
+
def __init__(self):
|
247 |
+
# priortize torch2's flash attention, if not fall back to xformers then regular attention
|
248 |
+
if torch2_is_available:
|
249 |
+
self.attn_method = "torch2"
|
250 |
+
elif xformers_is_available:
|
251 |
+
self.attn_method = "xformers"
|
252 |
+
else:
|
253 |
+
self.attn_method = "regular"
|
254 |
+
|
255 |
+
def torch2_attention(self, attn, query, key, value, attention_mask, batch_size):
|
256 |
+
inner_dim=key.shape[-1]
|
257 |
+
head_dim = inner_dim // attn.heads
|
258 |
+
|
259 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
260 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
261 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
262 |
+
|
263 |
+
hidden_states = F.scaled_dot_product_attention(
|
264 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
265 |
+
)
|
266 |
+
|
267 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
268 |
+
|
269 |
+
return hidden_states
|
270 |
+
|
271 |
+
def xformers_attention(self, attn, query, key, value, attention_mask, batch_size):
|
272 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
273 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
274 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
275 |
+
|
276 |
+
if attention_mask is not None:
|
277 |
+
attention_mask = attention_mask.reshape(batch_size * attn.heads, -1, attention_mask.shape[-1])
|
278 |
+
|
279 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
280 |
+
query, key, value, attn_bias=attention_mask, scale=attn.scale
|
281 |
+
)
|
282 |
+
|
283 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
284 |
+
|
285 |
+
return hidden_states
|
286 |
+
|
287 |
+
|
288 |
+
def regular_attention(self, attn, query, key, value, attention_mask, batch_size):
|
289 |
+
query = attn.head_to_batch_dim(query)
|
290 |
+
key = attn.head_to_batch_dim(key)
|
291 |
+
value = attn.head_to_batch_dim(value)
|
292 |
+
|
293 |
+
if attention_mask is not None:
|
294 |
+
attention_mask = attention_mask.reshape(batch_size * attn.heads, -1, attention_mask.shape[-1])
|
295 |
+
|
296 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
297 |
+
hidden_states = torch.bmm(attention_probs, value)
|
298 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
299 |
+
|
300 |
+
return hidden_states
|
301 |
+
|
302 |
+
|
303 |
+
def __call__(
|
304 |
+
self,
|
305 |
+
attn: Attention,
|
306 |
+
hidden_states: torch.FloatTensor,
|
307 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
308 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
309 |
+
temb: Optional[torch.FloatTensor] = None,
|
310 |
+
scale: float = 1.0,
|
311 |
+
) -> torch.FloatTensor:
|
312 |
+
residual = hidden_states
|
313 |
+
if attn.spatial_norm is not None:
|
314 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
315 |
+
|
316 |
+
input_ndim = hidden_states.ndim
|
317 |
+
|
318 |
+
if input_ndim == 4:
|
319 |
+
batch_size, channel, height, width = hidden_states.shape
|
320 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
321 |
+
|
322 |
+
batch_size, sequence_length, _ = (
|
323 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
324 |
+
)
|
325 |
+
|
326 |
+
if attention_mask is not None:
|
327 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
328 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
329 |
+
# (batch, heads, source_length, target_length)
|
330 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
331 |
+
|
332 |
+
if attn.group_norm is not None:
|
333 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
334 |
+
|
335 |
+
args = () if USE_PEFT_BACKEND else (scale,)
|
336 |
+
|
337 |
+
if self._tome_info['args']['merge_tokens'] == "all":
|
338 |
+
merge_fn, unmerge_fn = compute_merge(hidden_states, self._tome_info)
|
339 |
+
hidden_states = merge_fn(hidden_states)
|
340 |
+
|
341 |
+
query = attn.to_q(hidden_states, *args)
|
342 |
+
|
343 |
+
if encoder_hidden_states is None:
|
344 |
+
encoder_hidden_states = hidden_states
|
345 |
+
elif attn.norm_cross:
|
346 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
347 |
+
|
348 |
+
if self._tome_info['args']['merge_tokens'] == "keys/values":
|
349 |
+
merge_fn, _ = compute_merge(encoder_hidden_states, self._tome_info)
|
350 |
+
encoder_hidden_states = merge_fn(encoder_hidden_states)
|
351 |
+
|
352 |
+
key = attn.to_k(encoder_hidden_states, *args)
|
353 |
+
value = attn.to_v(encoder_hidden_states, *args)
|
354 |
+
|
355 |
+
if self.attn_method == "torch2":
|
356 |
+
hidden_states = self.torch2_attention(attn, query, key, value, attention_mask, batch_size)
|
357 |
+
elif self.attn_method == "xformers":
|
358 |
+
hidden_states = self.xformers_attention(attn, query, key, value, attention_mask, batch_size)
|
359 |
+
else:
|
360 |
+
hidden_states = self.regular_attention(attn, query, key, value, attention_mask, batch_size)
|
361 |
+
|
362 |
+
hidden_states = hidden_states.to(query.dtype)
|
363 |
+
|
364 |
+
# linear proj
|
365 |
+
hidden_states = attn.to_out[0](hidden_states, *args)
|
366 |
+
# dropout
|
367 |
+
hidden_states = attn.to_out[1](hidden_states)
|
368 |
+
|
369 |
+
if self._tome_info['args']['merge_tokens'] == "all":
|
370 |
+
hidden_states = unmerge_fn(hidden_states)
|
371 |
+
|
372 |
+
if input_ndim == 4:
|
373 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
374 |
+
|
375 |
+
if attn.residual_connection:
|
376 |
+
hidden_states = hidden_states + residual
|
377 |
+
|
378 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
379 |
+
|
380 |
+
return hidden_states
|
381 |
+
|
382 |
+
|
383 |
+
|
384 |
+
|
385 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
diffusers
|
2 |
+
transformers
|
3 |
+
accelerate
|
test_notebook.ipynb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db39c8f9f9eea913cade16c3cf45ce0d9a13cc050b5e2564896b100042cdc86b
|
3 |
+
size 17164306
|
utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from merge import TokenMergeAttentionProcessor
|
3 |
+
from diffusers.utils.import_utils import is_xformers_available
|
4 |
+
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor, AttnProcessor
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
if is_xformers_available():
|
8 |
+
xformers_is_available = True
|
9 |
+
else:
|
10 |
+
xformers_is_available = False
|
11 |
+
|
12 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
13 |
+
torch2_is_available = True
|
14 |
+
else:
|
15 |
+
torch2_is_available = False
|
16 |
+
|
17 |
+
|
18 |
+
def hook_tome_model(model: torch.nn.Module):
|
19 |
+
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
|
20 |
+
|
21 |
+
def hook(module, args):
|
22 |
+
module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
|
23 |
+
module._tome_info["timestep"] = args[1].item()
|
24 |
+
return None
|
25 |
+
|
26 |
+
model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))
|
27 |
+
|
28 |
+
|
29 |
+
def patch_attention_proc(unet, token_merge_args={}):
|
30 |
+
unet._tome_info = {
|
31 |
+
"size": None,
|
32 |
+
"timestep": None,
|
33 |
+
"hooks": [],
|
34 |
+
"args": {
|
35 |
+
"ratio": token_merge_args.get("ratio", 0.5), # ratio of tokens to merge
|
36 |
+
"sx": token_merge_args.get("sx", 2), # stride x for sim calculation
|
37 |
+
"sy": token_merge_args.get("sy", 2), # stride y for sim calculation
|
38 |
+
"use_rand": token_merge_args.get("use_rand", True),
|
39 |
+
"generator": None,
|
40 |
+
|
41 |
+
"merge_tokens": token_merge_args.get("merge_tokens", "keys/values"), # ["all", "keys/values"]
|
42 |
+
"merge_method": token_merge_args.get("merge_method", "downsample"), # ["none","similarity", "downsample"]
|
43 |
+
"downsample_method": token_merge_args.get("downsample_method", "nearest-exact"),
|
44 |
+
# native torch interpolation methods ["nearest", "linear", "bilinear", "bicubic", "nearest-exact"]
|
45 |
+
"downsample_factor": token_merge_args.get("downsample_factor", 2), # amount to downsample by
|
46 |
+
"timestep_threshold_switch": token_merge_args.get("timestep_threshold_switch", 0.2),
|
47 |
+
# timestep to switch to secondary method, 0.2 means 20% steps remaining
|
48 |
+
"timestep_threshold_stop": token_merge_args.get("timestep_threshold_stop", 0.0),
|
49 |
+
# timestep to stop merging, 0.0 means stop at 0 steps remaining
|
50 |
+
"secondary_merge_method": token_merge_args.get("secondary_merge_method", "similarity"),
|
51 |
+
# ["none", "similarity", "downsample"]
|
52 |
+
|
53 |
+
"downsample_factor_level_2": token_merge_args.get("downsample_factor_level_2", 1), # amount to downsample by at the 2nd down block of unet
|
54 |
+
"ratio_level_2": token_merge_args.get("ratio_level_2", 0.5), # ratio of tokens to merge at the 2nd down block of unet
|
55 |
+
}
|
56 |
+
}
|
57 |
+
hook_tome_model(unet)
|
58 |
+
attn_modules = [module for name, module in unet.named_modules() if module.__class__.__name__ == 'BasicTransformerBlock']
|
59 |
+
|
60 |
+
for i, module in enumerate(attn_modules):
|
61 |
+
module.attn1.processor = TokenMergeAttentionProcessor()
|
62 |
+
module.attn1.processor._tome_info = unet._tome_info
|
63 |
+
|
64 |
+
|
65 |
+
def remove_patch(pipe: torch.nn.Module):
|
66 |
+
""" Removes a patch from a ToMe Diffusion module if it was already patched. """
|
67 |
+
|
68 |
+
# this will remove our custom class
|
69 |
+
if torch2_is_available:
|
70 |
+
for n,m in pipe.unet.named_modules():
|
71 |
+
if hasattr(m, "processor"):
|
72 |
+
m.processor = AttnProcessor2_0()
|
73 |
+
|
74 |
+
elif xformers_is_available:
|
75 |
+
pipe.enable_xformers_memory_efficient_attention()
|
76 |
+
|
77 |
+
else:
|
78 |
+
for n,m in pipe.unet.named_modules():
|
79 |
+
if hasattr(m, "processor"):
|
80 |
+
m.processor = AttnProcessor()
|