์ปค์คํ Diffusion ํ์ต ์์
์ปค์คํ Diffusion์ ํผ์ฌ์ฒด์ ์ด๋ฏธ์ง ๋ช ์ฅ(4~5์ฅ)๋ง ์ฃผ์ด์ง๋ฉด Stable Diffusion์ฒ๋ผ text-to-image ๋ชจ๋ธ์ ์ปค์คํฐ๋ง์ด์งํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. 'train_custom_diffusion.py' ์คํฌ๋ฆฝํธ๋ ํ์ต ๊ณผ์ ์ ๊ตฌํํ๊ณ ์ด๋ฅผ Stable Diffusion์ ๋ง๊ฒ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
์ด ๊ต์ก ์ฌ๋ก๋ Nupur Kumari๊ฐ ์ ๊ณตํ์์ต๋๋ค. (Custom Diffusion์ ์ ์ ์ค ํ๋ช ).
๋ก์ปฌ์์ PyTorch๋ก ์คํํ๊ธฐ
Dependencies ์ค์นํ๊ธฐ
์คํฌ๋ฆฝํธ๋ฅผ ์คํํ๊ธฐ ์ ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ํ์ต dependencies๋ฅผ ์ค์นํด์ผ ํฉ๋๋ค:
์ค์
์์ ์คํฌ๋ฆฝํธ์ ์ต์ ๋ฒ์ ์ ์ฑ๊ณต์ ์ผ๋ก ์คํํ๋ ค๋ฉด ์์ค๋ก๋ถํฐ ์ค์นํ๋ ๊ฒ์ ๋งค์ฐ ๊ถ์ฅํ๋ฉฐ, ์์ ์คํฌ๋ฆฝํธ๋ฅผ ์์ฃผ ์ ๋ฐ์ดํธํ๋ ๋งํผ ์ผ๋ถ ์์ ๋ณ ์๊ตฌ ์ฌํญ์ ์ค์นํ๊ณ ์ค์น๋ฅผ ์ต์ ์ํ๋ก ์ ์งํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ์ด๋ฅผ ์ํด ์ ๊ฐ์ ํ๊ฒฝ์์ ๋ค์ ๋จ๊ณ๋ฅผ ์คํํ์ธ์:
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
example folder๋ก cdํ์ฌ ์ด๋ํ์ธ์.
cd examples/custom_diffusion
์ด์ ์คํ
pip install -r requirements.txt
pip install clip-retrieval
๊ทธ๋ฆฌ๊ณ ๐คAccelerate ํ๊ฒฝ์ ์ด๊ธฐํ:
accelerate config
๋๋ ์ฌ์ฉ์ ํ๊ฒฝ์ ๋ํ ์ง๋ฌธ์ ๋ตํ์ง ์๊ณ ๊ธฐ๋ณธ ๊ฐ์ ๊ตฌ์ฑ์ ์ฌ์ฉํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ์ธ์.
accelerate config default
๋๋ ์ฌ์ฉ ์ค์ธ ํ๊ฒฝ์ด ๋ํํ ์ ธ์ ์ง์ํ์ง ์๋ ๊ฒฝ์ฐ(์: jupyter notebook)
from accelerate.utils import write_basic_config
write_basic_config()
๊ณ ์์ด ์์ ๐บ
์ด์ ๋ฐ์ดํฐ์ ์ ๊ฐ์ ธ์ต๋๋ค. ์ฌ๊ธฐ์์ ๋ฐ์ดํฐ์ ์ ๋ค์ด๋ก๋ํ๊ณ ์์ถ์ ํ๋๋ค. ์ง์ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ๋ ค๋ฉด ํ์ต์ฉ ๋ฐ์ดํฐ์ ์์ฑํ๊ธฐ ๊ฐ์ด๋๋ฅผ ์ฐธ๊ณ ํ์ธ์.
๋ํ 'clip-retrieval'์ ์ฌ์ฉํ์ฌ 200๊ฐ์ ์ค์ ์ด๋ฏธ์ง๋ฅผ ์์งํ๊ณ , regularization์ผ๋ก์ ์ด๋ฅผ ํ์ต ๋ฐ์ดํฐ์
์ ํ๊ฒ ์ด๋ฏธ์ง์ ๊ฒฐํฉํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ์ฃผ์ด์ง ํ๊ฒ ์ด๋ฏธ์ง์ ๋ํ ๊ณผ์ ํฉ์ ๋ฐฉ์งํ ์ ์์ต๋๋ค. ๋ค์ ํ๋๊ทธ๋ฅผ ์ฌ์ฉํ๋ฉด prior_loss_weight=1.
๋ก prior_preservation
, real_prior
regularization์ ํ์ฑํํ ์ ์์ต๋๋ค.
ํด๋์ค_ํ๋กฌํํธ๋ ๋์ ์ด๋ฏธ์ง์ ๋์ผํ ์นดํ
๊ณ ๋ฆฌ ์ด๋ฆ์ด์ด์ผ ํฉ๋๋ค. ์์ง๋ ์ค์ ์ด๋ฏธ์ง์๋
class_prompt์ ์ ์ฌํ ํ
์คํธ ์บก์
์ด ์์ต๋๋ค. ๊ฒ์๋ ์ด๋ฏธ์ง๋
class_data_dir์ ์ ์ฅ๋ฉ๋๋ค. ์์ฑ๋ ์ด๋ฏธ์ง๋ฅผ regularization์ผ๋ก ์ฌ์ฉํ๊ธฐ ์ํด
real_prior`๋ฅผ ๋นํ์ฑํํ ์ ์์ต๋๋ค. ์ค์ ์ด๋ฏธ์ง๋ฅผ ์์งํ๋ ค๋ฉด ํ๋ จ ์ ์ ์ด ๋ช
๋ น์ ๋จผ์ ์ฌ์ฉํ์ญ์์ค.
pip install clip-retrieval
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
์ฐธ๊ณ : stable-diffusion-2 768x768 ๋ชจ๋ธ์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ 'ํด์๋'๋ฅผ 768๋ก ๋ณ๊ฒฝํ์ธ์.
์คํฌ๋ฆฝํธ๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ์ pytorch_custom_diffusion_weights.bin
ํ์ผ์ ์์ฑํ์ฌ ์ ์ฅ์์ ์ ์ฅํฉ๋๋ค.
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="path-to-save-model"
export INSTANCE_DIR="./data/cat"
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--class_data_dir=./real_reg/samples_cat/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt="cat" --num_class_images=200 \
--instance_prompt="photo of a <new1> cat" \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--scale_lr --hflip \
--modifier_token "<new1>" \
--push_to_hub
๋ ๋ฎ์ VRAM ์๊ตฌ ์ฌํญ(GPU๋น 16GB)์ผ๋ก ๋ ๋น ๋ฅด๊ฒ ํ๋ จํ๋ ค๋ฉด --enable_xformers_memory_efficient_attention
์ ์ฌ์ฉํ์ธ์. ์ค์น ๋ฐฉ๋ฒ์ ๊ฐ์ด๋๋ฅผ ๋ฐ๋ฅด์ธ์.
๊ฐ์ค์น ๋ฐ ํธํฅ(wandb
)์ ์ฌ์ฉํ์ฌ ์คํ์ ์ถ์ ํ๊ณ ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ๋ ค๋ฉด(๊ฐ๋ ฅํ ๊ถ์ฅํฉ๋๋ค) ๋ค์ ๋จ๊ณ๋ฅผ ๋ฐ๋ฅด์ธ์:
wandb
์ค์น:pip install wandb
.- ๋ก๊ทธ์ธ :
wandb login
. - ๊ทธ๋ฐ ๋ค์ ํธ๋ ์ด๋์ ์์ํ๋ ๋์
validation_prompt
๋ฅผ ์ง์ ํ๊ณreport_to
๋ฅผwandb
๋ก ์ค์ ํฉ๋๋ค. ๋ค์๊ณผ ๊ฐ์ ๊ด๋ จ ์ธ์๋ฅผ ๊ตฌ์ฑํ ์๋ ์์ต๋๋ค:num_validation_images
validation_steps
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--class_data_dir=./real_reg/samples_cat/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt="cat" --num_class_images=200 \
--instance_prompt="photo of a <new1> cat" \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=250 \
--scale_lr --hflip \
--modifier_token "<new1>" \
--validation_prompt="<new1> cat sitting in a bucket" \
--report_to="wandb" \
--push_to_hub
๋ค์์ Weights and Biases page์ ์์์ด๋ฉฐ, ์ฌ๋ฌ ํ์ต ์ธ๋ถ ์ ๋ณด์ ํจ๊ป ์ค๊ฐ ๊ฒฐ๊ณผ๋ค์ ํ์ธํ ์ ์์ต๋๋ค.
--push_to_hub
๋ฅผ ์ง์ ํ๋ฉด ํ์ต๋ ํ๋ผ๋ฏธํฐ๊ฐ ํ๊น
ํ์ด์ค ํ๋ธ์ ๋ฆฌํฌ์งํ ๋ฆฌ์ ํธ์๋ฉ๋๋ค. ๋ค์์ ์์ ๋ฆฌํฌ์งํ ๋ฆฌ์
๋๋ค.
๋ฉํฐ ์ปจ์ ์ ๋ํ ํ์ต ๐ฑ๐ชต
this์ ์ ์ฌํ๊ฒ ๊ฐ ์ปจ์ ์ ๋ํ ์ ๋ณด๊ฐ ํฌํจ๋ json ํ์ผ์ ์ ๊ณตํฉ๋๋ค.
์ค์ ์ด๋ฏธ์ง๋ฅผ ์์งํ๋ ค๋ฉด json ํ์ผ์ ๊ฐ ์ปจ์ ์ ๋ํด ์ด ๋ช ๋ น์ ์คํํฉ๋๋ค.
pip install clip-retrieval
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
๊ทธ๋ผ ์ฐ๋ฆฌ๋ ํ์ต์ํฌ ์ค๋น๊ฐ ๋์์ต๋๋ค!
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--output_dir=$OUTPUT_DIR \
--concepts_list=./concept_list.json \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=1e-5 \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--num_class_images=200 \
--scale_lr --hflip \
--modifier_token "<new1>+<new2>" \
--push_to_hub
๋ค์์ Weights and Biases page์ ์์์ด๋ฉฐ, ๋ค๋ฅธ ํ์ต ์ธ๋ถ ์ ๋ณด์ ํจ๊ป ์ค๊ฐ ๊ฒฐ๊ณผ๋ค์ ํ์ธํ ์ ์์ต๋๋ค.
์ฌ๋ ์ผ๊ตด์ ๋ํ ํ์ต
์ฌ๋ ์ผ๊ตด์ ๋ํ ํ์ธํ๋์ ์ํด ๋ค์๊ณผ ๊ฐ์ ์ค์ ์ด ๋ ํจ๊ณผ์ ์ด๋ผ๋ ๊ฒ์ ํ์ธํ์ต๋๋ค: learning_rate=5e-6
, max_train_steps=1000 to 2000
, freeze_model=crossattn
์ ์ต์ 15~20๊ฐ์ ์ด๋ฏธ์ง๋ก ์ค์ ํฉ๋๋ค.
์ค์ ์ด๋ฏธ์ง๋ฅผ ์์งํ๋ ค๋ฉด ํ๋ จ ์ ์ ์ด ๋ช ๋ น์ ๋จผ์ ์ฌ์ฉํ์ญ์์ค.
pip install clip-retrieval
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
์ด์ ํ์ต์ ์์ํ์ธ์!
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export OUTPUT_DIR="path-to-save-model"
export INSTANCE_DIR="path-to-images"
accelerate launch train_custom_diffusion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--class_data_dir=./real_reg/samples_person/ \
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
--class_prompt="person" --num_class_images=200 \
--instance_prompt="photo of a <new1> person" \
--resolution=512 \
--train_batch_size=2 \
--learning_rate=5e-6 \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--scale_lr --hflip --noaug \
--freeze_model crossattn \
--modifier_token "<new1>" \
--enable_xformers_memory_efficient_attention \
--push_to_hub
์ถ๋ก
์ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ต์ํจ ํ์๋ ์๋ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์คํํ ์ ์์ต๋๋ค. ํ๋กฌํํธ์ 'modifier token'(์: ์ ์์ ์์๋ <new1>)์ ๋ฐ๋์ ํฌํจํด์ผ ํฉ๋๋ค.
import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
pipe.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
image = pipe(
"<new1> cat sitting in a bucket",
num_inference_steps=100,
guidance_scale=6.0,
eta=1.0,
).images[0]
image.save("cat.png")
ํ๋ธ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ์ด๋ฌํ ๋งค๊ฐ๋ณ์๋ฅผ ์ง์ ๋ก๋ํ ์ ์์ต๋๋ค:
import torch
from huggingface_hub.repocard import RepoCard
from diffusers import DiffusionPipeline
model_id = "sayakpaul/custom-diffusion-cat"
card = RepoCard.load(model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
image = pipe(
"<new1> cat sitting in a bucket",
num_inference_steps=100,
guidance_scale=6.0,
eta=1.0,
).images[0]
image.save("cat.png")
๋ค์์ ์ฌ๋ฌ ์ปจ์ ์ผ๋ก ์ถ๋ก ์ ์ํํ๋ ์์ ์ ๋๋ค:
import torch
from huggingface_hub.repocard import RepoCard
from diffusers import DiffusionPipeline
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
card = RepoCard.load(model_id)
base_model_id = card.data.to_dict()["base_model"]
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
image = pipe(
"the <new1> cat sculpture in the style of a <new2> wooden pot",
num_inference_steps=100,
guidance_scale=6.0,
eta=1.0,
).images[0]
image.save("multi-subject.png")
์ฌ๊ธฐ์ '๊ณ ์์ด'์ '๋๋ฌด ๋๋น'๋ ์ฌ๋ฌ ์ปจ์ ์ ๋งํฉ๋๋ค.
ํ์ต๋ ์ฒดํฌํฌ์ธํธ์์ ์ถ๋ก ํ๊ธฐ
--checkpointing_steps
์ธ์๋ฅผ ์ฌ์ฉํ ๊ฒฝ์ฐ ํ์ต ๊ณผ์ ์์ ์ ์ฅ๋ ์ ์ฒด ์ฒดํฌํฌ์ธํธ ์ค ํ๋์์ ์ถ๋ก ์ ์ํํ ์๋ ์์ต๋๋ค.
Grads๋ฅผ None์ผ๋ก ์ค์
๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ ์ฝํ๋ ค๋ฉด ์คํฌ๋ฆฝํธ์ --set_grads_to_none
์ธ์๋ฅผ ์ ๋ฌํ์ธ์. ์ด๋ ๊ฒ ํ๋ฉด ์ฑ์ ์ด 0์ด ์๋ ์์์ผ๋ก ์ค์ ๋ฉ๋๋ค. ๊ทธ๋ฌ๋ ํน์ ๋์์ด ๋ณ๊ฒฝ๋๋ฏ๋ก ๋ฌธ์ ๊ฐ ๋ฐ์ํ๋ฉด ์ด ์ธ์๋ฅผ ์ ๊ฑฐํ์ธ์.
์์ธํ ์ ๋ณด: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
์คํ ๊ฒฐ๊ณผ
์คํ์ ๋ํ ์์ธํ ๋ด์ฉ์ ๋น์ฌ ์นํ์ด์ง๋ฅผ ์ฐธ์กฐํ์ธ์.