svjack's picture
Upload 1392 files
43b7e92 verified
|
raw
history blame
28.7 kB

Diffusion ๋ชจ๋ธ ํ‰๊ฐ€ํ•˜๊ธฐ[[evaluating-diffusion-models]]

Open In Colab

Stable Diffusion์™€ ๊ฐ™์€ ์ƒ์„ฑ ๋ชจ๋ธ์˜ ํ‰๊ฐ€๋Š” ์ฃผ๊ด€์ ์ธ ์„ฑ๊ฒฉ์„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์‹ค๋ฌด์ž์™€ ์—ฐ๊ตฌ์ž๋กœ์„œ ์šฐ๋ฆฌ๋Š” ์ข…์ข… ๋‹ค์–‘ํ•œ ๊ฐ€๋Šฅ์„ฑ ์ค‘์—์„œ ์‹ ์ค‘ํ•œ ์„ ํƒ์„ ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๋‹ค์–‘ํ•œ ์ƒ์„ฑ ๋ชจ๋ธ (GAN, Diffusion ๋“ฑ)์„ ์‚ฌ์šฉํ•  ๋•Œ ์–ด๋–ป๊ฒŒ ์„ ํƒํ•ด์•ผ ํ• ๊นŒ์š”?

์ •์„ฑ์ ์ธ ํ‰๊ฐ€๋Š” ๋ชจ๋ธ์˜ ์ด๋ฏธ์ง€ ํ’ˆ์งˆ์— ๋Œ€ํ•œ ์ฃผ๊ด€์ ์ธ ํ‰๊ฐ€์ด๋ฏ€๋กœ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๊ณ  ๊ฒฐ์ •์— ์ž˜๋ชป๋œ ์˜ํ–ฅ์„ ๋ฏธ์น  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋ฐ˜๋ฉด, ์ •๋Ÿ‰์ ์ธ ํ‰๊ฐ€๋Š” ์ด๋ฏธ์ง€ ํ’ˆ์งˆ๊ณผ ์ง์ ‘์ ์ธ ์ƒ๊ด€๊ด€๊ณ„๋ฅผ ๊ฐ–์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ผ๋ฐ˜์ ์œผ๋กœ ์ •์„ฑ์  ํ‰๊ฐ€์™€ ์ •๋Ÿ‰์  ํ‰๊ฐ€๋ฅผ ๋ชจ๋‘ ๊ณ ๋ คํ•˜๋Š” ๊ฒƒ์ด ๋” ๊ฐ•๋ ฅํ•œ ์‹ ํ˜ธ๋ฅผ ์ œ๊ณตํ•˜์—ฌ ๋ชจ๋ธ ์„ ํƒ์— ๋„์›€์ด ๋ฉ๋‹ˆ๋‹ค.

์ด ๋ฌธ์„œ์—์„œ๋Š” Diffusion ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๊ธฐ ์œ„ํ•œ ์ •์„ฑ์  ๋ฐ ์ •๋Ÿ‰์  ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์ƒ์„ธํžˆ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ์ •๋Ÿ‰์  ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด์„œ๋Š” ํŠนํžˆ diffusers์™€ ํ•จ๊ป˜ ๊ตฌํ˜„ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ์ดˆ์ ์„ ๋งž์ถ”์—ˆ์Šต๋‹ˆ๋‹ค.

์ด ๋ฌธ์„œ์—์„œ ๋ณด์—ฌ์ง„ ๋ฐฉ๋ฒ•๋“ค์€ ๊ธฐ๋ฐ˜ ์ƒ์„ฑ ๋ชจ๋ธ์„ ๊ณ ์ •์‹œํ‚ค๊ณ  ๋‹ค์–‘ํ•œ ๋…ธ์ด์ฆˆ ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ์—๋„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‹œ๋‚˜๋ฆฌ์˜ค[[scenarios]]

๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•˜์—ฌ Diffusion ๋ชจ๋ธ์„ ๋‹ค๋ฃน๋‹ˆ๋‹ค:

  • ํ…์ŠคํŠธ๋กœ ์•ˆ๋‚ด๋œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (์˜ˆ: StableDiffusionPipeline).
  • ์ž…๋ ฅ ์ด๋ฏธ์ง€์— ์ถ”๊ฐ€๋กœ ์กฐ๊ฑด์„ ๊ฑด ํ…์ŠคํŠธ๋กœ ์•ˆ๋‚ด๋œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ (์˜ˆ: StableDiffusionImg2ImgPipeline ๋ฐ StableDiffusionInstructPix2PixPipeline).
  • ํด๋ž˜์Šค ์กฐ๊ฑดํ™”๋œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ชจ๋ธ (์˜ˆ: DiTPipeline).

์ •์„ฑ์  ํ‰๊ฐ€[[qualitative-evaluation]]

์ •์„ฑ์  ํ‰๊ฐ€๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€์˜ ์ธ๊ฐ„ ํ‰๊ฐ€๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค. ํ’ˆ์งˆ์€ ๊ตฌ์„ฑ์„ฑ, ์ด๋ฏธ์ง€-ํ…์ŠคํŠธ ์ผ์น˜, ๊ณต๊ฐ„ ๊ด€๊ณ„ ๋“ฑ๊ณผ ๊ฐ™์€ ์ธก๋ฉด์—์„œ ์ธก์ •๋ฉ๋‹ˆ๋‹ค. ์ผ๋ฐ˜์ ์ธ ํ”„๋กฌํ”„ํŠธ๋Š” ์ฃผ๊ด€์ ์ธ ์ง€ํ‘œ์— ๋Œ€ํ•œ ์ผ์ •ํ•œ ๊ธฐ์ค€์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. DrawBench์™€ PartiPrompts๋Š” ์ •์„ฑ์ ์ธ ๋ฒค์น˜๋งˆํ‚น์— ์‚ฌ์šฉ๋˜๋Š” ํ”„๋กฌํ”„ํŠธ ๋ฐ์ดํ„ฐ์…‹์ž…๋‹ˆ๋‹ค. DrawBench์™€ PartiPrompts๋Š” ๊ฐ๊ฐ Imagen๊ณผ Parti์—์„œ ์†Œ๊ฐœ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

