diffusers-sdxl-controlnet / docs /source /ko /training /unconditional_training.md
svjack's picture
Upload 1392 files
43b7e92 verified
|
raw
history blame
5.47 kB

Unconditional ์ด๋ฏธ์ง€ ์ƒ์„ฑ

unconditional ์ด๋ฏธ์ง€ ์ƒ์„ฑ์€ text-to-image ๋˜๋Š” image-to-image ๋ชจ๋ธ๊ณผ ๋‹ฌ๋ฆฌ ํ…์ŠคํŠธ๋‚˜ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์กฐ๊ฑด์ด ์—†์ด ํ•™์Šต ๋ฐ์ดํ„ฐ ๋ถ„ํฌ์™€ ์œ ์‚ฌํ•œ ์ด๋ฏธ์ง€๋งŒ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๊ธฐ์กด์— ์กด์žฌํ•˜๋˜ ๋ฐ์ดํ„ฐ์…‹๊ณผ ์ž์‹ ๋งŒ์˜ ์ปค์Šคํ…€ ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•ด unconditional image generation ๋ชจ๋ธ์„ ํ›ˆ๋ จํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ํ›ˆ๋ จ ์„ธ๋ถ€ ์‚ฌํ•ญ์— ๋Œ€ํ•ด ๋” ์ž์„ธํžˆ ์•Œ๊ณ  ์‹ถ๋‹ค๋ฉด unconditional image generation์„ ์œ„ํ•œ ๋ชจ๋“  ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์—ฌ๊ธฐ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์ „, ๋จผ์ € ์˜์กด์„ฑ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์„ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

pip install diffusers[training] accelerate datasets

๊ทธ ๋‹ค์Œ ๐Ÿค— Accelerate ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

accelerate config

๋ณ„๋„์˜ ์„ค์ • ์—†์ด ๊ธฐ๋ณธ ์„ค์ •์œผ๋กœ ๐Ÿค— Accelerate ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”ํ•ด๋ด…์‹œ๋‹ค.

accelerate config default

๋…ธํŠธ๋ถ๊ณผ ๊ฐ™์€ ๋Œ€ํ™”ํ˜• ์‰˜์„ ์ง€์›ํ•˜์ง€ ์•Š๋Š” ํ™˜๊ฒฝ์˜ ๊ฒฝ์šฐ, ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‚ฌ์šฉํ•ด๋ณผ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.

from accelerate.utils import write_basic_config

write_basic_config()

๋ชจ๋ธ์„ ํ—ˆ๋ธŒ์— ์—…๋กœ๋“œํ•˜๊ธฐ

ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ž๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ํ—ˆ๋ธŒ์— ๋ชจ๋ธ์„ ์—…๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

--push_to_hub

์ฒดํฌํฌ์ธํŠธ ์ €์žฅํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

ํ›ˆ๋ จ ์ค‘ ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•  ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•˜์—ฌ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ •๊ธฐ์ ์œผ๋กœ ์ €์žฅํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ €์žฅํ•˜๋ ค๋ฉด ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ž๋ฅผ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค:

--checkpointing_steps=500

์ „์ฒด ํ›ˆ๋ จ ์ƒํƒœ๋Š” 500์Šคํ…๋งˆ๋‹ค output_dir์˜ ํ•˜์œ„ ํด๋”์— ์ €์žฅ๋˜๋ฉฐ, ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— --resume_from_checkpoint ์ธ์ž๋ฅผ ์ „๋‹ฌํ•จ์œผ๋กœ์จ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์˜ค๊ณ  ํ›ˆ๋ จ์„ ์žฌ๊ฐœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

--resume_from_checkpoint="checkpoint-1500"

ํŒŒ์ธํŠœ๋‹

์ด์ œ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹œ์ž‘ํ•  ์ค€๋น„๊ฐ€ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค! --dataset_name ์ธ์ž์— ํŒŒ์ธํŠœ๋‹ํ•  ๋ฐ์ดํ„ฐ์…‹ ์ด๋ฆ„์„ ์ง€์ •ํ•œ ๋‹ค์Œ, --output_dir ์ธ์ž์— ์ง€์ •๋œ ๊ฒฝ๋กœ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ๋ณธ์ธ๋งŒ์˜ ๋ฐ์ดํ„ฐ์…‹๋ฅผ ์‚ฌ์šฉํ•˜๋ ค๋ฉด, ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์…‹ ๋งŒ๋“ค๊ธฐ ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋Š” diffusion_pytorch_model.bin ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ , ๊ทธ๊ฒƒ์„ ๋‹น์‹ ์˜ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

๐Ÿ’ก ์ „์ฒด ํ•™์Šต์€ V100 GPU 4๊ฐœ๋ฅผ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ, 2์‹œ๊ฐ„์ด ์†Œ์š”๋ฉ๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด, Oxford Flowers ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•ด ํŒŒ์ธํŠœ๋‹ํ•  ๊ฒฝ์šฐ:

accelerate launch train_unconditional.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=64 \
  --output_dir="ddpm-ema-flowers-64" \
  --train_batch_size=16 \
  --num_epochs=100 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_warmup_steps=500 \
  --mixed_precision=no \
  --push_to_hub
[Naruto](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ:
accelerate launch train_unconditional.py \
  --dataset_name="lambdalabs/naruto-blip-captions" \
  --resolution=64 \
  --output_dir="ddpm-ema-naruto-64" \
  --train_batch_size=16 \
  --num_epochs=100 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --lr_warmup_steps=500 \
  --mixed_precision=no \
  --push_to_hub

์—ฌ๋Ÿฌ๊ฐœ์˜ GPU๋กœ ํ›ˆ๋ จํ•˜๊ธฐ

accelerate์„ ์‚ฌ์šฉํ•˜๋ฉด ์›ํ™œํ•œ ๋‹ค์ค‘ GPU ํ›ˆ๋ จ์ด ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. accelerate์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ„์‚ฐ ํ›ˆ๋ จ์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด ์—ฌ๊ธฐ ์ง€์นจ์„ ๋”ฐ๋ฅด์„ธ์š”. ๋‹ค์Œ์€ ๋ช…๋ น์–ด ์˜ˆ์ œ์ž…๋‹ˆ๋‹ค.

accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \
  --dataset_name="lambdalabs/naruto-blip-captions" \
  --resolution=64 --center_crop --random_flip \
  --output_dir="ddpm-ema-naruto-64" \
  --train_batch_size=16 \
  --num_epochs=100 \
  --gradient_accumulation_steps=1 \
  --use_ema \
  --learning_rate=1e-4 \
  --lr_warmup_steps=500 \
  --mixed_precision="fp16" \
  --logger="wandb" \
  --push_to_hub