|
<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
# ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ๊ธฐ์ฌํ๋ ๋ฐฉ๋ฒ |
|
|
|
<Tip> |
|
|
|
๐ก ๋ชจ๋ ์ฌ๋์ด ์๋ ์ ํ ์์ด ์ฝ๊ฒ ์์
์ ๊ณต์ ํ ์ ์๋๋ก ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ์ถ๊ฐํ๋ ์ด์ ์ ๋ํ ์์ธํ ๋ด์ฉ์ GitHub ์ด์ [#841](https://github.com/huggingface/diffusers/issues/841)๋ฅผ ์ฐธ์กฐํ์ธ์. |
|
|
|
</Tip> |
|
|
|
์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ๋ฉด [`DiffusionPipeline`] ์์ ์ํ๋ ์ถ๊ฐ ๊ธฐ๋ฅ์ ์ถ๊ฐํ ์ ์์ต๋๋ค. `DiffusionPipeline` ์์ ๊ตฌ์ถํ ๋์ ๊ฐ์ฅ ํฐ ์ฅ์ ์ ๋๊ตฌ๋ ์ธ์๋ฅผ ํ๋๋ง ์ถ๊ฐํ๋ฉด ํ์ดํ๋ผ์ธ์ ๋ก๋ํ๊ณ ์ฌ์ฉํ ์ ์์ด ์ปค๋ฎค๋ํฐ๊ฐ ๋งค์ฐ ์ฝ๊ฒ ์ ๊ทผํ ์ ์๋ค๋ ๊ฒ์
๋๋ค. |
|
|
|
์ด๋ฒ ๊ฐ์ด๋์์๋ ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ์์ฑํ๋ ๋ฐฉ๋ฒ๊ณผ ์๋ ์๋ฆฌ๋ฅผ ์ค๋ช
ํฉ๋๋ค. |
|
๊ฐ๋จํ๊ฒ ์ค๋ช
ํ๊ธฐ ์ํด `UNet`์ด ๋จ์ผ forward pass๋ฅผ ์ํํ๊ณ ์ค์ผ์ค๋ฌ๋ฅผ ํ ๋ฒ ํธ์ถํ๋ "one-step" ํ์ดํ๋ผ์ธ์ ๋ง๋ค๊ฒ ์ต๋๋ค. |
|
|
|
## ํ์ดํ๋ผ์ธ ์ด๊ธฐํ |
|
|
|
์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ์ํ `one_step_unet.py` ํ์ผ์ ์์ฑํ๋ ๊ฒ์ผ๋ก ์์ํฉ๋๋ค. ์ด ํ์ผ์์, Hub์์ ๋ชจ๋ธ ๊ฐ์ค์น์ ์ค์ผ์ค๋ฌ ๊ตฌ์ฑ์ ๋ก๋ํ ์ ์๋๋ก [`DiffusionPipeline`]์ ์์ํ๋ ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ์์ฑํฉ๋๋ค. one-step ํ์ดํ๋ผ์ธ์๋ `UNet`๊ณผ ์ค์ผ์ค๋ฌ๊ฐ ํ์ํ๋ฏ๋ก ์ด๋ฅผ `__init__` ํจ์์ ์ธ์๋ก ์ถ๊ฐํด์ผํฉ๋๋ค: |
|
|
|
```python |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
|
|
|
|
class UnetSchedulerOneForwardPipeline(DiffusionPipeline): |
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
``` |
|
|
|
ํ์ดํ๋ผ์ธ๊ณผ ๊ทธ ๊ตฌ์ฑ์์(`unet` and `scheduler`)๋ฅผ [`~DiffusionPipeline.save_pretrained`]์ผ๋ก ์ ์ฅํ ์ ์๋๋ก ํ๋ ค๋ฉด `register_modules` ํจ์์ ์ถ๊ฐํ์ธ์: |
|
|
|
```diff |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
|
|
class UnetSchedulerOneForwardPipeline(DiffusionPipeline): |
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
|
|
+ self.register_modules(unet=unet, scheduler=scheduler) |
|
``` |
|
|
|
์ด์ '์ด๊ธฐํ' ๋จ๊ณ๊ฐ ์๋ฃ๋์์ผ๋ forward pass๋ก ์ด๋ํ ์ ์์ต๋๋ค! ๐ฅ |
|
|
|
## Forward pass ์ ์ |
|
|
|
Forward pass ์์๋(`__call__`๋ก ์ ์ํ๋ ๊ฒ์ด ์ข์ต๋๋ค) ์ํ๋ ๊ธฐ๋ฅ์ ์ถ๊ฐํ ์ ์๋ ์์ ํ ์ฐฝ์ ์์ ๊ฐ ์์ต๋๋ค. ์ฐ๋ฆฌ์ ๋๋ผ์ด one-step ํ์ดํ๋ผ์ธ์ ๊ฒฝ์ฐ, ์์์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ณ `timestep=1`์ ์ค์ ํ์ฌ `unet`๊ณผ `scheduler`๋ฅผ ํ ๋ฒ๋ง ํธ์ถํฉ๋๋ค: |
|
|
|
```diff |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
|
|
|
|
class UnetSchedulerOneForwardPipeline(DiffusionPipeline): |
|
def __init__(self, unet, scheduler): |
|
super().__init__() |
|
|
|
self.register_modules(unet=unet, scheduler=scheduler) |
|
|
|
+ def __call__(self): |
|
+ image = torch.randn( |
|
+ (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size), |
|
+ ) |
|
+ timestep = 1 |
|
|
|
+ model_output = self.unet(image, timestep).sample |
|
+ scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample |
|
|
|
+ return scheduler_output |
|
``` |
|
|
|
๋๋ฌ์ต๋๋ค! ๐ ์ด์ ์ด ํ์ดํ๋ผ์ธ์ `unet`๊ณผ `scheduler`๋ฅผ ์ ๋ฌํ์ฌ ์คํํ ์ ์์ต๋๋ค: |
|
|
|
```python |
|
from diffusers import DDPMScheduler, UNet2DModel |
|
|
|
scheduler = DDPMScheduler() |
|
unet = UNet2DModel() |
|
|
|
pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler) |
|
|
|
output = pipeline() |
|
``` |
|
|
|
ํ์ง๋ง ํ์ดํ๋ผ์ธ ๊ตฌ์กฐ๊ฐ ๋์ผํ ๊ฒฝ์ฐ ๊ธฐ์กด ๊ฐ์ค์น๋ฅผ ํ์ดํ๋ผ์ธ์ ๋ก๋ํ ์ ์๋ค๋ ์ฅ์ ์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด one-step ํ์ดํ๋ผ์ธ์ [`google/ddpm-cifar10-32`](https://huggingface.co/google/ddpm-cifar10-32) ๊ฐ์ค์น๋ฅผ ๋ก๋ํ ์ ์์ต๋๋ค: |
|
|
|
```python |
|
pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32") |
|
|
|
output = pipeline() |
|
``` |
|
|
|
## ํ์ดํ๋ผ์ธ ๊ณต์ |
|
|
|
๐งจDiffusers [๋ฆฌํฌ์งํ ๋ฆฌ](https://github.com/huggingface/diffusers)์์ Pull Request๋ฅผ ์ด์ด [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) ํ์ ํด๋์ `one_step_unet.py`์ ๋ฉ์ง ํ์ดํ๋ผ์ธ์ ์ถ๊ฐํ์ธ์. |
|
|
|
๋ณํฉ์ด ๋๋ฉด, `diffusers >= 0.4.0`์ด ์ค์น๋ ์ฌ์ฉ์๋ผ๋ฉด ๋๊ตฌ๋ `custom_pipeline` ์ธ์์ ์ง์ ํ์ฌ ์ด ํ์ดํ๋ผ์ธ์ ๋ง์ ์ฒ๋ผ ๐ช ์ฌ์ฉํ ์ ์์ต๋๋ค: |
|
|
|
```python |
|
from diffusers import DiffusionPipeline |
|
|
|
pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet") |
|
pipe() |
|
``` |
|
|
|
์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ๊ณต์ ํ๋ ๋ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ Hub ์์ ์ ํธํ๋ [๋ชจ๋ธ ๋ฆฌํฌ์งํ ๋ฆฌ](https://huggingface.co/docs/hub/models-uploading)์ ์ง์ `one_step_unet.py` ํ์ผ์ ์
๋ก๋ํ๋ ๊ฒ์
๋๋ค. `one_step_unet.py` ํ์ผ์ ์ง์ ํ๋ ๋์ ๋ชจ๋ธ ์ ์ฅ์ id๋ฅผ `custom_pipeline` ์ธ์์ ์ ๋ฌํ์ธ์: |
|
|
|
```python |
|
from diffusers import DiffusionPipeline |
|
|
|
pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="stevhliu/one_step_unet") |
|
``` |
|
|
|
๋ค์ ํ์์ ๋ ๊ฐ์ง ๊ณต์ ์ํฌํ๋ก์ฐ๋ฅผ ๋น๊ตํ์ฌ ์์ ์๊ฒ ๊ฐ์ฅ ์ ํฉํ ์ต์
์ ๊ฒฐ์ ํ๋ ๋ฐ ๋์์ด ๋๋ ์ ๋ณด๋ฅผ ํ์ธํ์ธ์: |
|
|
|
| | GitHub ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ | HF Hub ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ | |
|
|----------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------| |
|
| ์ฌ์ฉ๋ฒ | ๋์ผ | ๋์ผ | |
|
| ๋ฆฌ๋ทฐ ๊ณผ์ | ๋ณํฉํ๊ธฐ ์ ์ GitHub์์ Pull Request๋ฅผ ์ด๊ณ Diffusers ํ์ ๊ฒํ ๊ณผ์ ์ ๊ฑฐ์นฉ๋๋ค. ์๋๊ฐ ๋๋ฆด ์ ์์ต๋๋ค. | ๊ฒํ ์์ด Hub ์ ์ฅ์์ ๋ฐ๋ก ์
๋ก๋ํฉ๋๋ค. ๊ฐ์ฅ ๋น ๋ฅธ ์ํฌํ๋ก์ฐ ์
๋๋ค. | |
|
| ๊ฐ์์ฑ | ๊ณต์ Diffusers ์ ์ฅ์ ๋ฐ ๋ฌธ์์ ํฌํจ๋์ด ์์ต๋๋ค. | HF ํ๋ธ ํ๋กํ์ ํฌํจ๋๋ฉฐ ๊ฐ์์ฑ์ ํ๋ณดํ๊ธฐ ์ํด ์์ ์ ์ฌ์ฉ๋/ํ๋ก๋ชจ์
์ ์์กดํฉ๋๋ค. | |
|
|
|
<Tip> |
|
|
|
๐ก ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ ํ์ผ์ ์ํ๋ ํจํค์ง๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ฌ์ฉ์๊ฐ ํจํค์ง๋ฅผ ์ค์นํ๊ธฐ๋ง ํ๋ฉด ๋ชจ๋ ๊ฒ์ด ์ ์์ ์ผ๋ก ์๋ํฉ๋๋ค. ํ์ดํ๋ผ์ธ์ด ์๋์ผ๋ก ๊ฐ์ง๋๋ฏ๋ก `DiffusionPipeline`์์ ์์ํ๋ ํ์ดํ๋ผ์ธ ํด๋์ค๊ฐ ํ๋๋ง ์๋์ง ํ์ธํ์ธ์. |
|
|
|
</Tip> |
|
|
|
## ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ์ด๋ป๊ฒ ์๋ํ๋์? |
|
|
|
์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ [`DiffusionPipeline`]์ ์์ํ๋ ํด๋์ค์
๋๋ค: |
|
|
|
- [`custom_pipeline`] ์ธ์๋ก ๋ก๋ํ ์ ์์ต๋๋ค. |
|
- ๋ชจ๋ธ ๊ฐ์ค์น ๋ฐ ์ค์ผ์ค๋ฌ ๊ตฌ์ฑ์ [`pretrained_model_name_or_path`]์์ ๋ก๋๋ฉ๋๋ค. |
|
- ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์์ ๊ธฐ๋ฅ์ ๊ตฌํํ๋ ์ฝ๋๋ `pipeline.py` ํ์ผ์ ์ ์๋์ด ์์ต๋๋ค. |
|
|
|
๊ณต์ ์ ์ฅ์์์ ๋ชจ๋ ํ์ดํ๋ผ์ธ ๊ตฌ์ฑ ์์ ๊ฐ์ค์น๋ฅผ ๋ก๋ํ ์ ์๋ ๊ฒฝ์ฐ๊ฐ ์์ต๋๋ค. ์ด ๊ฒฝ์ฐ ๋ค๋ฅธ ๊ตฌ์ฑ ์์๋ ํ์ดํ๋ผ์ธ์ ์ง์ ์ ๋ฌํด์ผ ํฉ๋๋ค: |
|
|
|
```python |
|
from diffusers import DiffusionPipeline |
|
from transformers import CLIPFeatureExtractor, CLIPModel |
|
|
|
model_id = "CompVis/stable-diffusion-v1-4" |
|
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" |
|
|
|
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id) |
|
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16) |
|
|
|
pipeline = DiffusionPipeline.from_pretrained( |
|
model_id, |
|
custom_pipeline="clip_guided_stable_diffusion", |
|
clip_model=clip_model, |
|
feature_extractor=feature_extractor, |
|
scheduler=scheduler, |
|
torch_dtype=torch.float16, |
|
) |
|
``` |
|
|
|
์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ ๋ง๋ฒ์ ๋ค์ ์ฝ๋์ ๋ด๊ฒจ ์์ต๋๋ค. ์ด ์ฝ๋๋ฅผ ํตํด ์ปค๋ฎค๋ํฐ ํ์ดํ๋ผ์ธ์ GitHub ๋๋ Hub์์ ๋ก๋ํ ์ ์์ผ๋ฉฐ, ๋ชจ๋ ๐งจ Diffusers ํจํค์ง์์ ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
|
|
```python |
|
# 2. ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ๋ก๋ํฉ๋๋ค. ์ฌ์ฉ์ ์ง์ ๋ชจ๋์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ Hub์์ ๋ก๋ํฉ๋๋ค |
|
# ๋ช
์์ ํด๋์ค์์ ๋ก๋ํ๋ ๊ฒฝ์ฐ, ์ด๋ฅผ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค. |
|
if custom_pipeline is not None: |
|
pipeline_class = get_class_from_dynamic_module( |
|
custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline |
|
) |
|
elif cls != DiffusionPipeline: |
|
pipeline_class = cls |
|
else: |
|
diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) |
|
pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) |
|
``` |
|
|