svjack's picture
Upload 1392 files
43b7e92 verified
|
raw
history blame
14.7 kB

Textual-Inversion

[[open-in-colab]]

textual-inversion์€ ์†Œ์ˆ˜์˜ ์˜ˆ์‹œ ์ด๋ฏธ์ง€์—์„œ ์ƒˆ๋กœ์šด ์ฝ˜์…‰ํŠธ๋ฅผ ํฌ์ฐฉํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค. ์ด ๊ธฐ์ˆ ์€ ์›๋ž˜ Latent Diffusion์—์„œ ์‹œ์—ฐ๋˜์—ˆ์ง€๋งŒ, ์ดํ›„ Stable Diffusion๊ณผ ๊ฐ™์€ ์œ ์‚ฌํ•œ ๋‹ค๋ฅธ ๋ชจ๋ธ์—๋„ ์ ์šฉ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ํ•™์Šต๋œ ์ฝ˜์…‰ํŠธ๋Š” text-to-image ํŒŒ์ดํ”„๋ผ์ธ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ ๋” ์ž˜ ์ œ์–ดํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ชจ๋ธ์€ ํ…์ŠคํŠธ ์ธ์ฝ”๋”์˜ ์ž„๋ฒ ๋”ฉ ๊ณต๊ฐ„์—์„œ ์ƒˆ๋กœ์šด '๋‹จ์–ด'๋ฅผ ํ•™์Šตํ•˜์—ฌ ๊ฐœ์ธํ™”๋œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์œ„ํ•œ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ ๋‚ด์—์„œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

Textual Inversion example By using just 3-5 images you can teach new concepts to a model such as Stable Diffusion for personalized image generation (image source).

์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” textual-inversion์œผ๋กœ runwayml/stable-diffusion-v1-5 ๋ชจ๋ธ์„ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ์ด ๊ฐ€์ด๋“œ์—์„œ ์‚ฌ์šฉ๋œ ๋ชจ๋“  textual-inversion ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋Š” ์—ฌ๊ธฐ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‚ด๋ถ€์ ์œผ๋กœ ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”์ง€ ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ณ  ์‹ถ์œผ์‹œ๋‹ค๋ฉด ํ•ด๋‹น ๋งํฌ๋ฅผ ์ฐธ์กฐํ•ด์ฃผ์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

Stable Diffusion Textual Inversion Concepts Library์—๋Š” ์ปค๋ฎค๋‹ˆํ‹ฐ์—์„œ ์ œ์ž‘ํ•œ ํ•™์Šต๋œ textual-inversion ๋ชจ๋ธ๋“ค์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์‹œ๊ฐ„์ด ์ง€๋‚จ์— ๋”ฐ๋ผ ๋” ๋งŽ์€ ์ฝ˜์…‰ํŠธ๋“ค์ด ์ถ”๊ฐ€๋˜์–ด ์œ ์šฉํ•œ ๋ฆฌ์†Œ์Šค๋กœ ์„ฑ์žฅํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค!

์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ํ•™์Šต์„ ์œ„ํ•œ ์˜์กด์„ฑ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์„ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

pip install diffusers accelerate transformers

์˜์กด์„ฑ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์˜ ์„ค์น˜๊ฐ€ ์™„๋ฃŒ๋˜๋ฉด, ๐Ÿค—Accelerate ํ™˜๊ฒฝ์„ ์ดˆ๊ธฐํ™”์‹œํ‚ต๋‹ˆ๋‹ค.

accelerate config

๋ณ„๋„์˜ ์„ค์ •์—†์ด, ๊ธฐ๋ณธ ๐Ÿค—Accelerate ํ™˜๊ฒฝ์„ ์„ค์ •ํ•˜๋ ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ํ•˜์„ธ์š”:

accelerate config default

๋˜๋Š” ์‚ฌ์šฉ ์ค‘์ธ ํ™˜๊ฒฝ์ด ๋…ธํŠธ๋ถ๊ณผ ๊ฐ™์€ ๋Œ€ํ™”ํ˜• ์…ธ์„ ์ง€์›ํ•˜์ง€ ์•Š๋Š”๋‹ค๋ฉด, ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

from accelerate.utils import write_basic_config

write_basic_config()

๋งˆ์ง€๋ง‰์œผ๋กœ, Memory-Efficient Attention์„ ํ†ตํ•ด ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๊ธฐ ์œ„ํ•ด xFormers๋ฅผ ์„ค์น˜ํ•ฉ๋‹ˆ๋‹ค. xFormers๋ฅผ ์„ค์น˜ํ•œ ํ›„, ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— --enable_xformers_memory_efficient_attention ์ธ์ž๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค. xFormers๋Š” Flax์—์„œ ์ง€์›๋˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.