Parti ๊ณต์‹ ์›น์‚ฌ์ดํŠธ์—์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์„ค๋ช…ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค:

PartiPrompts (P2)๋Š” ์ด ์ž‘์—…์˜ ์ผ๋ถ€๋กœ ๊ณต๊ฐœ๋˜๋Š” ์˜์–ด๋กœ ๋œ 1600๊ฐœ ์ด์ƒ์˜ ๋‹ค์–‘ํ•œ ํ”„๋กฌํ”„ํŠธ ์„ธํŠธ์ž…๋‹ˆ๋‹ค. P2๋Š” ๋‹ค์–‘ํ•œ ๋ฒ”์ฃผ์™€ ๋„์ „ ์ธก๋ฉด์—์„œ ๋ชจ๋ธ์˜ ๋Šฅ๋ ฅ์„ ์ธก์ •ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

parti-prompts

PartiPrompts๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์—ด์„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค:

  • ํ”„๋กฌํ”„ํŠธ (Prompt)
  • ํ”„๋กฌํ”„ํŠธ์˜ ์นดํ…Œ๊ณ ๋ฆฌ (์˜ˆ: "Abstract", "World Knowledge" ๋“ฑ)
  • ๋‚œ์ด๋„๋ฅผ ๋ฐ˜์˜ํ•œ ์ฑŒ๋ฆฐ์ง€ (์˜ˆ: "Basic", "Complex", "Writing & Symbols" ๋“ฑ)

์ด๋Ÿฌํ•œ ๋ฒค์น˜๋งˆํฌ๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ์ด๋ฏธ์ง€ ์ƒ์„ฑ ๋ชจ๋ธ์„ ์ธ๊ฐ„ ํ‰๊ฐ€๋กœ ๋น„๊ตํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

์ด๋ฅผ ์œ„ํ•ด ๐Ÿงจ Diffusers ํŒ€์€ Open Parti Prompts๋ฅผ ๊ตฌ์ถ•ํ–ˆ์Šต๋‹ˆ๋‹ค. ์ด๋Š” Parti Prompts๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ์ปค๋ฎค๋‹ˆํ‹ฐ ๊ธฐ๋ฐ˜์˜ ์งˆ์  ๋ฒค์น˜๋งˆํฌ๋กœ, ์ตœ์ฒจ๋‹จ ์˜คํ”ˆ ์†Œ์Šค ํ™•์‚ฐ ๋ชจ๋ธ์„ ๋น„๊ตํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค:

  • Open Parti Prompts ๊ฒŒ์ž„: 10๊ฐœ์˜ parti prompt์— ๋Œ€ํ•ด 4๊ฐœ์˜ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€๊ฐ€ ์ œ์‹œ๋˜๋ฉฐ, ์‚ฌ์šฉ์ž๋Š” ํ”„๋กฌํ”„ํŠธ์— ๊ฐ€์žฅ ์ ํ•ฉํ•œ ์ด๋ฏธ์ง€๋ฅผ ์„ ํƒํ•ฉ๋‹ˆ๋‹ค.
  • Open Parti Prompts ๋ฆฌ๋”๋ณด๋“œ: ํ˜„์žฌ ์ตœ๊ณ ์˜ ์˜คํ”ˆ ์†Œ์Šค diffusion ๋ชจ๋ธ๋“ค์„ ์„œ๋กœ ๋น„๊ตํ•˜๋Š” ๋ฆฌ๋”๋ณด๋“œ์ž…๋‹ˆ๋‹ค.

์ด๋ฏธ์ง€๋ฅผ ์ˆ˜๋™์œผ๋กœ ๋น„๊ตํ•˜๋ ค๋ฉด, diffusers๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ช‡๊ฐ€์ง€ PartiPrompts๋ฅผ ์–ด๋–ป๊ฒŒ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ์•Œ์•„๋ด…์‹œ๋‹ค.

๋‹ค์Œ์€ ๋ช‡ ๊ฐ€์ง€ ๋‹ค๋ฅธ ๋„์ „์—์„œ ์ƒ˜ํ”Œ๋งํ•œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค: Basic, Complex, Linguistic Structures, Imagination, Writing & Symbols. ์—ฌ๊ธฐ์„œ๋Š” PartiPrompts๋ฅผ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

from datasets import load_dataset

# prompts = load_dataset("nateraw/parti-prompts", split="train")
# prompts = prompts.shuffle()
# sample_prompts = [prompts[i]["Prompt"] for i in range(5)]

# Fixing these sample prompts in the interest of reproducibility.
sample_prompts = [
    "a corgi",
    "a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky",
    "a car with no windows",
    "a cube made of porcupine",
    'The saying "BE EXCELLENT TO EACH OTHER" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.',
]

์ด์ œ ์ด๋Ÿฐ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Stable Diffusion (v1-4 checkpoint)๋ฅผ ์‚ฌ์šฉํ•œ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค :

import torch

seed = 0
generator = torch.manual_seed(seed)

images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images

parti-prompts-14

