|
๏ปฟ<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
[[open-in-colab]] |
|
|
|
|
|
# Diffusion ๋ชจ๋ธ์ ํ์ตํ๊ธฐ |
|
|
|
Unconditional ์ด๋ฏธ์ง ์์ฑ์ ํ์ต์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ์
๊ณผ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ diffusion ๋ชจ๋ธ์์ ์ธ๊ธฐ ์๋ ์ดํ๋ฆฌ์ผ์ด์
์
๋๋ค. ์ผ๋ฐ์ ์ผ๋ก, ๊ฐ์ฅ ์ข์ ๊ฒฐ๊ณผ๋ ํน์ ๋ฐ์ดํฐ์
์ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๊ฒ์ผ๋ก ์ป์ ์ ์์ต๋๋ค. ์ด [ํ๋ธ](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model)์์ ์ด๋ฌํ ๋ง์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ ์ ์์ง๋ง, ๋ง์ฝ ๋ง์์ ๋๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ง ๋ชปํ๋ค๋ฉด, ์ธ์ ๋ ์ง ์ค์ค๋ก ํ์ตํ ์ ์์ต๋๋ค! |
|
|
|
์ด ํํ ๋ฆฌ์ผ์ ๋๋ง์ ๐ฆ ๋๋น ๐ฆ๋ฅผ ์์ฑํ๊ธฐ ์ํด [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) ๋ฐ์ดํฐ์
์ ํ์ ์งํฉ์์ [`UNet2DModel`] ๋ชจ๋ธ์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ๊ฐ๋ฅด์ณ์ค ๊ฒ์
๋๋ค. |
|
|
|
<Tip> |
|
|
|
๐ก ์ด ํ์ต ํํ ๋ฆฌ์ผ์ [Training with ๐งจ Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) ๋
ธํธ๋ถ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค. Diffusion ๋ชจ๋ธ์ ์๋ ๋ฐฉ์ ๋ฐ ์์ธํ ๋ด์ฉ์ ๋
ธํธ๋ถ์ ํ์ธํ์ธ์! |
|
|
|
</Tip> |
|
|
|
์์ ์ ์, ๐ค Datasets์ ๋ถ๋ฌ์ค๊ณ ์ ์ฒ๋ฆฌํ๊ธฐ ์ํด ๋ฐ์ดํฐ์
์ด ์ค์น๋์ด ์๋์ง ๋ค์ GPU์์ ํ์ต์ ๊ฐ์ํํ๊ธฐ ์ํด ๐ค Accelerate ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์. ๊ทธ ํ ํ์ต ๋ฉํธ๋ฆญ์ ์๊ฐํํ๊ธฐ ์ํด [TensorBoard](https://www.tensorflow.org/tensorboard)๋ฅผ ๋ํ ์ค์นํ์ธ์. (๋ํ ํ์ต ์ถ์ ์ ์ํด [Weights & Biases](https://docs.wandb.ai/)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.) |
|
|
|
```bash |
|
!pip install diffusers[training] |
|
``` |
|
|
|
์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ๊ณต์ ํ ๊ฒ์ ๊ถ์ฅํ๋ฉฐ, ์ด๋ฅผ ์ํด์ Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธ์ ํด์ผ ํฉ๋๋ค. (๊ณ์ ์ด ์๋ค๋ฉด [์ฌ๊ธฐ](https://hf.co/join)์์ ๋ง๋ค ์ ์์ต๋๋ค.) ๋
ธํธ๋ถ์์ ๋ก๊ทธ์ธํ ์ ์์ผ๋ฉฐ ๋ฉ์์ง๊ฐ ํ์๋๋ฉด ํ ํฐ์ ์
๋ ฅํ ์ ์์ต๋๋ค. |
|
|
|
```py |
|
>>> from huggingface_hub import notebook_login |
|
|
|
>>> notebook_login() |
|
``` |
|
|
|
๋๋ ํฐ๋ฏธ๋๋ก ๋ก๊ทธ์ธํ ์ ์์ต๋๋ค: |
|
|
|
```bash |
|
huggingface-cli login |
|
``` |
|
|
|
๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์๋นํ ํฌ๊ธฐ ๋๋ฌธ์ [Git-LFS](https://git-lfs.com/)์์ ๋์ฉ๋ ํ์ผ์ ๋ฒ์ ๊ด๋ฆฌ๋ฅผ ํ ์ ์์ต๋๋ค. |
|
|
|
```bash |
|
!sudo apt -qq install git-lfs |
|
!git config --global credential.helper store |
|
``` |
|
|
|
|
|
## ํ์ต ๊ตฌ์ฑ |
|
|
|
ํธ์๋ฅผ ์ํด ํ์ต ํ๋ผ๋ฏธํฐ๋ค์ ํฌํจํ `TrainingConfig` ํด๋์ค๋ฅผ ์์ฑํฉ๋๋ค (์์ ๋กญ๊ฒ ์กฐ์ ๊ฐ๋ฅ): |
|
|
|
```py |
|
>>> from dataclasses import dataclass |
|
|
|
|
|
>>> @dataclass |
|
... class TrainingConfig: |
|
... image_size = 128 # ์์ฑ๋๋ ์ด๋ฏธ์ง ํด์๋ |
|
... train_batch_size = 16 |
|
... eval_batch_size = 16 # ํ๊ฐ ๋์์ ์ํ๋งํ ์ด๋ฏธ์ง ์ |
|
... num_epochs = 50 |
|
... gradient_accumulation_steps = 1 |
|
... learning_rate = 1e-4 |
|
... lr_warmup_steps = 500 |
|
... save_image_epochs = 10 |
|
... save_model_epochs = 30 |
|
... mixed_precision = "fp16" # `no`๋ float32, ์๋ ํผํฉ ์ ๋ฐ๋๋ฅผ ์ํ `fp16` |
|
... output_dir = "ddpm-butterflies-128" # ๋ก์ปฌ ๋ฐ HF Hub์ ์ ์ฅ๋๋ ๋ชจ๋ธ๋ช
|
|
|
|
... push_to_hub = True # ์ ์ฅ๋ ๋ชจ๋ธ์ HF Hub์ ์
๋ก๋ํ ์ง ์ฌ๋ถ |
|
... hub_private_repo = False |
|
... overwrite_output_dir = True # ๋
ธํธ๋ถ์ ๋ค์ ์คํํ ๋ ์ด์ ๋ชจ๋ธ์ ๋ฎ์ด์์ธ์ง |
|
... seed = 0 |
|
|
|
|
|
>>> config = TrainingConfig() |
|
``` |
|
|
|
|
|
## ๋ฐ์ดํฐ์
๋ถ๋ฌ์ค๊ธฐ |
|
|
|
๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) ๋ฐ์ดํฐ์
์ ์ฝ๊ฒ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. |
|
|
|
```py |
|
>>> from datasets import load_dataset |
|
|
|
>>> config.dataset_name = "huggan/smithsonian_butterflies_subset" |
|
>>> dataset = load_dataset(config.dataset_name, split="train") |
|
``` |
|
|
|
๐ก[HugGan Community Event](https://huggingface.co/huggan) ์์ ์ถ๊ฐ์ ๋ฐ์ดํฐ์
์ ์ฐพ๊ฑฐ๋ ๋ก์ปฌ์ [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder)๋ฅผ ๋ง๋ฆ์ผ๋ก์จ ๋๋ง์ ๋ฐ์ดํฐ์
์ ์ฌ์ฉํ ์ ์์ต๋๋ค. HugGan Community Event ์ ๊ฐ์ ธ์จ ๋ฐ์ดํฐ์
์ ๊ฒฝ์ฐ ๋ฆฌํฌ์งํ ๋ฆฌ์ id๋ก `config.dataset_name` ์ ์ค์ ํ๊ณ , ๋๋ง์ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ `imagefolder` ๋ฅผ ์ค์ ํฉ๋๋ค. |
|
|
|
๐ค Datasets์ [`~datasets.Image`] ๊ธฐ๋ฅ์ ์ฌ์ฉํด ์๋์ผ๋ก ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ๋์ฝ๋ฉํ๊ณ [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html)๋ก ๋ถ๋ฌ์ต๋๋ค. ์ด๋ฅผ ์๊ฐํ ํด๋ณด๋ฉด: |
|
|
|
```py |
|
>>> import matplotlib.pyplot as plt |
|
|
|
>>> fig, axs = plt.subplots(1, 4, figsize=(16, 4)) |
|
>>> for i, image in enumerate(dataset[:4]["image"]): |
|
... axs[i].imshow(image) |
|
... axs[i].set_axis_off() |
|
>>> fig.show() |
|
``` |
|
|
|
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png) |
|
|
|
์ด๋ฏธ์ง๋ ๋ชจ๋ ๋ค๋ฅธ ์ฌ์ด์ฆ์ด๊ธฐ ๋๋ฌธ์, ์ฐ์ ์ ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค: |
|
|
|
- `Resize` ๋ `config.image_size` ์ ์ ์๋ ์ด๋ฏธ์ง ์ฌ์ด์ฆ๋ก ๋ณ๊ฒฝํฉ๋๋ค. |
|
- `RandomHorizontalFlip` ์ ๋๋ค์ ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ๋ฏธ๋ฌ๋งํ์ฌ ๋ฐ์ดํฐ์
์ ๋ณด๊ฐํฉ๋๋ค. |
|
- `Normalize` ๋ ๋ชจ๋ธ์ด ์์ํ๋ [-1, 1] ๋ฒ์๋ก ํฝ์
๊ฐ์ ์ฌ์กฐ์ ํ๋๋ฐ ์ค์ํฉ๋๋ค. |
|
|
|
```py |
|
>>> from torchvision import transforms |
|
|
|
>>> preprocess = transforms.Compose( |
|
... [ |
|
... transforms.Resize((config.image_size, config.image_size)), |
|
... transforms.RandomHorizontalFlip(), |
|
... transforms.ToTensor(), |
|
... transforms.Normalize([0.5], [0.5]), |
|
... ] |
|
... ) |
|
``` |
|
|
|
ํ์ต ๋์ค์ `preprocess` ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets์ [`~datasets.Dataset.set_transform`] ๋ฐฉ๋ฒ์ด ์ฌ์ฉ๋ฉ๋๋ค. |
|
|
|
```py |
|
>>> def transform(examples): |
|
... images = [preprocess(image.convert("RGB")) for image in examples["image"]] |
|
... return {"images": images} |
|
|
|
|
|
>>> dataset.set_transform(transform) |
|
``` |
|
|
|
์ด๋ฏธ์ง์ ํฌ๊ธฐ๊ฐ ์กฐ์ ๋์๋์ง ํ์ธํ๊ธฐ ์ํด ์ด๋ฏธ์ง๋ฅผ ๋ค์ ์๊ฐํํด๋ณด์ธ์. ์ด์ [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader)์ ๋ฐ์ดํฐ์
์ ํฌํจํด ํ์ตํ ์ค๋น๊ฐ ๋์์ต๋๋ค! |
|
|
|
```py |
|
>>> import torch |
|
|
|
>>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) |
|
``` |
|
|
|
|
|
## UNet2DModel ์์ฑํ๊ธฐ |
|
|
|
๐งจ Diffusers์ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ๋ค์ ๋ชจ๋ธ ํด๋์ค์์ ์ํ๋ ํ๋ผ๋ฏธํฐ๋ก ์ฝ๊ฒ ์์ฑํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, [`UNet2DModel`]๋ฅผ ์์ฑํ๋ ค๋ฉด: |
|
|
|
```py |
|
>>> from diffusers import UNet2DModel |
|
|
|
>>> model = UNet2DModel( |
|
... sample_size=config.image_size, # ํ๊ฒ ์ด๋ฏธ์ง ํด์๋ |
|
... in_channels=3, # ์
๋ ฅ ์ฑ๋ ์, RGB ์ด๋ฏธ์ง์์ 3 |
|
... out_channels=3, # ์ถ๋ ฅ ์ฑ๋ ์ |
|
... layers_per_block=2, # UNet ๋ธ๋ญ๋น ๋ช ๊ฐ์ ResNet ๋ ์ด์ด๊ฐ ์ฌ์ฉ๋๋์ง |
|
... block_out_channels=(128, 128, 256, 256, 512, 512), # ๊ฐ UNet ๋ธ๋ญ์ ์ํ ์ถ๋ ฅ ์ฑ๋ ์ |
|
... down_block_types=( |
|
... "DownBlock2D", # ์ผ๋ฐ์ ์ธ ResNet ๋ค์ด์ํ๋ง ๋ธ๋ญ |
|
... "DownBlock2D", |
|
... "DownBlock2D", |
|
... "DownBlock2D", |
|
... "AttnDownBlock2D", # spatial self-attention์ด ํฌํจ๋ ์ผ๋ฐ์ ์ธ ResNet ๋ค์ด์ํ๋ง ๋ธ๋ญ |
|
... "DownBlock2D", |
|
... ), |
|
... up_block_types=( |
|
... "UpBlock2D", # ์ผ๋ฐ์ ์ธ ResNet ์
์ํ๋ง ๋ธ๋ญ |
|
... "AttnUpBlock2D", # spatial self-attention์ด ํฌํจ๋ ์ผ๋ฐ์ ์ธ ResNet ์
์ํ๋ง ๋ธ๋ญ |
|
... "UpBlock2D", |
|
... "UpBlock2D", |
|
... "UpBlock2D", |
|
... "UpBlock2D", |
|
... ), |
|
... ) |
|
``` |
|
|
|
์ํ์ ์ด๋ฏธ์ง ํฌ๊ธฐ์ ๋ชจ๋ธ ์ถ๋ ฅ ํฌ๊ธฐ๊ฐ ๋ง๋์ง ๋น ๋ฅด๊ฒ ํ์ธํ๊ธฐ ์ํ ์ข์ ์์ด๋์ด๊ฐ ์์ต๋๋ค: |
|
|
|
```py |
|
>>> sample_image = dataset[0]["images"].unsqueeze(0) |
|
>>> print("Input shape:", sample_image.shape) |
|
Input shape: torch.Size([1, 3, 128, 128]) |
|
|
|
>>> print("Output shape:", model(sample_image, timestep=0).sample.shape) |
|
Output shape: torch.Size([1, 3, 128, 128]) |
|
``` |
|
|
|
ํ๋ฅญํด์! ๋ค์, ์ด๋ฏธ์ง์ ์ฝ๊ฐ์ ๋
ธ์ด์ฆ๋ฅผ ๋ํ๊ธฐ ์ํด ์ค์ผ์ค๋ฌ๊ฐ ํ์ํฉ๋๋ค. |
|
|
|
|
|
## ์ค์ผ์ค๋ฌ ์์ฑํ๊ธฐ |
|
|
|
์ค์ผ์ค๋ฌ๋ ๋ชจ๋ธ์ ํ์ต ๋๋ ์ถ๋ก ์ ์ฌ์ฉํ๋์ง์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ ์๋ํฉ๋๋ค. ์ถ๋ก ์์, ์ค์ผ์ค๋ฌ๋ ๋
ธ์ด์ฆ๋ก๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค. ํ์ต์ ์ค์ผ์ค๋ฌ๋ diffusion ๊ณผ์ ์์์ ํน์ ํฌ์ธํธ๋ก๋ถํฐ ๋ชจ๋ธ์ ์ถ๋ ฅ ๋๋ ์ํ์ ๊ฐ์ ธ์ *๋
ธ์ด์ฆ ์ค์ผ์ค* ๊ณผ *์
๋ฐ์ดํธ ๊ท์น*์ ๋ฐ๋ผ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ ์ฉํฉ๋๋ค. |
|
|
|
`DDPMScheduler`๋ฅผ ๋ณด๋ฉด ์ด์ ์ผ๋ก๋ถํฐ `sample_image`์ ๋๋คํ ๋
ธ์ด์ฆ๋ฅผ ๋ํ๋ `add_noise` ๋ฉ์๋๋ฅผ ์ฌ์ฉํฉ๋๋ค: |
|
|
|
```py |
|
>>> import torch |
|
>>> from PIL import Image |
|
>>> from diffusers import DDPMScheduler |
|
|
|
>>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000) |
|
>>> noise = torch.randn(sample_image.shape) |
|
>>> timesteps = torch.LongTensor([50]) |
|
>>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps) |
|
|
|
>>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0]) |
|
``` |
|
|
|
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png) |
|
|
|
๋ชจ๋ธ์ ํ์ต ๋ชฉ์ ์ ์ด๋ฏธ์ง์ ๋ํด์ง ๋
ธ์ด์ฆ๋ฅผ ์์ธกํ๋ ๊ฒ์
๋๋ค. ์ด ๋จ๊ณ์์ ์์ค์ ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐ๋ ์ ์์ต๋๋ค: |
|
|
|
```py |
|
>>> import torch.nn.functional as F |
|
|
|
>>> noise_pred = model(noisy_image, timesteps).sample |
|
>>> loss = F.mse_loss(noise_pred, noise) |
|
``` |
|
|
|
## ๋ชจ๋ธ ํ์ตํ๊ธฐ |
|
|
|
์ง๊ธ๊น์ง, ๋ชจ๋ธ ํ์ต์ ์์ํ๊ธฐ ์ํด ๋ง์ ๋ถ๋ถ์ ๊ฐ์ถ์์ผ๋ฉฐ ์ด์ ๋จ์ ๊ฒ์ ๋ชจ๋ ๊ฒ์ ์กฐํฉํ๋ ๊ฒ์
๋๋ค. |
|
|
|
์ฐ์ ์ตํฐ๋ง์ด์ (optimizer)์ ํ์ต๋ฅ ์ค์ผ์ค๋ฌ(learning rate scheduler)๊ฐ ํ์ํ ๊ฒ์
๋๋ค: |
|
|
|
```py |
|
>>> from diffusers.optimization import get_cosine_schedule_with_warmup |
|
|
|
>>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) |
|
>>> lr_scheduler = get_cosine_schedule_with_warmup( |
|
... optimizer=optimizer, |
|
... num_warmup_steps=config.lr_warmup_steps, |
|
... num_training_steps=(len(train_dataloader) * config.num_epochs), |
|
... ) |
|
``` |
|
|
|
๊ทธ ํ, ๋ชจ๋ธ์ ํ๊ฐํ๋ ๋ฐฉ๋ฒ์ด ํ์ํฉ๋๋ค. ํ๊ฐ๋ฅผ ์ํด, `DDPMPipeline`์ ์ฌ์ฉํด ๋ฐฐ์น์ ์ด๋ฏธ์ง ์ํ๋ค์ ์์ฑํ๊ณ ๊ทธ๋ฆฌ๋ ํํ๋ก ์ ์ฅํ ์ ์์ต๋๋ค: |
|
|
|
```py |
|
>>> from diffusers import DDPMPipeline |
|
>>> import math |
|
>>> import os |
|
|
|
|
|
>>> def make_grid(images, rows, cols): |
|
... w, h = images[0].size |
|
... grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
... for i, image in enumerate(images): |
|
... grid.paste(image, box=(i % cols * w, i // cols * h)) |
|
... return grid |
|
|
|
|
|
>>> def evaluate(config, epoch, pipeline): |
|
... # ๋๋คํ ๋
ธ์ด์ฆ๋ก ๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ์ถ์ถํฉ๋๋ค.(์ด๋ ์ญ์ ํ diffusion ๊ณผ์ ์
๋๋ค.) |
|
... # ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ์ถ๋ ฅ ํํ๋ `List[PIL.Image]` ์
๋๋ค. |
|
... images = pipeline( |
|
... batch_size=config.eval_batch_size, |
|
... generator=torch.manual_seed(config.seed), |
|
... ).images |
|
|
|
... # ์ด๋ฏธ์ง๋ค์ ๊ทธ๋ฆฌ๋๋ก ๋ง๋ค์ด์ค๋๋ค. |
|
... image_grid = make_grid(images, rows=4, cols=4) |
|
|
|
... # ์ด๋ฏธ์ง๋ค์ ์ ์ฅํฉ๋๋ค. |
|
... test_dir = os.path.join(config.output_dir, "samples") |
|
... os.makedirs(test_dir, exist_ok=True) |
|
... image_grid.save(f"{test_dir}/{epoch:04d}.png") |
|
``` |
|
|
|
TensorBoard์ ๋ก๊น
, ๊ทธ๋๋์ธํธ ๋์ ๋ฐ ํผํฉ ์ ๋ฐ๋ ํ์ต์ ์ฝ๊ฒ ์ํํ๊ธฐ ์ํด ๐ค Accelerate๋ฅผ ํ์ต ๋ฃจํ์ ํจ๊ป ์์ ๋งํ ๋ชจ๋ ๊ตฌ์ฑ ์ ๋ณด๋ค์ ๋ฌถ์ด ์งํํ ์ ์์ต๋๋ค. ํ๋ธ์ ๋ชจ๋ธ์ ์
๋ก๋ ํ๊ธฐ ์ํด ๋ฆฌํฌ์งํ ๋ฆฌ ์ด๋ฆ ๋ฐ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ค๊ธฐ ์ํ ํจ์๋ฅผ ์์ฑํ๊ณ ํ๋ธ์ ์
๋ก๋ํ ์ ์์ต๋๋ค. |
|
|
|
๐ก์๋์ ํ์ต ๋ฃจํ๋ ์ด๋ ต๊ณ ๊ธธ์ด ๋ณด์ผ ์ ์์ง๋ง, ๋์ค์ ํ ์ค์ ์ฝ๋๋ก ํ์ต์ ํ๋ค๋ฉด ๊ทธ๋งํ ๊ฐ์น๊ฐ ์์ ๊ฒ์
๋๋ค! ๋ง์ฝ ๊ธฐ๋ค๋ฆฌ์ง ๋ชปํ๊ณ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ณ ์ถ๋ค๋ฉด, ์๋ ์ฝ๋๋ฅผ ์์ ๋กญ๊ฒ ๋ถ์ฌ๋ฃ๊ณ ์๋์ํค๋ฉด ๋ฉ๋๋ค. ๐ค |
|
|
|
```py |
|
>>> from accelerate import Accelerator |
|
>>> from huggingface_hub import create_repo, upload_folder |
|
>>> from tqdm.auto import tqdm |
|
>>> from pathlib import Path |
|
>>> import os |
|
|
|
|
|
>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): |
|
... # Initialize accelerator and tensorboard logging |
|
... accelerator = Accelerator( |
|
... mixed_precision=config.mixed_precision, |
|
... gradient_accumulation_steps=config.gradient_accumulation_steps, |
|
... log_with="tensorboard", |
|
... project_dir=os.path.join(config.output_dir, "logs"), |
|
... ) |
|
... if accelerator.is_main_process: |
|
... if config.output_dir is not None: |
|
... os.makedirs(config.output_dir, exist_ok=True) |
|
... if config.push_to_hub: |
|
... repo_id = create_repo( |
|
... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True |
|
... ).repo_id |
|
... accelerator.init_trackers("train_example") |
|
|
|
... # ๋ชจ๋ ๊ฒ์ด ์ค๋น๋์์ต๋๋ค. |
|
... # ๊ธฐ์ตํด์ผ ํ ํน์ ํ ์์๋ ์์ผ๋ฉฐ ์ค๋นํ ๋ฐฉ๋ฒ์ ์ ๊ณตํ ๊ฒ๊ณผ ๋์ผํ ์์๋ก ๊ฐ์ฒด์ ์์ถ์ ํ๋ฉด ๋ฉ๋๋ค. |
|
... model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
|
... model, optimizer, train_dataloader, lr_scheduler |
|
... ) |
|
|
|
... global_step = 0 |
|
|
|
... # ์ด์ ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค. |
|
... for epoch in range(config.num_epochs): |
|
... progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) |
|
... progress_bar.set_description(f"Epoch {epoch}") |
|
|
|
... for step, batch in enumerate(train_dataloader): |
|
... clean_images = batch["images"] |
|
... # ์ด๋ฏธ์ง์ ๋ํ ๋
ธ์ด์ฆ๋ฅผ ์ํ๋งํฉ๋๋ค. |
|
... noise = torch.randn(clean_images.shape, device=clean_images.device) |
|
... bs = clean_images.shape[0] |
|
|
|
... # ๊ฐ ์ด๋ฏธ์ง๋ฅผ ์ํ ๋๋คํ ํ์์คํ
(timestep)์ ์ํ๋งํฉ๋๋ค. |
|
... timesteps = torch.randint( |
|
... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device, |
|
... dtype=torch.int64 |
|
... ) |
|
|
|
... # ๊ฐ ํ์์คํ
์ ๋
ธ์ด์ฆ ํฌ๊ธฐ์ ๋ฐ๋ผ ๊นจ๋ํ ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ถ๊ฐํฉ๋๋ค. |
|
... # (์ด๋ foward diffusion ๊ณผ์ ์
๋๋ค.) |
|
... noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) |
|
|
|
... with accelerator.accumulate(model): |
|
... # ๋
ธ์ด์ฆ๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ์์ธกํฉ๋๋ค. |
|
... noise_pred = model(noisy_images, timesteps, return_dict=False)[0] |
|
... loss = F.mse_loss(noise_pred, noise) |
|
... accelerator.backward(loss) |
|
|
|
... accelerator.clip_grad_norm_(model.parameters(), 1.0) |
|
... optimizer.step() |
|
... lr_scheduler.step() |
|
... optimizer.zero_grad() |
|
|
|
... progress_bar.update(1) |
|
... logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} |
|
... progress_bar.set_postfix(**logs) |
|
... accelerator.log(logs, step=global_step) |
|
... global_step += 1 |
|
|
|
... # ๊ฐ ์ํฌํฌ๊ฐ ๋๋ ํ evaluate()์ ๋ช ๊ฐ์ง ๋ฐ๋ชจ ์ด๋ฏธ์ง๋ฅผ ์ ํ์ ์ผ๋ก ์ํ๋งํ๊ณ ๋ชจ๋ธ์ ์ ์ฅํฉ๋๋ค. |
|
... if accelerator.is_main_process: |
|
... pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) |
|
|
|
... if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: |
|
... evaluate(config, epoch, pipeline) |
|
|
|
... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: |
|
... if config.push_to_hub: |
|
... upload_folder( |
|
... repo_id=repo_id, |
|
... folder_path=config.output_dir, |
|
... commit_message=f"Epoch {epoch}", |
|
... ignore_patterns=["step_*", "epoch_*"], |
|
... ) |
|
... else: |
|
... pipeline.save_pretrained(config.output_dir) |
|
``` |
|
|
|
ํด, ์ฝ๋๊ฐ ๊ฝค ๋ง์๋ค์! ํ์ง๋ง ๐ค Accelerate์ [`~accelerate.notebook_launcher`] ํจ์์ ํ์ต์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค. ํจ์์ ํ์ต ๋ฃจํ, ๋ชจ๋ ํ์ต ์ธ์, ํ์ต์ ์ฌ์ฉํ ํ๋ก์ธ์ค ์(์ฌ์ฉ ๊ฐ๋ฅํ GPU์ ์๋ฅผ ๋ณ๊ฒฝํ ์ ์์)๋ฅผ ์ ๋ฌํฉ๋๋ค: |
|
|
|
```py |
|
>>> from accelerate import notebook_launcher |
|
|
|
>>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) |
|
|
|
>>> notebook_launcher(train_loop, args, num_processes=1) |
|
``` |
|
|
|
ํ๋ฒ ํ์ต์ด ์๋ฃ๋๋ฉด, diffusion ๋ชจ๋ธ๋ก ์์ฑ๋ ์ต์ข
๐ฆ์ด๋ฏธ์ง๐ฆ๋ฅผ ํ์ธํด๋ณด๊ธธ ๋ฐ๋๋๋ค! |
|
|
|
```py |
|
>>> import glob |
|
|
|
>>> sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png")) |
|
>>> Image.open(sample_images[-1]) |
|
``` |
|
|
|
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png) |
|
|
|
## ๋ค์ ๋จ๊ณ |
|
|
|
Unconditional ์ด๋ฏธ์ง ์์ฑ์ ํ์ต๋ ์ ์๋ ์์
์ค ํ๋์ ์์์
๋๋ค. ๋ค๋ฅธ ์์
๊ณผ ํ์ต ๋ฐฉ๋ฒ์ [๐งจ Diffusers ํ์ต ์์](../training/overview) ํ์ด์ง์์ ํ์ธํ ์ ์์ต๋๋ค. ๋ค์์ ํ์ตํ ์ ์๋ ๋ช ๊ฐ์ง ์์์
๋๋ค: |
|
|
|
- [Textual Inversion](../training/text_inversion), ํน์ ์๊ฐ์ ๊ฐ๋
์ ํ์ต์์ผ ์์ฑ๋ ์ด๋ฏธ์ง์ ํตํฉ์ํค๋ ์๊ณ ๋ฆฌ์ฆ์
๋๋ค. |
|
- [DreamBooth](../training/dreambooth), ์ฃผ์ ์ ๋ํ ๋ช ๊ฐ์ง ์
๋ ฅ ์ด๋ฏธ์ง๋ค์ด ์ฃผ์ด์ง๋ฉด ์ฃผ์ ์ ๋ํ ๊ฐ์ธํ๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ํ ๊ธฐ์ ์
๋๋ค. |
|
- [Guide](../training/text2image) ๋ฐ์ดํฐ์
์ Stable Diffusion ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์
๋๋ค. |
|
- [Guide](../training/lora) LoRA๋ฅผ ์ฌ์ฉํด ๋งค์ฐ ํฐ ๋ชจ๋ธ์ ๋น ๋ฅด๊ฒ ํ์ธํ๋ํ๊ธฐ ์ํ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ธ ๊ธฐ์ ์
๋๋ค. |
|
|