|
<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
# Token Merging (ν ν° λ³ν©) |
|
|
|
Token Merging (introduced in [Token Merging: Your ViT But Faster](https://arxiv.org/abs/2210.09461))μ νΈλμ€ν¬λ¨Έ κΈ°λ° λ€νΈμν¬μ forward passμμ μ€λ³΅ ν ν°μ΄λ ν¨μΉλ₯Ό μ μ§μ μΌλ‘ λ³ν©νλ λ°©μμΌλ‘ μλν©λλ€. μ΄λ₯Ό ν΅ν΄ κΈ°λ° λ€νΈμν¬μ μΆλ‘ μ§μ° μκ°μ λ¨μΆν μ μμ΅λλ€. |
|
|
|
Token Merging(ToMe)μ΄ μΆμλ ν, μ μλ€μ [Fast Stable Diffusionμ μν ν ν° λ³ν©](https://arxiv.org/abs/2303.17604)μ λ°ννμ¬ Stable Diffusionκ³Ό λ μ νΈνλλ ToMe λ²μ μ μκ°νμ΅λλ€. ToMeλ₯Ό μ¬μ©νλ©΄ [`DiffusionPipeline`]μ μΆλ‘ μ§μ° μκ°μ λΆλλ½κ² λ¨μΆν μ μμ΅λλ€. μ΄ λ¬Έμμμλ ToMeλ₯Ό [`StableDiffusionPipeline`]μ μ μ©νλ λ°©λ², μμλλ μλ ν₯μ, [`StableDiffusionPipeline`]μμ ToMeλ₯Ό μ¬μ©ν λμ μ§μ μΈ‘λ©΄μ λν΄ μ€λͺ
ν©λλ€. |
|
|
|
## ToMe μ¬μ©νκΈ° |
|
|
|
ToMeμ μ μλ€μ [`tomesd`](https://github.com/dbolya/tomesd)λΌλ νΈλ¦¬ν Python λΌμ΄λΈλ¬λ¦¬λ₯Ό 곡κ°νλλ°, μ΄ λΌμ΄λΈλ¬λ¦¬λ₯Ό μ΄μ©νλ©΄ [`DiffusionPipeline`]μ ToMeλ₯Ό λ€μκ³Ό κ°μ΄ μ μ©ν μ μμ΅λλ€: |
|
|
|
```diff |
|
from diffusers import StableDiffusionPipeline |
|
import tomesd |
|
|
|
pipeline = StableDiffusionPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 |
|
).to("cuda") |
|
+ tomesd.apply_patch(pipeline, ratio=0.5) |
|
|
|
image = pipeline("a photo of an astronaut riding a horse on mars").images[0] |
|
``` |
|
|
|
μ΄κ²μ΄ λ€μ
λλ€! |
|
|
|
`tomesd.apply_patch()`λ νμ΄νλΌμΈ μΆλ‘ μλμ μμ±λ ν ν°μ νμ§ μ¬μ΄μ κ· νμ λ§μΆ μ μλλ‘ [μ¬λ¬ κ°μ μΈμ](https://github.com/dbolya/tomesd#usage)λ₯Ό λ
ΈμΆν©λλ€. μ΄λ¬ν μΈμ μ€ κ°μ₯ μ€μν κ²μ `ratio(λΉμ¨)`μ
λλ€. `ratio`μ forward pass μ€μ λ³ν©λ ν ν°μ μλ₯Ό μ μ΄ν©λλ€. `tomesd`μ λν μμΈν λ΄μ©μ ν΄λΉ 리ν¬μ§ν 리(https://github.com/dbolya/tomesd) λ° [λ
Όλ¬Έ](https://arxiv.org/abs/2303.17604)μ μ°Έκ³ νμκΈ° λ°λλλ€. |
|
|
|
## `StableDiffusionPipeline`μΌλ‘ `tomesd` λ²€μΉλ§νΉνκΈ° |
|
|
|
We benchmarked the impact of using `tomesd` on [`StableDiffusionPipeline`] along with [xformers](https://huggingface.co/docs/diffusers/optimization/xformers) across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5): |
|
λ€μν μ΄λ―Έμ§ ν΄μλμμ [xformers](https://huggingface.co/docs/diffusers/optimization/xformers)λ₯Ό μ μ©ν μνμμ, [`StableDiffusionPipeline`]μ `tomesd`λ₯Ό μ¬μ©νμ λμ μν₯μ λ²€μΉλ§νΉνμ΅λλ€. ν
μ€νΈ GPU μ₯μΉλ‘ A100κ³Ό V100μ μ¬μ©νμΌλ©° κ°λ° νκ²½μ λ€μκ³Ό κ°μ΅λλ€(Python 3.8.5 μ¬μ©): |
|
|
|
```bash |
|
- `diffusers` version: 0.15.1 |
|
- Python version: 3.8.16 |
|
- PyTorch version (GPU?): 1.13.1+cu116 (True) |
|
- Huggingface_hub version: 0.13.2 |
|
- Transformers version: 4.27.2 |
|
- Accelerate version: 0.18.0 |
|
- xFormers version: 0.0.16 |
|
- tomesd version: 0.1.2 |
|
``` |
|
|
|
λ²€μΉλ§νΉμλ λ€μ μ€ν¬λ¦½νΈλ₯Ό μ¬μ©νμ΅λλ€: [https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335](https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335). κ²°κ³Όλ λ€μκ³Ό κ°μ΅λλ€: |
|
|
|
### A100 |
|
|
|
| ν΄μλ | λ°°μΉ ν¬κΈ° | Vanilla | ToMe | ToMe + xFormers | ToMe μλ ν₯μ (%) | ToMe + xFormers μλ ν₯μ (%) | |
|
| --- | --- | --- | --- | --- | --- | --- | |
|
| 512 | 10 | 6.88 | 5.26 | 4.69 | 23.54651163 | 31.83139535 | |
|
| | | | | | | | |
|
| 768 | 10 | OOM | 14.71 | 11 | | | |
|
| | 8 | OOM | 11.56 | 8.84 | | | |
|
| | 4 | OOM | 5.98 | 4.66 | | | |
|
| | 2 | 4.99 | 3.24 | 3.1 | 35.07014028 | 37.8757515 | |
|
| | 1 | 3.29 | 2.24 | 2.03 | 31.91489362 | 38.29787234 | |
|
| | | | | | | | |
|
| 1024 | 10 | OOM | OOM | OOM | | | |
|
| | 8 | OOM | OOM | OOM | | | |
|
| | 4 | OOM | 12.51 | 9.09 | | | |
|
| | 2 | OOM | 6.52 | 4.96 | | | |
|
| | 1 | 6.4 | 3.61 | 2.81 | 43.59375 | 56.09375 | |
|
|
|
***κ²°κ³Όλ μ΄ λ¨μμ
λλ€. μλ ν₯μμ `Vanilla`κ³Ό λΉκ΅ν΄ κ³μ°λ©λλ€.*** |
|
|
|
### V100 |
|
|
|
| ν΄μλ | λ°°μΉ ν¬κΈ° | Vanilla | ToMe | ToMe + xFormers | ToMe μλ ν₯μ (%) | ToMe + xFormers μλ ν₯μ (%) | |
|
| --- | --- | --- | --- | --- | --- | --- | |
|
| 512 | 10 | OOM | 10.03 | 9.29 | | | |
|
| | 8 | OOM | 8.05 | 7.47 | | | |
|
| | 4 | 5.7 | 4.3 | 3.98 | 24.56140351 | 30.1754386 | |
|
| | 2 | 3.14 | 2.43 | 2.27 | 22.61146497 | 27.70700637 | |
|
| | 1 | 1.88 | 1.57 | 1.57 | 16.4893617 | 16.4893617 | |
|
| | | | | | | | |
|
| 768 | 10 | OOM | OOM | 23.67 | | | |
|
| | 8 | OOM | OOM | 18.81 | | | |
|
| | 4 | OOM | 11.81 | 9.7 | | | |
|
| | 2 | OOM | 6.27 | 5.2 | | | |
|
| | 1 | 5.43 | 3.38 | 2.82 | 37.75322284 | 48.06629834 | |
|
| | | | | | | | |
|
| 1024 | 10 | OOM | OOM | OOM | | | |
|
| | 8 | OOM | OOM | OOM | | | |
|
| | 4 | OOM | OOM | 19.35 | | | |
|
| | 2 | OOM | 13 | 10.78 | | | |
|
| | 1 | OOM | 6.66 | 5.54 | | | |
|
|
|
μμ νμμ λ³Ό μ μλ―μ΄, μ΄λ―Έμ§ ν΄μλκ° λμμλ‘ `tomesd`λ₯Ό μ¬μ©ν μλ ν₯μμ΄ λμ± λλλ¬μ§λλ€. λν `tomesd`λ₯Ό μ¬μ©νλ©΄ 1024x1024μ κ°μ λ λμ ν΄μλμμ νμ΄νλΌμΈμ μ€νν μ μλ€λ μ λ ν₯λ―Έλ‘μ΅λλ€. |
|
|
|
[`torch.compile()`](https://huggingface.co/docs/diffusers/optimization/torch2.0)μ μ¬μ©νλ©΄ μΆλ‘ μλλ₯Ό λμ± λμΌ μ μμ΅λλ€. |
|
|
|
## νμ§ |
|
|
|
As reported in [the paper](https://arxiv.org/abs/2303.17604), ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the `ratio`, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality. |
|
|
|
To test the quality of the generated samples using our setup, we sampled a few prompts from the βParti Promptsβ (introduced in [Parti](https://parti.research.google/)) and performed inference with the [`StableDiffusionPipeline`] in the following settings: |
|
|
|
[λ
Όλ¬Έ](https://arxiv.org/abs/2303.17604)μ λ³΄κ³ λ λ°μ κ°μ΄, ToMeλ μμ±λ μ΄λ―Έμ§μ νμ§μ μλΉ λΆλΆ 보쑴νλ©΄μ μΆλ‘ μλλ₯Ό λμΌ μ μμ΅λλ€. `ratio`μ λμ΄λ©΄ μΆλ‘ μλλ₯Ό λ λμΌ μ μμ§λ§, μ΄λ―Έμ§ νμ§μ΄ μ νλ μ μμ΅λλ€. |
|
|
|
ν΄λΉ μ€μ μ μ¬μ©νμ¬ μμ±λ μνμ νμ§μ ν
μ€νΈνκΈ° μν΄, "Parti ν둬ννΈ"([Parti](https://parti.research.google/)μμ μκ°)μμ λͺ κ°μ§ ν둬ννΈλ₯Ό μνλ§νκ³ λ€μ μ€μ μμ [`StableDiffusionPipeline`]μ μ¬μ©νμ¬ μΆλ‘ μ μννμ΅λλ€: |
|
|
|
- Vanilla [`StableDiffusionPipeline`] |
|
- [`StableDiffusionPipeline`] + ToMe |
|
- [`StableDiffusionPipeline`] + ToMe + xformers |
|
|
|
μμ±λ μνμ νμ§μ΄ ν¬κ² μ νλλ κ²μ λ°κ²¬νμ§ λͺ»νμ΅λλ€. λ€μμ μνμ
λλ€: |
|
|
|
![tome-samples](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/tome/tome_samples.png) |
|
|
|
μμ±λ μνμ [μ¬κΈ°](https://wandb.ai/sayakpaul/tomesd-results/runs/23j4bj3i?workspace=)μμ νμΈν μ μμ΅λλ€. μ΄ μ€νμ μννκΈ° μν΄ [μ΄ μ€ν¬λ¦½νΈ](https://gist.github.com/sayakpaul/8cac98d7f22399085a060992f411ecbd)λ₯Ό μ¬μ©νμ΅λλ€. |