num_images_per_prompt๋ฅผ ์„ค์ •ํ•˜์—ฌ ๋™์ผํ•œ ํ”„๋กฌํ”„ํŠธ์— ๋Œ€ํ•ด ๋‹ค๋ฅธ ์ด๋ฏธ์ง€๋ฅผ ๋น„๊ตํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ์ฒดํฌํฌ์ธํŠธ(v1-5)๋กœ ๋™์ผํ•œ ํŒŒ์ดํ”„๋ผ์ธ์„ ์‹คํ–‰ํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒฐ๊ณผ๊ฐ€ ๋‚˜์˜ต๋‹ˆ๋‹ค:

parti-prompts-15

๋‹ค์–‘ํ•œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋“  ํ”„๋กฌํ”„ํŠธ์—์„œ ์ƒ์„ฑ๋œ ์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€๋“ค์ด ์ƒ์„ฑ๋˜๋ฉด (ํ‰๊ฐ€ ๊ณผ์ •์—์„œ) ์ด๋Ÿฌํ•œ ๊ฒฐ๊ณผ๋ฌผ๋“ค์€ ์‚ฌ๋žŒ ํ‰๊ฐ€์ž๋“ค์—๊ฒŒ ์ ์ˆ˜๋ฅผ ๋งค๊ธฐ๊ธฐ ์œ„ํ•ด ์ œ์‹œ๋ฉ๋‹ˆ๋‹ค. DrawBench์™€ PartiPrompts ๋ฒค์น˜๋งˆํฌ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๊ฐ๊ฐ์˜ ๋…ผ๋ฌธ์„ ์ฐธ์กฐํ•˜์‹ญ์‹œ์˜ค.

๋ชจ๋ธ์ด ํ›ˆ๋ จ ์ค‘์ผ ๋•Œ ์ถ”๋ก  ์ƒ˜ํ”Œ์„ ์‚ดํŽด๋ณด๋Š” ๊ฒƒ์€ ํ›ˆ๋ จ ์ง„ํ–‰ ์ƒํ™ฉ์„ ์ธก์ •ํ•˜๋Š” ๋ฐ ์œ ์šฉํ•ฉ๋‹ˆ๋‹ค. ํ›ˆ๋ จ ์Šคํฌ๋ฆฝํŠธ์—์„œ๋Š” TensorBoard์™€ Weights & Biases์— ๋Œ€ํ•œ ์ถ”๊ฐ€ ์ง€์›๊ณผ ํ•จ๊ป˜ ์ด ์œ ํ‹ธ๋ฆฌํ‹ฐ๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.

์ •๋Ÿ‰์  ํ‰๊ฐ€[[quantitative-evaluation]]

์ด ์„น์…˜์—์„œ๋Š” ์„ธ ๊ฐ€์ง€ ๋‹ค๋ฅธ ํ™•์‚ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์•ˆ๋‚ดํ•ฉ๋‹ˆ๋‹ค:

  • CLIP ์ ์ˆ˜
  • CLIP ๋ฐฉํ–ฅ์„ฑ ์œ ์‚ฌ๋„
  • FID

ํ…์ŠคํŠธ ์•ˆ๋‚ด ์ด๋ฏธ์ง€ ์ƒ์„ฑ[[text-guided-image-generation]]

CLIP ์ ์ˆ˜๋Š” ์ด๋ฏธ์ง€-์บก์…˜ ์Œ์˜ ํ˜ธํ™˜์„ฑ์„ ์ธก์ •ํ•ฉ๋‹ˆ๋‹ค. ๋†’์€ CLIP ์ ์ˆ˜๋Š” ๋†’์€ ํ˜ธํ™˜์„ฑ๐Ÿ”ผ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. CLIP ์ ์ˆ˜๋Š” ์ด๋ฏธ์ง€์™€ ์บก์…˜ ์‚ฌ์ด์˜ ์˜๋ฏธ์  ์œ ์‚ฌ์„ฑ์œผ๋กœ ์ƒ๊ฐํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. CLIP ์ ์ˆ˜๋Š” ์ธ๊ฐ„ ํŒ๋‹จ๊ณผ ๋†’์€ ์ƒ๊ด€๊ด€๊ณ„๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

[StableDiffusionPipeline]์„ ์ผ๋‹จ ๋กœ๋“œํ•ด๋ด…์‹œ๋‹ค:

from diffusers import StableDiffusionPipeline
import torch

model_ckpt = "CompVis/stable-diffusion-v1-4"
sd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to("cuda")

์—ฌ๋Ÿฌ ๊ฐœ์˜ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

prompts = [
    "a photo of an astronaut riding a horse on mars",
    "A high tech solarpunk utopia in the Amazon rainforest",
    "A pikachu fine dining with a view to the Eiffel Tower",
    "A mecha robot in a favela in expressionist style",
    "an insect robot preparing a delicious meal",
    "A small cabin on top of a snowy mountain in the style of Disney, artstation",
]

images = sd_pipeline(prompts, num_images_per_prompt=1, output_type="np").images

print(images.shape)
# (6, 512, 512, 3)

๊ทธ๋Ÿฌ๊ณ  ๋‚˜์„œ CLIP ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.

from torchmetrics.functional.multimodal import clip_score
from functools import partial

clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")

def calculate_clip_score(images, prompts):
    images_int = (images * 255).astype("uint8")
    clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
    return round(float(clip_score), 4)

sd_clip_score = calculate_clip_score(images, prompts)
print(f"CLIP score: {sd_clip_score}")
# CLIP score: 35.7038

