Token Merging (ν ν° λ³ν©)
Token Merging (introduced in Token Merging: Your ViT But Faster)μ νΈλμ€ν¬λ¨Έ κΈ°λ° λ€νΈμν¬μ forward passμμ μ€λ³΅ ν ν°μ΄λ ν¨μΉλ₯Ό μ μ§μ μΌλ‘ λ³ν©νλ λ°©μμΌλ‘ μλν©λλ€. μ΄λ₯Ό ν΅ν΄ κΈ°λ° λ€νΈμν¬μ μΆλ‘ μ§μ° μκ°μ λ¨μΆν μ μμ΅λλ€.
Token Merging(ToMe)μ΄ μΆμλ ν, μ μλ€μ Fast Stable Diffusionμ μν ν ν° λ³ν©μ λ°ννμ¬ Stable Diffusionκ³Ό λ μ νΈνλλ ToMe λ²μ μ μκ°νμ΅λλ€. ToMeλ₯Ό μ¬μ©νλ©΄ [DiffusionPipeline
]μ μΆλ‘ μ§μ° μκ°μ λΆλλ½κ² λ¨μΆν μ μμ΅λλ€. μ΄ λ¬Έμμμλ ToMeλ₯Ό [StableDiffusionPipeline
]μ μ μ©νλ λ°©λ², μμλλ μλ ν₯μ, [StableDiffusionPipeline
]μμ ToMeλ₯Ό μ¬μ©ν λμ μ§μ μΈ‘λ©΄μ λν΄ μ€λͺ
ν©λλ€.
ToMe μ¬μ©νκΈ°
ToMeμ μ μλ€μ tomesd
λΌλ νΈλ¦¬ν Python λΌμ΄λΈλ¬λ¦¬λ₯Ό 곡κ°νλλ°, μ΄ λΌμ΄λΈλ¬λ¦¬λ₯Ό μ΄μ©νλ©΄ [DiffusionPipeline
]μ ToMeλ₯Ό λ€μκ³Ό κ°μ΄ μ μ©ν μ μμ΅λλ€:
from diffusers import StableDiffusionPipeline
import tomesd
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to("cuda")
+ tomesd.apply_patch(pipeline, ratio=0.5)
image = pipeline("a photo of an astronaut riding a horse on mars").images[0]
μ΄κ²μ΄ λ€μ λλ€!
tomesd.apply_patch()
λ νμ΄νλΌμΈ μΆλ‘ μλμ μμ±λ ν ν°μ νμ§ μ¬μ΄μ κ· νμ λ§μΆ μ μλλ‘ μ¬λ¬ κ°μ μΈμλ₯Ό λ
ΈμΆν©λλ€. μ΄λ¬ν μΈμ μ€ κ°μ₯ μ€μν κ²μ ratio(λΉμ¨)
μ
λλ€. ratio
μ forward pass μ€μ λ³ν©λ ν ν°μ μλ₯Ό μ μ΄ν©λλ€. tomesd
μ λν μμΈν λ΄μ©μ ν΄λΉ 리ν¬μ§ν 리(https://github.com/dbolya/tomesd) λ° λ
Όλ¬Έμ μ°Έκ³ νμκΈ° λ°λλλ€.
StableDiffusionPipeline
μΌλ‘ tomesd
λ²€μΉλ§νΉνκΈ°
We benchmarked the impact of using tomesd
on [StableDiffusionPipeline
] along with xformers across different image resolutions. We used A100 and V100 as our test GPU devices with the following development environment (with Python 3.8.5):
λ€μν μ΄λ―Έμ§ ν΄μλμμ xformersλ₯Ό μ μ©ν μνμμ, [StableDiffusionPipeline
]μ tomesd
λ₯Ό μ¬μ©νμ λμ μν₯μ λ²€μΉλ§νΉνμ΅λλ€. ν
μ€νΈ GPU μ₯μΉλ‘ A100κ³Ό V100μ μ¬μ©νμΌλ©° κ°λ° νκ²½μ λ€μκ³Ό κ°μ΅λλ€(Python 3.8.5 μ¬μ©):
- `diffusers` version: 0.15.1
- Python version: 3.8.16
- PyTorch version (GPU?): 1.13.1+cu116 (True)
- Huggingface_hub version: 0.13.2
- Transformers version: 4.27.2
- Accelerate version: 0.18.0
- xFormers version: 0.0.16
- tomesd version: 0.1.2
λ²€μΉλ§νΉμλ λ€μ μ€ν¬λ¦½νΈλ₯Ό μ¬μ©νμ΅λλ€: https://gist.github.com/sayakpaul/27aec6bca7eb7b0e0aa4112205850335. κ²°κ³Όλ λ€μκ³Ό κ°μ΅λλ€:
A100
ν΄μλ | λ°°μΉ ν¬κΈ° | Vanilla | ToMe | ToMe + xFormers | ToMe μλ ν₯μ (%) | ToMe + xFormers μλ ν₯μ (%) |
---|---|---|---|---|---|---|
512 | 10 | 6.88 | 5.26 | 4.69 | 23.54651163 | 31.83139535 |
768 | 10 | OOM | 14.71 | 11 | ||
8 | OOM | 11.56 | 8.84 | |||
4 | OOM | 5.98 | 4.66 | |||
2 | 4.99 | 3.24 | 3.1 | 35.07014028 | 37.8757515 | |
1 | 3.29 | 2.24 | 2.03 | 31.91489362 | 38.29787234 | |
1024 | 10 | OOM | OOM | OOM | ||
8 | OOM | OOM | OOM | |||
4 | OOM | 12.51 | 9.09 | |||
2 | OOM | 6.52 | 4.96 | |||
1 | 6.4 | 3.61 | 2.81 | 43.59375 | 56.09375 |
κ²°κ³Όλ μ΄ λ¨μμ
λλ€. μλ ν₯μμ Vanilla
κ³Ό λΉκ΅ν΄ κ³μ°λ©λλ€.
V100
ν΄μλ | λ°°μΉ ν¬κΈ° | Vanilla | ToMe | ToMe + xFormers | ToMe μλ ν₯μ (%) | ToMe + xFormers μλ ν₯μ (%) |
---|---|---|---|---|---|---|
512 | 10 | OOM | 10.03 | 9.29 | ||
8 | OOM | 8.05 | 7.47 | |||
4 | 5.7 | 4.3 | 3.98 | 24.56140351 | 30.1754386 | |
2 | 3.14 | 2.43 | 2.27 | 22.61146497 | 27.70700637 | |
1 | 1.88 | 1.57 | 1.57 | 16.4893617 | 16.4893617 | |
768 | 10 | OOM | OOM | 23.67 | ||
8 | OOM | OOM | 18.81 | |||
4 | OOM | 11.81 | 9.7 | |||
2 | OOM | 6.27 | 5.2 | |||
1 | 5.43 | 3.38 | 2.82 | 37.75322284 | 48.06629834 | |
1024 | 10 | OOM | OOM | OOM | ||
8 | OOM | OOM | OOM | |||
4 | OOM | OOM | 19.35 | |||
2 | OOM | 13 | 10.78 | |||
1 | OOM | 6.66 | 5.54 |
μμ νμμ λ³Ό μ μλ―μ΄, μ΄λ―Έμ§ ν΄μλκ° λμμλ‘ tomesd
λ₯Ό μ¬μ©ν μλ ν₯μμ΄ λμ± λλλ¬μ§λλ€. λν tomesd
λ₯Ό μ¬μ©νλ©΄ 1024x1024μ κ°μ λ λμ ν΄μλμμ νμ΄νλΌμΈμ μ€νν μ μλ€λ μ λ ν₯λ―Έλ‘μ΅λλ€.
torch.compile()
μ μ¬μ©νλ©΄ μΆλ‘ μλλ₯Ό λμ± λμΌ μ μμ΅λλ€.
νμ§
As reported in the paper, ToMe can preserve the quality of the generated images to a great extent while speeding up inference. By increasing the ratio
, it is possible to further speed up inference, but that might come at the cost of a deterioration in the image quality.
To test the quality of the generated samples using our setup, we sampled a few prompts from the βParti Promptsβ (introduced in Parti) and performed inference with the [StableDiffusionPipeline
] in the following settings:
λ
Όλ¬Έμ λ³΄κ³ λ λ°μ κ°μ΄, ToMeλ μμ±λ μ΄λ―Έμ§μ νμ§μ μλΉ λΆλΆ 보쑴νλ©΄μ μΆλ‘ μλλ₯Ό λμΌ μ μμ΅λλ€. ratio
μ λμ΄λ©΄ μΆλ‘ μλλ₯Ό λ λμΌ μ μμ§λ§, μ΄λ―Έμ§ νμ§μ΄ μ νλ μ μμ΅λλ€.
ν΄λΉ μ€μ μ μ¬μ©νμ¬ μμ±λ μνμ νμ§μ ν
μ€νΈνκΈ° μν΄, "Parti ν둬ννΈ"(Partiμμ μκ°)μμ λͺ κ°μ§ ν둬ννΈλ₯Ό μνλ§νκ³ λ€μ μ€μ μμ [StableDiffusionPipeline
]μ μ¬μ©νμ¬ μΆλ‘ μ μννμ΅λλ€:
- Vanilla [
StableDiffusionPipeline
] - [
StableDiffusionPipeline
] + ToMe - [
StableDiffusionPipeline
] + ToMe + xformers
μμ±λ μνμ νμ§μ΄ ν¬κ² μ νλλ κ²μ λ°κ²¬νμ§ λͺ»νμ΅λλ€. λ€μμ μνμ λλ€:
μμ±λ μνμ μ¬κΈ°μμ νμΈν μ μμ΅λλ€. μ΄ μ€νμ μννκΈ° μν΄ μ΄ μ€ν¬λ¦½νΈλ₯Ό μ¬μ©νμ΅λλ€.