ํ—ˆ๋ธŒ์— ๋ชจ๋ธ ์—…๋กœ๋“œํ•˜๊ธฐ

๋ชจ๋ธ์„ ํ—ˆ๋ธŒ์— ์ €์žฅํ•˜๋ ค๋ฉด, ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ž๋ฅผ ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

--push_to_hub

์ฒดํฌํฌ์ธํŠธ ์ €์žฅ ๋ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

ํ•™์Šต์ค‘์— ๋ชจ๋ธ์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ •๊ธฐ์ ์œผ๋กœ ์ €์žฅํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ์–ด๋–ค ์ด์œ ๋กœ๋“  ํ•™์Šต์ด ์ค‘๋‹จ๋œ ๊ฒฝ์šฐ ์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ํ•™์Šต์„ ๋‹ค์‹œ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ž๋ฅผ ์ „๋‹ฌํ•˜๋ฉด 500๋‹จ๊ณ„๋งˆ๋‹ค ์ „์ฒด ํ•™์Šต ์ƒํƒœ๊ฐ€ output_dir์˜ ํ•˜์œ„ ํด๋”์— ์ฒดํฌํฌ์ธํŠธ๋กœ์„œ ์ €์žฅ๋ฉ๋‹ˆ๋‹ค.

--checkpointing_steps=500

์ €์žฅ๋œ ์ฒดํฌํฌ์ธํŠธ์—์„œ ํ•™์Šต์„ ์žฌ๊ฐœํ•˜๋ ค๋ฉด, ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์™€ ์žฌ๊ฐœํ•  ํŠน์ • ์ฒดํฌํฌ์ธํŠธ์— ๋‹ค์Œ ์ธ์ž๋ฅผ ์ „๋‹ฌํ•˜์„ธ์š”.

--resume_from_checkpoint="checkpoint-1500"

ํŒŒ์ธ ํŠœ๋‹

ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๊ณ ์–‘์ด ์žฅ๋‚œ๊ฐ ๋ฐ์ดํ„ฐ์…‹์„ ๋‹ค์šด๋กœ๋“œํ•˜์—ฌ ๋””๋ ‰ํ† ๋ฆฌ์— ์ €์žฅํ•˜์„ธ์š”. ์—ฌ๋Ÿฌ๋ถ„๋งŒ์˜ ๊ณ ์œ ํ•œ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•˜๊ณ ์ž ํ•œ๋‹ค๋ฉด, ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์…‹ ๋งŒ๋“ค๊ธฐ ๊ฐ€์ด๋“œ๋ฅผ ์‚ดํŽด๋ณด์‹œ๊ธฐ ๋ฐ”๋ž๋‹ˆ๋‹ค.

from huggingface_hub import snapshot_download

local_dir = "./cat"
snapshot_download(
    "diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes"
)

๋ชจ๋ธ์˜ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ID(๋˜๋Š” ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๊ฐ€ ํฌํ•จ๋œ ๋””๋ ‰ํ„ฐ๋ฆฌ ๊ฒฝ๋กœ)๋ฅผ MODEL_NAME ํ™˜๊ฒฝ ๋ณ€์ˆ˜์— ํ• ๋‹นํ•˜๊ณ , ํ•ด๋‹น ๊ฐ’์„ pretrained_model_name_or_path ์ธ์ž์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด๋ฏธ์ง€๊ฐ€ ํฌํ•จ๋œ ๋””๋ ‰ํ„ฐ๋ฆฌ ๊ฒฝ๋กœ๋ฅผ DATA_DIR ํ™˜๊ฒฝ ๋ณ€์ˆ˜์— ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์Šคํฌ๋ฆฝํŠธ๋Š” ๋‹ค์Œ ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

  • learned_embeds.bin
  • token_identifier.txt
  • type_of_concept.txt.

๐Ÿ’กV100 GPU 1๊ฐœ๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์ „์ฒด ํ•™์Šต์—๋Š” ์ตœ๋Œ€ 1์‹œ๊ฐ„์ด ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ํ•™์Šต์ด ์™„๋ฃŒ๋˜๊ธฐ๋ฅผ ๊ธฐ๋‹ค๋ฆฌ๋Š” ๋™์•ˆ ๊ถ๊ธˆํ•œ ์ ์ด ์žˆ์œผ๋ฉด ์•„๋ž˜ ์„น์…˜์—์„œ textual-inversion์ด ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋Š”์ง€ ์ž์œ ๋กญ๊ฒŒ ํ™•์ธํ•˜์„ธ์š” !