์œ„์˜ ์˜ˆ์ œ์—์„œ๋Š” ๊ฐ ํ”„๋กฌํ”„ํŠธ ๋‹น ํ•˜๋‚˜์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค. ๋งŒ์•ฝ ํ”„๋กฌํ”„ํŠธ ๋‹น ์—ฌ๋Ÿฌ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•œ๋‹ค๋ฉด, ํ”„๋กฌํ”„ํŠธ ๋‹น ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€์˜ ํ‰๊ท  ์ ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ [StableDiffusionPipeline]๊ณผ ํ˜ธํ™˜๋˜๋Š” ๋‘ ๊ฐœ์˜ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ๋น„๊ตํ•˜๋ ค๋ฉด, ํŒŒ์ดํ”„๋ผ์ธ์„ ํ˜ธ์ถœํ•  ๋•Œ generator๋ฅผ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋จผ์ €, ๊ณ ์ •๋œ ์‹œ๋“œ๋กœ v1-4 Stable Diffusion ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

seed = 0
generator = torch.manual_seed(seed)

images = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images

๊ทธ๋Ÿฐ ๋‹ค์Œ v1-5 checkpoint๋ฅผ ๋กœ๋“œํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค:

model_ckpt_1_5 = "runwayml/stable-diffusion-v1-5"
sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=weight_dtype).to(device)

images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images

๊ทธ๋ฆฌ๊ณ  ๋งˆ์ง€๋ง‰์œผ๋กœ CLIP ์ ์ˆ˜๋ฅผ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค:

sd_clip_score_1_4 = calculate_clip_score(images, prompts)
print(f"CLIP Score with v-1-4: {sd_clip_score_1_4}")
# CLIP Score with v-1-4: 34.9102

sd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)
print(f"CLIP Score with v-1-5: {sd_clip_score_1_5}")
# CLIP Score with v-1-5: 36.2137

v1-5 ์ฒดํฌํฌ์ธํŠธ๊ฐ€ ์ด์ „ ๋ฒ„์ „๋ณด๋‹ค ๋” ๋‚˜์€ ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ CLIP ์ ์ˆ˜๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด ์‚ฌ์šฉํ•œ ํ”„๋กฌํ”„ํŠธ์˜ ์ˆ˜๊ฐ€ ์ƒ๋‹นํžˆ ์ ์Šต๋‹ˆ๋‹ค. ๋ณด๋‹ค ์‹ค์šฉ์ ์ธ ํ‰๊ฐ€๋ฅผ ์œ„ํ•ด์„œ๋Š” ์ด ์ˆ˜๋ฅผ ํ›จ์”ฌ ๋†’๊ฒŒ ์„ค์ •ํ•˜๊ณ , ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋‹ค์–‘ํ•˜๊ฒŒ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด ์ ์ˆ˜์—๋Š” ๋ช‡ ๊ฐ€์ง€ ์ œํ•œ ์‚ฌํ•ญ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹์˜ ์บก์…˜์€ ์›น์—์„œ ํฌ๋กค๋ง๋˜์–ด ์ด๋ฏธ์ง€์™€ ๊ด€๋ จ๋œ alt ๋ฐ ์œ ์‚ฌํ•œ ํƒœ๊ทธ์—์„œ ์ถ”์ถœ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ์ด๋“ค์€ ์ธ๊ฐ„์ด ์ด๋ฏธ์ง€๋ฅผ ์„ค๋ช…ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒƒ๊ณผ ์ผ์น˜ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์—ฌ๊ธฐ์„œ๋Š” ๋ช‡ ๊ฐ€์ง€ ํ”„๋กฌํ”„ํŠธ๋ฅผ "์—”์ง€๋‹ˆ์–ด๋ง"ํ•ด์•ผ ํ–ˆ์Šต๋‹ˆ๋‹ค.

์ด๋ฏธ์ง€ ์กฐ๊ฑดํ™”๋œ ํ…์ŠคํŠธ-์ด๋ฏธ์ง€ ์ƒ์„ฑ[[image-conditioned-text-to-image-generation]]

์ด ๊ฒฝ์šฐ, ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ž…๋ ฅ ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ํ”„๋กฌํ”„ํŠธ๋กœ ์กฐ๊ฑดํ™”ํ•ฉ๋‹ˆ๋‹ค. [StableDiffusionInstructPix2PixPipeline]์„ ์˜ˆ๋กœ ๋“ค์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ์ด๋Š” ํŽธ์ง‘ ์ง€์‹œ๋ฌธ์„ ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ๋กœ ์‚ฌ์šฉํ•˜๊ณ  ํŽธ์ง‘ํ•  ์ž…๋ ฅ ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ์€ ํ•˜๋‚˜์˜ ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค:

edit-instruction

๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๋Š” ํ•œ ๊ฐ€์ง€ ์ „๋žต์€ ๋‘ ์ด๋ฏธ์ง€ ์บก์…˜ ๊ฐ„์˜ ๋ณ€๊ฒฝ๊ณผ(CLIP-Guided Domain Adaptation of Image Generators์—์„œ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค) ํ•จ๊ป˜ ๋‘ ์ด๋ฏธ์ง€ ์‚ฌ์ด์˜ ๋ณ€๊ฒฝ์˜ ์ผ๊ด€์„ฑ์„ ์ธก์ •ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค (CLIP ๊ณต๊ฐ„์—์„œ). ์ด๋ฅผ "CLIP ๋ฐฉํ–ฅ์„ฑ ์œ ์‚ฌ์„ฑ"์ด๋ผ๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

  • ์บก์…˜ 1์€ ํŽธ์ง‘ํ•  ์ด๋ฏธ์ง€ (์ด๋ฏธ์ง€ 1)์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค.
  • ์บก์…˜ 2๋Š” ํŽธ์ง‘๋œ ์ด๋ฏธ์ง€ (์ด๋ฏธ์ง€ 2)์— ํ•ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ํŽธ์ง‘ ์ง€์‹œ๋ฅผ ๋ฐ˜์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

๋‹ค์Œ์€ ๊ทธ๋ฆผ์œผ๋กœ ๋œ ๊ฐœ์š”์ž…๋‹ˆ๋‹ค:

edit-consistency

์šฐ๋ฆฌ๋Š” ์ด ์ธก์ • ํ•ญ๋ชฉ์„ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด ๋ฏธ๋‹ˆ ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ์ค€๋น„ํ–ˆ์Šต๋‹ˆ๋‹ค. ๋จผ์ € ๋ฐ์ดํ„ฐ ์„ธํŠธ๋ฅผ ๋กœ๋“œํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

from datasets import load_dataset

dataset = load_dataset("sayakpaul/instructpix2pix-demo", split="train")
dataset.features
{'input': Value(dtype='string', id=None),
 'edit': Value(dtype='string', id=None),
 'output': Value(dtype='string', id=None),
 'image': Image(decode=True, id=None)}

์—ฌ๊ธฐ์—๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํ•ญ๋ชฉ์ด ์žˆ์Šต๋‹ˆ๋‹ค:

  • input์€ image์— ํ•ด๋‹นํ•˜๋Š” ์บก์…˜์ž…๋‹ˆ๋‹ค.
  • edit์€ ํŽธ์ง‘ ์ง€์‹œ์‚ฌํ•ญ์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
  • output์€ edit ์ง€์‹œ์‚ฌํ•ญ์„ ๋ฐ˜์˜ํ•œ ์ˆ˜์ •๋œ ์บก์…˜์ž…๋‹ˆ๋‹ค.

์ƒ˜ํ”Œ์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

idx = 0
print(f"Original caption: {dataset[idx]['input']}")
print(f"Edit instruction: {dataset[idx]['edit']}")
print(f"Modified caption: {dataset[idx]['output']}")
Original caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'
Edit instruction: make the isles all white marble
Modified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'

๋‹ค์Œ์€ ์ด๋ฏธ์ง€์ž…๋‹ˆ๋‹ค:

dataset[idx]["image"]

edit-dataset

๋จผ์ € ํŽธ์ง‘ ์ง€์‹œ์‚ฌํ•ญ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ์„ธํŠธ์˜ ์ด๋ฏธ์ง€๋ฅผ ํŽธ์ง‘ํ•˜๊ณ  ๋ฐฉํ–ฅ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.

[StableDiffusionInstructPix2PixPipeline]๋ฅผ ๋จผ์ € ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค:

from diffusers import StableDiffusionInstructPix2PixPipeline

instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
).to(device)

์ด์ œ ํŽธ์ง‘์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค:

import numpy as np


def edit_image(input_image, instruction):
    image = instruct_pix2pix_pipeline(
        instruction,
        image=input_image,
        output_type="np",
        generator=generator,
    ).images[0]
    return image

input_images = []
original_captions = []
modified_captions = []
edited_images = []

for idx in range(len(dataset)):
    input_image = dataset[idx]["image"]
    edit_instruction = dataset[idx]["edit"]
    edited_image = edit_image(input_image, edit_instruction)

    input_images.append(np.array(input_image))
    original_captions.append(dataset[idx]["input"])
    modified_captions.append(dataset[idx]["output"])
    edited_images.append(edited_image)

๋ฐฉํ–ฅ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ๋จผ์ € CLIP์˜ ์ด๋ฏธ์ง€์™€ ํ…์ŠคํŠธ ์ธ์ฝ”๋”๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค:

from transformers import (
    CLIPTokenizer,
    CLIPTextModelWithProjection,
    CLIPVisionModelWithProjection,
    CLIPImageProcessor,
)

clip_id = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(clip_id)
text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(device)
image_processor = CLIPImageProcessor.from_pretrained(clip_id)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device)

์ฃผ๋ชฉํ•  ์ ์€ ํŠน์ •ํ•œ CLIP ์ฒดํฌํฌ์ธํŠธ์ธ openai/clip-vit-large-patch14๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋Š” Stable Diffusion ์‚ฌ์ „ ํ›ˆ๋ จ์ด ์ด CLIP ๋ณ€ํ˜•์ฒด์™€ ํ•จ๊ป˜ ์ˆ˜ํ–‰๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

๋‹ค์Œ์œผ๋กœ, ๋ฐฉํ–ฅ์„ฑ ์œ ์‚ฌ๋„๋ฅผ ๊ณ„์‚ฐํ•˜๊ธฐ ์œ„ํ•ด PyTorch์˜ nn.Module์„ ์ค€๋น„ํ•ฉ๋‹ˆ๋‹ค:

import torch.nn as nn
import torch.nn.functional as F


class DirectionalSimilarity(nn.Module):
    def __init__(self, tokenizer, text_encoder, image_processor, image_encoder):
        super().__init__()
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.image_processor = image_processor
        self.image_encoder = image_encoder

    def preprocess_image(self, image):
        image = self.image_processor(image, return_tensors="pt")["pixel_values"]
        return {"pixel_values": image.to(device)}

    def tokenize_text(self, text):
        inputs = self.tokenizer(
            text,
            max_length=self.tokenizer.model_max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {"input_ids": inputs.input_ids.to(device)}

    def encode_image(self, image):
        preprocessed_image = self.preprocess_image(image)
        image_features = self.image_encoder(**preprocessed_image).image_embeds
        image_features = image_features / image_features.norm(dim=1, keepdim=True)
        return image_features

    def encode_text(self, text):
        tokenized_text = self.tokenize_text(text)
        text_features = self.text_encoder(**tokenized_text).text_embeds
        text_features = text_features / text_features.norm(dim=1, keepdim=True)
        return text_features

    def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two):
        sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)
        return sim_direction

    def forward(self, image_one, image_two, caption_one, caption_two):
        img_feat_one = self.encode_image(image_one)
        img_feat_two = self.encode_image(image_two)
        text_feat_one = self.encode_text(caption_one)
        text_feat_two = self.encode_text(caption_two)
        directional_similarity = self.compute_directional_similarity(
            img_feat_one, img_feat_two, text_feat_one, text_feat_two
        )
        return directional_similarity