```bash export MODEL_NAME="runwayml/stable-diffusion-v1-5" export DATA_DIR="./cat"

accelerate launch textual_inversion.py
--pretrained_model_name_or_path=$MODEL_NAME
--train_data_dir=$DATA_DIR
--learnable_property="object"
--placeholder_token="" --initializer_token="toy"
--resolution=512
--train_batch_size=1
--gradient_accumulation_steps=4
--max_train_steps=3000
--learning_rate=5.0e-04 --scale_lr
--lr_scheduler="constant"
--lr_warmup_steps=0
--output_dir="textual_inversion_cat"
--push_to_hub


<Tip>

๐Ÿ’กํ•™์Šต ์„ฑ๋Šฅ์„ ์˜ฌ๋ฆฌ๊ธฐ ์œ„ํ•ด, ํ”Œ๋ ˆ์ด์Šคํ™€๋” ํ† ํฐ(`<cat-toy>`)์„ (๋‹จ์ผํ•œ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๊ฐ€ ์•„๋‹Œ) ๋ณต์ˆ˜์˜ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋กœ ํ‘œํ˜„ํ•˜๋Š” ๊ฒƒ ์—ญ์‹œ ๊ณ ๋ คํ•  ์žˆ์Šต๋‹ˆ๋‹ค.  ์ด๋Ÿฌํ•œ ํŠธ๋ฆญ์ด ๋ชจ๋ธ์ด ๋ณด๋‹ค ๋ณต์žกํ•œ ์ด๋ฏธ์ง€์˜ ์Šคํƒ€์ผ(์•ž์„œ ๋งํ•œ ์ฝ˜์…‰ํŠธ)์„ ๋” ์ž˜ ์บก์ฒ˜ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ณต์ˆ˜์˜ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ํ•™์Šต์„ ํ™œ์„ฑํ™”ํ•˜๋ ค๋ฉด ๋‹ค์Œ ์˜ต์…˜์„ ์ „๋‹ฌํ•˜์‹ญ์‹œ์˜ค.

```bash
--num_vectors=5

TPU์— ์•ก์„ธ์Šคํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ, Flax ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋” ๋น ๋ฅด๊ฒŒ ๋ชจ๋ธ์„ ํ•™์Šต์‹œ์ผœ๋ณด์„ธ์š”. (๋ฌผ๋ก  GPU์—์„œ๋„ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.) ๋™์ผํ•œ ์„ค์ •์—์„œ Flax ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋Š” PyTorch ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ณด๋‹ค ์ตœ์†Œ 70% ๋” ๋นจ๋ผ์•ผ ํ•ฉ๋‹ˆ๋‹ค! โšก๏ธ

์‹œ์ž‘ํ•˜๊ธฐ ์•ž์„œ Flax์— ๋Œ€ํ•œ ์˜์กด์„ฑ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋“ค์„ ์„ค์น˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

pip install -U -r requirements_flax.txt

๋ชจ๋ธ์˜ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ ID(๋˜๋Š” ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๊ฐ€ ํฌํ•จ๋œ ๋””๋ ‰ํ„ฐ๋ฆฌ ๊ฒฝ๋กœ)๋ฅผ MODEL_NAME ํ™˜๊ฒฝ ๋ณ€์ˆ˜์— ํ• ๋‹นํ•˜๊ณ , ํ•ด๋‹น ๊ฐ’์„ pretrained_model_name_or_path ์ธ์ž์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

๊ทธ๋Ÿฐ ๋‹ค์Œ ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹œ์ž‘ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export DATA_DIR="./cat"

python textual_inversion_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --output_dir="textual_inversion_cat" \
  --push_to_hub

์ค‘๊ฐ„ ๋กœ๊น…

๋ชจ๋ธ์˜ ํ•™์Šต ์ง„ํ–‰ ์ƒํ™ฉ์„ ์ถ”์ ํ•˜๋Š” ๋ฐ ๊ด€์‹ฌ์ด ์žˆ๋Š” ๊ฒฝ์šฐ, ํ•™์Šต ๊ณผ์ •์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๋ฅผ ์ €์žฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ•™์Šต ์Šคํฌ๋ฆฝํŠธ์— ๋‹ค์Œ ์ธ์ˆ˜๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ ์ค‘๊ฐ„ ๋กœ๊น…์„ ํ™œ์„ฑํ™”ํ•ฉ๋‹ˆ๋‹ค.

  • validation_prompt : ์ƒ˜ํ”Œ์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ํ”„๋กฌํ”„ํŠธ(๊ธฐ๋ณธ๊ฐ’์€ None์œผ๋กœ ์„ค์ •๋˜๋ฉฐ, ์ด ๋•Œ ์ค‘๊ฐ„ ๋กœ๊น…์€ ๋น„ํ™œ์„ฑํ™”๋จ)
  • num_validation_images : ์ƒ์„ฑํ•  ์ƒ˜ํ”Œ ์ด๋ฏธ์ง€ ์ˆ˜
  • validation_steps : validation_prompt๋กœ๋ถ€ํ„ฐ ์ƒ˜ํ”Œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์ „ ์Šคํ…์˜ ์ˆ˜