์ด์ œ DirectionalSimilarity๋ฅผ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

dir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)
scores = []

for i in range(len(input_images)):
    original_image = input_images[i]
    original_caption = original_captions[i]
    edited_image = edited_images[i]
    modified_caption = modified_captions[i]

    similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)
    scores.append(float(similarity_score.detach().cpu()))

print(f"CLIP directional similarity: {np.mean(scores)}")
# CLIP directional similarity: 0.0797976553440094

CLIP ์ ์ˆ˜์™€ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ, CLIP ๋ฐฉํ–ฅ ์œ ์‚ฌ์„ฑ์ด ๋†’์„์ˆ˜๋ก ์ข‹์Šต๋‹ˆ๋‹ค.

StableDiffusionInstructPix2PixPipeline์€ image_guidance_scale๊ณผ guidance_scale์ด๋ผ๋Š” ๋‘ ๊ฐ€์ง€ ์ธ์ž๋ฅผ ๋…ธ์ถœ์‹œํ‚ต๋‹ˆ๋‹ค. ์ด ๋‘ ์ธ์ž๋ฅผ ์กฐ์ •ํ•˜์—ฌ ์ตœ์ข… ํŽธ์ง‘๋œ ์ด๋ฏธ์ง€์˜ ํ’ˆ์งˆ์„ ์ œ์–ดํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋‘ ์ธ์ž์˜ ์˜ํ–ฅ์„ ์‹คํ—˜ํ•ด๋ณด๊ณ  ๋ฐฉํ–ฅ ์œ ์‚ฌ์„ฑ์— ๋ฏธ์น˜๋Š” ์˜ํ–ฅ์„ ํ™•์ธํ•ด๋ณด๊ธฐ๋ฅผ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ๋ฉ”ํŠธ๋ฆญ์˜ ๊ฐœ๋…์„ ํ™•์žฅํ•˜์—ฌ ์›๋ณธ ์ด๋ฏธ์ง€์™€ ํŽธ์ง‘๋œ ๋ฒ„์ „์˜ ์œ ์‚ฌ์„ฑ์„ ์ธก์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋ฅผ ์œ„ํ•ด F.cosine_similarity(img_feat_two, img_feat_one)์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์ข…๋ฅ˜์˜ ํŽธ์ง‘์—์„œ๋Š” ์ด๋ฏธ์ง€์˜ ์ฃผ์š” ์˜๋ฏธ๊ฐ€ ์ตœ๋Œ€ํ•œ ๋ณด์กด๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ฆ‰, ๋†’์€ ์œ ์‚ฌ์„ฑ ์ ์ˆ˜๋ฅผ ์–ป์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

StableDiffusionPix2PixZeroPipeline์™€ ๊ฐ™์€ ์œ ์‚ฌํ•œ ํŒŒ์ดํ”„๋ผ์ธ์—๋„ ์ด๋Ÿฌํ•œ ๋ฉ”ํŠธ๋ฆญ์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

CLIP ์ ์ˆ˜์™€ CLIP ๋ฐฉํ–ฅ ์œ ์‚ฌ์„ฑ ๋ชจ๋‘ CLIP ๋ชจ๋ธ์— ์˜์กดํ•˜๊ธฐ ๋•Œ๋ฌธ์— ํ‰๊ฐ€๊ฐ€ ํŽธํ–ฅ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค

IS, FID (๋‚˜์ค‘์— ์„ค๋ช…ํ•  ์˜ˆ์ •), ๋˜๋Š” KID์™€ ๊ฐ™์€ ๋ฉ”ํŠธ๋ฆญ์„ ํ™•์žฅํ•˜๋Š” ๊ฒƒ์€ ์–ด๋ ค์šธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ํ‰๊ฐ€ ์ค‘์ธ ๋ชจ๋ธ์ด ๋Œ€๊ทœ๋ชจ ์ด๋ฏธ์ง€ ์บก์…”๋‹ ๋ฐ์ดํ„ฐ์…‹ (์˜ˆ: LAION-5B ๋ฐ์ดํ„ฐ์…‹)์—์„œ ์‚ฌ์ „ ํ›ˆ๋ จ๋˜์—ˆ์„ ๋•Œ ์ด๋Š” ๋ฌธ์ œ๊ฐ€ ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์™œ๋ƒํ•˜๋ฉด ์ด๋Ÿฌํ•œ ๋ฉ”ํŠธ๋ฆญ์˜ ๊ธฐ๋ฐ˜์—๋Š” ์ค‘๊ฐ„ ์ด๋ฏธ์ง€ ํŠน์ง•์„ ์ถ”์ถœํ•˜๊ธฐ ์œ„ํ•ด ImageNet-1k ๋ฐ์ดํ„ฐ์…‹์—์„œ ์‚ฌ์ „ ํ›ˆ๋ จ๋œ InceptionNet์ด ์‚ฌ์šฉ๋˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. Stable Diffusion์˜ ์‚ฌ์ „ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹์€ InceptionNet์˜ ์‚ฌ์ „ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ์…‹๊ณผ ๊ฒน์น˜๋Š” ๋ถ€๋ถ„์ด ์ œํ•œ์ ์ผ ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ ๋”ฐ๋ผ์„œ ์—ฌ๊ธฐ์—๋Š” ์ข‹์€ ํ›„๋ณด๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค.

์œ„์˜ ๋ฉ”ํŠธ๋ฆญ์„ ์‚ฌ์šฉํ•˜๋ฉด ํด๋ž˜์Šค ์กฐ๊ฑด์ด ์žˆ๋Š” ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด, DiT. ์ด๋Š” ImageNet-1k ํด๋ž˜์Šค์— ์กฐ๊ฑด์„ ๊ฑธ๊ณ  ์‚ฌ์ „ ํ›ˆ๋ จ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

ํด๋ž˜์Šค ์กฐ๊ฑดํ™” ์ด๋ฏธ์ง€ ์ƒ์„ฑ[[class-conditioned-image-generation]]

ํด๋ž˜์Šค ์กฐ๊ฑดํ™” ์ƒ์„ฑ ๋ชจ๋ธ์€ ์ผ๋ฐ˜์ ์œผ๋กœ ImageNet-1k์™€ ๊ฐ™์€ ํด๋ž˜์Šค ๋ ˆ์ด๋ธ”์ด ์ง€์ •๋œ ๋ฐ์ดํ„ฐ์…‹์—์„œ ์‚ฌ์ „ ํ›ˆ๋ จ๋ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ์„ ํ‰๊ฐ€ํ•˜๋Š” ์ธ๊ธฐ์žˆ๋Š” ์ง€ํ‘œ์—๋Š” Frรฉchet Inception Distance (FID), Kernel Inception Distance (KID) ๋ฐ Inception Score (IS)๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๋ฌธ์„œ์—์„œ๋Š” FID (Heusel et al.)์— ์ดˆ์ ์„ ๋งž์ถ”๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. DiTPipeline์„ ์‚ฌ์šฉํ•˜์—ฌ FID๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค. ์ด๋Š” ๋‚ด๋ถ€์ ์œผ๋กœ DiT ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

FID๋Š” ๋‘ ๊ฐœ์˜ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹์ด ์–ผ๋งˆ๋‚˜ ์œ ์‚ฌํ•œ์ง€๋ฅผ ์ธก์ •ํ•˜๋Š” ๊ฒƒ์„ ๋ชฉํ‘œ๋กœ ํ•ฉ๋‹ˆ๋‹ค. ์ด ์ž๋ฃŒ์— ๋”ฐ๋ฅด๋ฉด:

Frรฉchet Inception Distance๋Š” ๋‘ ๊ฐœ์˜ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹ ๊ฐ„์˜ ์œ ์‚ฌ์„ฑ์„ ์ธก์ •ํ•˜๋Š” ์ง€ํ‘œ์ž…๋‹ˆ๋‹ค. ์‹œ๊ฐ์  ํ’ˆ์งˆ์— ๋Œ€ํ•œ ์ธ๊ฐ„ ํŒ๋‹จ๊ณผ ์ž˜ ์ƒ๊ด€๋˜๋Š” ๊ฒƒ์œผ๋กœ ๋‚˜ํƒ€๋‚ฌ์œผ๋ฉฐ, ์ฃผ๋กœ ์ƒ์„ฑ์  ์ ๋Œ€ ์‹ ๊ฒฝ๋ง์˜ ์ƒ˜ํ”Œ ํ’ˆ์งˆ์„ ํ‰๊ฐ€ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. FID๋Š” Inception ๋„คํŠธ์›Œํฌ์˜ ํŠน์ง• ํ‘œํ˜„์— ๋งž๊ฒŒ ์ ํ•ฉํ•œ ๋‘ ๊ฐœ์˜ ๊ฐ€์šฐ์‹œ์•ˆ ์‚ฌ์ด์˜ Frรฉchet ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ ๊ตฌํ•ฉ๋‹ˆ๋‹ค.

์ด ๋‘ ๊ฐœ์˜ ๋ฐ์ดํ„ฐ์…‹์€ ์‹ค์ œ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹๊ณผ ๊ฐ€์งœ ์ด๋ฏธ์ง€ ๋ฐ์ดํ„ฐ์…‹(์šฐ๋ฆฌ์˜ ๊ฒฝ์šฐ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€)์ž…๋‹ˆ๋‹ค. FID๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ๋‘ ๊ฐœ์˜ ํฐ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ๊ณ„์‚ฐ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด ๋ฌธ์„œ์—์„œ๋Š” ๋‘ ๊ฐœ์˜ ๋ฏธ๋‹ˆ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์ž‘์—…ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค.

๋จผ์ € ImageNet-1k ํ›ˆ๋ จ ์„ธํŠธ์—์„œ ๋ช‡ ๊ฐœ์˜ ์ด๋ฏธ์ง€๋ฅผ ๋‹ค์šด๋กœ๋“œํ•ด ๋ด…์‹œ๋‹ค:

from zipfile import ZipFile
import requests


def download(url, local_filepath):
    r = requests.get(url)
    with open(local_filepath, "wb") as f:
        f.write(r.content)
    return local_filepath

dummy_dataset_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip"
local_filepath = download(dummy_dataset_url, dummy_dataset_url.split("/")[-1])

with ZipFile(local_filepath, "r") as zipper:
    zipper.extractall(".")
from PIL import Image
import os

dataset_path = "sample-imagenet-images"
image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])

real_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]

๋‹ค์Œ์€ ImageNet-1k classes์˜ ์ด๋ฏธ์ง€ 10๊ฐœ์ž…๋‹ˆ๋‹ค : "cassette_player", "chain_saw" (x2), "church", "gas_pump" (x3), "parachute" (x2), ๊ทธ๋ฆฌ๊ณ  "tench".