--validation_prompt="A <cat-toy> backpack"
--num_validation_images=4
--validation_steps=100

์ถ”๋ก 

๋ชจ๋ธ์„ ํ•™์Šตํ•œ ํ›„์—๋Š”, ํ•ด๋‹น ๋ชจ๋ธ์„ [StableDiffusionPipeline]์„ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

textual-inversion ์Šคํฌ๋ฆฝํŠธ๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ textual-inversion์„ ํ†ตํ•ด ์–ป์–ด์ง„ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋งŒ์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ํ•ด๋‹น ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋“ค์€ ํ…์ŠคํŠธ ์ธ์ฝ”๋”์˜ ์ž„๋ฒ ๋”ฉ ํ–‰๋ ฌ์— ์ถ”๊ฐ€๋˜์–ด ์žˆ์Šต์Šต๋‹ˆ๋‹ค.

๐Ÿ’ก ์ปค๋ฎค๋‹ˆํ‹ฐ๋Š” sd-concepts-library ๋ผ๋Š” ๋Œ€๊ทœ๋ชจ์˜ textual-inversion ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ๋งŒ๋“ค์—ˆ์Šต๋‹ˆ๋‹ค. textual-inversion ์ž„๋ฒ ๋”ฉ์„ ๋ฐ‘๋ฐ”๋‹ฅ๋ถ€ํ„ฐ ํ•™์Šตํ•˜๋Š” ๋Œ€์‹ , ํ•ด๋‹น ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์— ๋ณธ์ธ์ด ์ฐพ๋Š” textual-inversion ์ž„๋ฒ ๋”ฉ์ด ์ด๋ฏธ ์ถ”๊ฐ€๋˜์–ด ์žˆ์ง€ ์•Š์€์ง€๋ฅผ ํ™•์ธํ•˜๋Š” ๊ฒƒ๋„ ์ข‹์€ ๋ฐฉ๋ฒ•์ด ๋  ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

textual-inversion ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ์„ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์œ„ํ•ด์„œ๋Š”, ๋จผ์ € ํ•ด๋‹น ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋ฅผ ํ•™์Šตํ•  ๋•Œ ์‚ฌ์šฉํ•œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” runwayml/stable-diffusion-v1-5 ๋ชจ๋ธ์ด ์‚ฌ์šฉ๋˜์—ˆ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๊ณ  ๋ถˆ๋Ÿฌ์˜ค๊ฒ ์Šต๋‹ˆ๋‹ค.

from diffusers import StableDiffusionPipeline
import torch

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")

๋‹ค์Œ์œผ๋กœ TextualInversionLoaderMixin.load_textual_inversion ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด, textual-inversion ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ ์šฐ๋ฆฌ๋Š” ์ด์ „์˜ <cat-toy> ์˜ˆ์ œ์˜ ์ž„๋ฒ ๋”ฉ์„ ๋ถˆ๋Ÿฌ์˜ฌ ๊ฒƒ์ž…๋‹ˆ๋‹ค.

pipe.load_textual_inversion("sd-concepts-library/cat-toy")

์ด์ œ ํ”Œ๋ ˆ์ด์Šคํ™€๋” ํ† ํฐ(<cat-toy>)์ด ์ž˜ ๋™์ž‘ํ•˜๋Š”์ง€๋ฅผ ํ™•์ธํ•˜๋Š” ํŒŒ์ดํ”„๋ผ์ธ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

prompt = "A <cat-toy> backpack"

image = pipe(prompt, num_inference_steps=50).images[0]
image.save("cat-backpack.png")

TextualInversionLoaderMixin.load_textual_inversion์€ Diffusers ํ˜•์‹์œผ๋กœ ์ €์žฅ๋œ ํ…์ŠคํŠธ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์„ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ, Automatic1111 ํ˜•์‹์œผ๋กœ ์ €์žฅ๋œ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋„ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ ค๋ฉด, ๋จผ์ € civitAI์—์„œ ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•œ ๋‹ค์Œ ๋กœ์ปฌ์—์„œ ๋ถˆ๋Ÿฌ์™€์•ผ ํ•ฉ๋‹ˆ๋‹ค.