real-images
Real images.

์ด์ œ ์ด๋ฏธ์ง€๊ฐ€ ๋กœ๋“œ๋˜์—ˆ์œผ๋ฏ€๋กœ ์ด๋ฏธ์ง€์— ๊ฐ€๋ฒผ์šด ์ „์ฒ˜๋ฆฌ๋ฅผ ์ ์šฉํ•˜์—ฌ FID ๊ณ„์‚ฐ์— ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

from torchvision.transforms import functional as F


def preprocess_image(image):
    image = torch.tensor(image).unsqueeze(0)
    image = image.permute(0, 3, 1, 2) / 255.0
    return F.center_crop(image, (256, 256))

real_images = torch.cat([preprocess_image(image) for image in real_images])
print(real_images.shape)
# torch.Size([10, 3, 256, 256])

์ด์ œ ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ํด๋ž˜์Šค์— ๋”ฐ๋ผ ์กฐ๊ฑดํ™” ๋œ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ธฐ ์œ„ํ•ด DiTPipeline๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.

from diffusers import DiTPipeline, DPMSolverMultistepScheduler

dit_pipeline = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)
dit_pipeline = dit_pipeline.to("cuda")

words = [
    "cassette player",
    "chainsaw",
    "chainsaw",
    "church",
    "gas pump",
    "gas pump",
    "gas pump",
    "parachute",
    "parachute",
    "tench",
]

class_ids = dit_pipeline.get_label_ids(words)
output = dit_pipeline(class_labels=class_ids, generator=generator, output_type="np")

fake_images = output.images
fake_images = torch.tensor(fake_images)
fake_images = fake_images.permute(0, 3, 1, 2)
print(fake_images.shape)
# torch.Size([10, 3, 256, 256])

์ด์ œ torchmetrics๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ FID๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(normalize=True)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)

print(f"FID: {float(fid.compute())}")
# FID: 177.7147216796875

FID๋Š” ๋‚ฎ์„์ˆ˜๋ก ์ข‹์Šต๋‹ˆ๋‹ค. ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ์š”์†Œ๊ฐ€ FID์— ์˜ํ–ฅ์„ ์ค„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • ์ด๋ฏธ์ง€์˜ ์ˆ˜ (์‹ค์ œ ์ด๋ฏธ์ง€์™€ ๊ฐ€์งœ ์ด๋ฏธ์ง€ ๋ชจ๋‘)
  • diffusion ๊ณผ์ •์—์„œ ๋ฐœ์ƒํ•˜๋Š” ๋ฌด์ž‘์œ„์„ฑ
  • diffusion ๊ณผ์ •์—์„œ์˜ ์ถ”๋ก  ๋‹จ๊ณ„ ์ˆ˜
  • diffusion ๊ณผ์ •์—์„œ ์‚ฌ์šฉ๋˜๋Š” ์Šค์ผ€์ค„๋Ÿฌ

๋งˆ์ง€๋ง‰ ๋‘ ๊ฐ€์ง€ ์š”์†Œ์— ๋Œ€ํ•ด์„œ๋Š”, ๋‹ค๋ฅธ ์‹œ๋“œ์™€ ์ถ”๋ก  ๋‹จ๊ณ„์—์„œ ํ‰๊ฐ€๋ฅผ ์‹คํ–‰ํ•˜๊ณ  ํ‰๊ท  ๊ฒฐ๊ณผ๋ฅผ ๋ณด๊ณ ํ•˜๋Š” ๊ฒƒ์€ ์ข‹์€ ์‹ค์ฒœ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค

FID ๊ฒฐ๊ณผ๋Š” ๋งŽ์€ ์š”์†Œ์— ์˜์กดํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์ทจ์•ฝํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

  • ๊ณ„์‚ฐ ์ค‘ ์‚ฌ์šฉ๋˜๋Š” ํŠน์ • Inception ๋ชจ๋ธ.
  • ๊ณ„์‚ฐ์˜ ๊ตฌํ˜„ ์ •ํ™•๋„.
  • ์ด๋ฏธ์ง€ ํ˜•์‹ (PNG ๋˜๋Š” JPG์—์„œ ์‹œ์ž‘ํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ๋‹ค๋ฆ…๋‹ˆ๋‹ค).

์ด๋Ÿฌํ•œ ์‚ฌํ•ญ์„ ์—ผ๋‘์— ๋‘๋ฉด, FID๋Š” ์œ ์‚ฌํ•œ ์‹คํ–‰์„ ๋น„๊ตํ•  ๋•Œ ๊ฐ€์žฅ ์œ ์šฉํ•˜์ง€๋งŒ, ์ €์ž๊ฐ€ FID ์ธก์ • ์ฝ”๋“œ๋ฅผ ์ฃผ์˜ ๊นŠ๊ฒŒ ๊ณต๊ฐœํ•˜์ง€ ์•Š๋Š” ํ•œ ๋…ผ๋ฌธ ๊ฒฐ๊ณผ๋ฅผ ์žฌํ˜„ํ•˜๊ธฐ๋Š” ์–ด๋ ต์Šต๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ์‚ฌํ•ญ์€ KID ๋ฐ IS์™€ ๊ฐ™์€ ๋‹ค๋ฅธ ๊ด€๋ จ ๋ฉ”ํŠธ๋ฆญ์—๋„ ์ ์šฉ๋ฉ๋‹ˆ๋‹ค.

๋งˆ์ง€๋ง‰ ๋‹จ๊ณ„๋กœ, fake_images๋ฅผ ์‹œ๊ฐ์ ์œผ๋กœ ๊ฒ€์‚ฌํ•ด ๋ด…์‹œ๋‹ค.

fake-images
Fake images.