pipe.load_textual_inversion("./charturnerv2.pt")

ํ˜„์žฌ Flax์— ๋Œ€ํ•œ load_textual_inversion ํ•จ์ˆ˜๋Š” ์—†์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ•™์Šต ํ›„ textual-inversion ์ž„๋ฒ ๋”ฉ ๋ฒกํ„ฐ๊ฐ€ ๋ชจ๋ธ์˜ ์ผ๋ถ€๋กœ์„œ ์ €์žฅ๋˜์—ˆ๋Š”์ง€๋ฅผ ํ™•์ธํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ์€ ๋‹ค๋ฅธ Flax ๋ชจ๋ธ๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline

model_path = "path-to-your-trained-model"
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16)

prompt = "A <cat-toy> backpack"
prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 50

num_samples = jax.device_count()
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
image.save("cat-backpack.png")

์ž‘๋™ ๋ฐฉ์‹

Diagram from the paper showing overview Architecture overview from the Textual Inversion blog post.

์ผ๋ฐ˜์ ์œผ๋กœ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋Š” ๋ชจ๋ธ์— ์ „๋‹ฌ๋˜๊ธฐ ์ „์— ์ž„๋ฒ ๋”ฉ์œผ๋กœ ํ† ํฐํ™”๋ฉ๋‹ˆ๋‹ค. textual-inversion์€ ๋น„์Šทํ•œ ์ž‘์—…์„ ์ˆ˜ํ–‰ํ•˜์ง€๋งŒ, ์œ„ ๋‹ค์ด์–ด๊ทธ๋žจ์˜ ํŠน์ˆ˜ ํ† ํฐ S*๋กœ๋ถ€ํ„ฐ ์ƒˆ๋กœ์šด ํ† ํฐ ์ž„๋ฒ ๋”ฉ v*๋ฅผ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ์˜ ์•„์›ƒํ’‹์€ ๋””ํ“จ์ „ ๋ชจ๋ธ์„ ์กฐ์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋ฉฐ, ๋””ํ“จ์ „ ๋ชจ๋ธ์ด ๋‹จ ๋ช‡ ๊ฐœ์˜ ์˜ˆ์ œ ์ด๋ฏธ์ง€์—์„œ ์‹ ์†ํ•˜๊ณ  ์ƒˆ๋กœ์šด ์ฝ˜์…‰ํŠธ๋ฅผ ์ดํ•ดํ•˜๋Š” ๋ฐ ๋„์›€์„ ์ค๋‹ˆ๋‹ค.

์ด๋ฅผ ์œ„ํ•ด textual-inversion์€ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ๋ชจ๋ธ๊ณผ ํ•™์Šต์šฉ ์ด๋ฏธ์ง€์˜ ๋…ธ์ด์ฆˆ ๋ฒ„์ „์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค. ์ œ๋„ˆ๋ ˆ์ดํ„ฐ๋Š” ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์€ ๋ฒ„์ „์˜ ์ด๋ฏธ์ง€๋ฅผ ์˜ˆ์ธกํ•˜๋ ค๊ณ  ์‹œ๋„ํ•˜๋ฉฐ ํ† ํฐ ์ž„๋ฒ ๋”ฉ v*์€ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ์˜ ์„ฑ๋Šฅ์— ๋”ฐ๋ผ ์ตœ์ ํ™”๋ฉ๋‹ˆ๋‹ค. ํ† ํฐ ์ž„๋ฒ ๋”ฉ์ด ์ƒˆ๋กœ์šด ์ฝ˜์…‰ํŠธ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ํฌ์ฐฉํ•˜๋ฉด ๋””ํ“จ์ „ ๋ชจ๋ธ์— ๋” ์œ ์šฉํ•œ ์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜๊ณ  ๋…ธ์ด์ฆˆ๊ฐ€ ์ ์€ ๋” ์„ ๋ช…ํ•œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ตœ์ ํ™” ํ”„๋กœ์„ธ์Šค๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ๋‹ค์–‘ํ•œ ํ”„๋กฌํ”„ํŠธ์™€ ์ด๋ฏธ์ง€์— ์ˆ˜์ฒœ ๋ฒˆ์— ๋…ธ์ถœ๋จ์œผ๋กœ์จ ์ด๋ฃจ์–ด์ง‘๋‹ˆ๋‹ค.