Diffusion ๋ชจ๋ธ ํ๊ฐํ๊ธฐ[[evaluating-diffusion-models]]
Stable Diffusion์ ๊ฐ์ ์์ฑ ๋ชจ๋ธ์ ํ๊ฐ๋ ์ฃผ๊ด์ ์ธ ์ฑ๊ฒฉ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ค๋ฌด์์ ์ฐ๊ตฌ์๋ก์ ์ฐ๋ฆฌ๋ ์ข ์ข ๋ค์ํ ๊ฐ๋ฅ์ฑ ์ค์์ ์ ์คํ ์ ํ์ ํด์ผ ํฉ๋๋ค. ๊ทธ๋์ ๋ค์ํ ์์ฑ ๋ชจ๋ธ (GAN, Diffusion ๋ฑ)์ ์ฌ์ฉํ ๋ ์ด๋ป๊ฒ ์ ํํด์ผ ํ ๊น์?
์ ์ฑ์ ์ธ ํ๊ฐ๋ ๋ชจ๋ธ์ ์ด๋ฏธ์ง ํ์ง์ ๋ํ ์ฃผ๊ด์ ์ธ ํ๊ฐ์ด๋ฏ๋ก ์ค๋ฅ๊ฐ ๋ฐ์ํ ์ ์๊ณ ๊ฒฐ์ ์ ์๋ชป๋ ์ํฅ์ ๋ฏธ์น ์ ์์ต๋๋ค. ๋ฐ๋ฉด, ์ ๋์ ์ธ ํ๊ฐ๋ ์ด๋ฏธ์ง ํ์ง๊ณผ ์ง์ ์ ์ธ ์๊ด๊ด๊ณ๋ฅผ ๊ฐ์ง ์์ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ผ๋ฐ์ ์ผ๋ก ์ ์ฑ์ ํ๊ฐ์ ์ ๋์ ํ๊ฐ๋ฅผ ๋ชจ๋ ๊ณ ๋ คํ๋ ๊ฒ์ด ๋ ๊ฐ๋ ฅํ ์ ํธ๋ฅผ ์ ๊ณตํ์ฌ ๋ชจ๋ธ ์ ํ์ ๋์์ด ๋ฉ๋๋ค.
์ด ๋ฌธ์์์๋ Diffusion ๋ชจ๋ธ์ ํ๊ฐํ๊ธฐ ์ํ ์ ์ฑ์ ๋ฐ ์ ๋์ ๋ฐฉ๋ฒ์ ๋ํด ์์ธํ ์ค๋ช
ํฉ๋๋ค. ์ ๋์ ๋ฐฉ๋ฒ์ ๋ํด์๋ ํนํ diffusers
์ ํจ๊ป ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ์ด์ ์ ๋ง์ถ์์ต๋๋ค.
์ด ๋ฌธ์์์ ๋ณด์ฌ์ง ๋ฐฉ๋ฒ๋ค์ ๊ธฐ๋ฐ ์์ฑ ๋ชจ๋ธ์ ๊ณ ์ ์ํค๊ณ ๋ค์ํ ๋ ธ์ด์ฆ ์ค์ผ์ค๋ฌ๋ฅผ ํ๊ฐํ๋ ๋ฐ์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค.
์๋๋ฆฌ์ค[[scenarios]]
๋ค์๊ณผ ๊ฐ์ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ์ฌ Diffusion ๋ชจ๋ธ์ ๋ค๋ฃน๋๋ค:
- ํ
์คํธ๋ก ์๋ด๋ ์ด๋ฏธ์ง ์์ฑ (์:
StableDiffusionPipeline
). - ์
๋ ฅ ์ด๋ฏธ์ง์ ์ถ๊ฐ๋ก ์กฐ๊ฑด์ ๊ฑด ํ
์คํธ๋ก ์๋ด๋ ์ด๋ฏธ์ง ์์ฑ (์:
StableDiffusionImg2ImgPipeline
๋ฐStableDiffusionInstructPix2PixPipeline
). - ํด๋์ค ์กฐ๊ฑดํ๋ ์ด๋ฏธ์ง ์์ฑ ๋ชจ๋ธ (์:
DiTPipeline
).
์ ์ฑ์ ํ๊ฐ[[qualitative-evaluation]]
์ ์ฑ์ ํ๊ฐ๋ ์ผ๋ฐ์ ์ผ๋ก ์์ฑ๋ ์ด๋ฏธ์ง์ ์ธ๊ฐ ํ๊ฐ๋ฅผ ํฌํจํฉ๋๋ค. ํ์ง์ ๊ตฌ์ฑ์ฑ, ์ด๋ฏธ์ง-ํ ์คํธ ์ผ์น, ๊ณต๊ฐ ๊ด๊ณ ๋ฑ๊ณผ ๊ฐ์ ์ธก๋ฉด์์ ์ธก์ ๋ฉ๋๋ค. ์ผ๋ฐ์ ์ธ ํ๋กฌํํธ๋ ์ฃผ๊ด์ ์ธ ์งํ์ ๋ํ ์ผ์ ํ ๊ธฐ์ค์ ์ ๊ณตํฉ๋๋ค. DrawBench์ PartiPrompts๋ ์ ์ฑ์ ์ธ ๋ฒค์น๋งํน์ ์ฌ์ฉ๋๋ ํ๋กฌํํธ ๋ฐ์ดํฐ์ ์ ๋๋ค. DrawBench์ PartiPrompts๋ ๊ฐ๊ฐ Imagen๊ณผ Parti์์ ์๊ฐ๋์์ต๋๋ค.
Parti ๊ณต์ ์น์ฌ์ดํธ์์ ๋ค์๊ณผ ๊ฐ์ด ์ค๋ช ํ๊ณ ์์ต๋๋ค:
PartiPrompts (P2)๋ ์ด ์์ ์ ์ผ๋ถ๋ก ๊ณต๊ฐ๋๋ ์์ด๋ก ๋ 1600๊ฐ ์ด์์ ๋ค์ํ ํ๋กฌํํธ ์ธํธ์ ๋๋ค. P2๋ ๋ค์ํ ๋ฒ์ฃผ์ ๋์ ์ธก๋ฉด์์ ๋ชจ๋ธ์ ๋ฅ๋ ฅ์ ์ธก์ ํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์ต๋๋ค.
PartiPrompts๋ ๋ค์๊ณผ ๊ฐ์ ์ด์ ๊ฐ์ง๊ณ ์์ต๋๋ค:
- ํ๋กฌํํธ (Prompt)
- ํ๋กฌํํธ์ ์นดํ ๊ณ ๋ฆฌ (์: "Abstract", "World Knowledge" ๋ฑ)
- ๋์ด๋๋ฅผ ๋ฐ์ํ ์ฑ๋ฆฐ์ง (์: "Basic", "Complex", "Writing & Symbols" ๋ฑ)
์ด๋ฌํ ๋ฒค์น๋งํฌ๋ ์๋ก ๋ค๋ฅธ ์ด๋ฏธ์ง ์์ฑ ๋ชจ๋ธ์ ์ธ๊ฐ ํ๊ฐ๋ก ๋น๊ตํ ์ ์๋๋ก ํฉ๋๋ค.
์ด๋ฅผ ์ํด ๐งจ Diffusers ํ์ Open Parti Prompts๋ฅผ ๊ตฌ์ถํ์ต๋๋ค. ์ด๋ Parti Prompts๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ์ปค๋ฎค๋ํฐ ๊ธฐ๋ฐ์ ์ง์ ๋ฒค์น๋งํฌ๋ก, ์ต์ฒจ๋จ ์คํ ์์ค ํ์ฐ ๋ชจ๋ธ์ ๋น๊ตํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค:
- Open Parti Prompts ๊ฒ์: 10๊ฐ์ parti prompt์ ๋ํด 4๊ฐ์ ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ์ ์๋๋ฉฐ, ์ฌ์ฉ์๋ ํ๋กฌํํธ์ ๊ฐ์ฅ ์ ํฉํ ์ด๋ฏธ์ง๋ฅผ ์ ํํฉ๋๋ค.
- Open Parti Prompts ๋ฆฌ๋๋ณด๋: ํ์ฌ ์ต๊ณ ์ ์คํ ์์ค diffusion ๋ชจ๋ธ๋ค์ ์๋ก ๋น๊ตํ๋ ๋ฆฌ๋๋ณด๋์ ๋๋ค.
์ด๋ฏธ์ง๋ฅผ ์๋์ผ๋ก ๋น๊ตํ๋ ค๋ฉด, diffusers
๋ฅผ ์ฌ์ฉํ์ฌ ๋ช๊ฐ์ง PartiPrompts๋ฅผ ์ด๋ป๊ฒ ํ์ฉํ ์ ์๋์ง ์์๋ด
์๋ค.
๋ค์์ ๋ช ๊ฐ์ง ๋ค๋ฅธ ๋์ ์์ ์ํ๋งํ ํ๋กฌํํธ๋ฅผ ๋ณด์ฌ์ค๋๋ค: Basic, Complex, Linguistic Structures, Imagination, Writing & Symbols. ์ฌ๊ธฐ์๋ PartiPrompts๋ฅผ ๋ฐ์ดํฐ์ ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
from datasets import load_dataset
# prompts = load_dataset("nateraw/parti-prompts", split="train")
# prompts = prompts.shuffle()
# sample_prompts = [prompts[i]["Prompt"] for i in range(5)]
# Fixing these sample prompts in the interest of reproducibility.
sample_prompts = [
"a corgi",
"a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky",
"a car with no windows",
"a cube made of porcupine",
'The saying "BE EXCELLENT TO EACH OTHER" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.',
]
์ด์ ์ด๋ฐ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํ์ฌ Stable Diffusion (v1-4 checkpoint)๋ฅผ ์ฌ์ฉํ ์ด๋ฏธ์ง ์์ฑ์ ํ ์ ์์ต๋๋ค :
import torch
seed = 0
generator = torch.manual_seed(seed)
images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images
num_images_per_prompt
๋ฅผ ์ค์ ํ์ฌ ๋์ผํ ํ๋กฌํํธ์ ๋ํด ๋ค๋ฅธ ์ด๋ฏธ์ง๋ฅผ ๋น๊ตํ ์๋ ์์ต๋๋ค. ๋ค๋ฅธ ์ฒดํฌํฌ์ธํธ(v1-5)๋ก ๋์ผํ ํ์ดํ๋ผ์ธ์ ์คํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ ๋์ต๋๋ค:
๋ค์ํ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ ํ๋กฌํํธ์์ ์์ฑ๋ ์ฌ๋ฌ ์ด๋ฏธ์ง๋ค์ด ์์ฑ๋๋ฉด (ํ๊ฐ ๊ณผ์ ์์) ์ด๋ฌํ ๊ฒฐ๊ณผ๋ฌผ๋ค์ ์ฌ๋ ํ๊ฐ์๋ค์๊ฒ ์ ์๋ฅผ ๋งค๊ธฐ๊ธฐ ์ํด ์ ์๋ฉ๋๋ค. DrawBench์ PartiPrompts ๋ฒค์น๋งํฌ์ ๋ํ ์์ธํ ๋ด์ฉ์ ๊ฐ๊ฐ์ ๋ ผ๋ฌธ์ ์ฐธ์กฐํ์ญ์์ค.
๋ชจ๋ธ์ด ํ๋ จ ์ค์ผ ๋ ์ถ๋ก ์ํ์ ์ดํด๋ณด๋ ๊ฒ์ ํ๋ จ ์งํ ์ํฉ์ ์ธก์ ํ๋ ๋ฐ ์ ์ฉํฉ๋๋ค. ํ๋ จ ์คํฌ๋ฆฝํธ์์๋ TensorBoard์ Weights & Biases์ ๋ํ ์ถ๊ฐ ์ง์๊ณผ ํจ๊ป ์ด ์ ํธ๋ฆฌํฐ๋ฅผ ์ง์ํฉ๋๋ค.
์ ๋์ ํ๊ฐ[[quantitative-evaluation]]
์ด ์น์ ์์๋ ์ธ ๊ฐ์ง ๋ค๋ฅธ ํ์ฐ ํ์ดํ๋ผ์ธ์ ํ๊ฐํ๋ ๋ฐฉ๋ฒ์ ์๋ดํฉ๋๋ค:
- CLIP ์ ์
- CLIP ๋ฐฉํฅ์ฑ ์ ์ฌ๋
- FID
ํ ์คํธ ์๋ด ์ด๋ฏธ์ง ์์ฑ[[text-guided-image-generation]]
CLIP ์ ์๋ ์ด๋ฏธ์ง-์บก์ ์์ ํธํ์ฑ์ ์ธก์ ํฉ๋๋ค. ๋์ CLIP ์ ์๋ ๋์ ํธํ์ฑ๐ผ์ ๋ํ๋ ๋๋ค. CLIP ์ ์๋ ์ด๋ฏธ์ง์ ์บก์ ์ฌ์ด์ ์๋ฏธ์ ์ ์ฌ์ฑ์ผ๋ก ์๊ฐํ ์๋ ์์ต๋๋ค. CLIP ์ ์๋ ์ธ๊ฐ ํ๋จ๊ณผ ๋์ ์๊ด๊ด๊ณ๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค.
[StableDiffusionPipeline
]์ ์ผ๋จ ๋ก๋ํด๋ด
์๋ค:
from diffusers import StableDiffusionPipeline
import torch
model_ckpt = "CompVis/stable-diffusion-v1-4"
sd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to("cuda")
์ฌ๋ฌ ๊ฐ์ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค:
prompts = [
"a photo of an astronaut riding a horse on mars",
"A high tech solarpunk utopia in the Amazon rainforest",
"A pikachu fine dining with a view to the Eiffel Tower",
"A mecha robot in a favela in expressionist style",
"an insect robot preparing a delicious meal",
"A small cabin on top of a snowy mountain in the style of Disney, artstation",
]
images = sd_pipeline(prompts, num_images_per_prompt=1, output_type="np").images
print(images.shape)
# (6, 512, 512, 3)
๊ทธ๋ฌ๊ณ ๋์ CLIP ์ ์๋ฅผ ๊ณ์ฐํฉ๋๋ค.
from torchmetrics.functional.multimodal import clip_score
from functools import partial
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
def calculate_clip_score(images, prompts):
images_int = (images * 255).astype("uint8")
clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach()
return round(float(clip_score), 4)
sd_clip_score = calculate_clip_score(images, prompts)
print(f"CLIP score: {sd_clip_score}")
# CLIP score: 35.7038
์์ ์์ ์์๋ ๊ฐ ํ๋กฌํํธ ๋น ํ๋์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ์ต๋๋ค. ๋ง์ฝ ํ๋กฌํํธ ๋น ์ฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค๋ฉด, ํ๋กฌํํธ ๋น ์์ฑ๋ ์ด๋ฏธ์ง์ ํ๊ท ์ ์๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
์ด์ [StableDiffusionPipeline
]๊ณผ ํธํ๋๋ ๋ ๊ฐ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋น๊ตํ๋ ค๋ฉด, ํ์ดํ๋ผ์ธ์ ํธ์ถํ ๋ generator๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค. ๋จผ์ , ๊ณ ์ ๋ ์๋๋ก v1-4 Stable Diffusion ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค:
seed = 0
generator = torch.manual_seed(seed)
images = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
๊ทธ๋ฐ ๋ค์ v1-5 checkpoint๋ฅผ ๋ก๋ํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค:
model_ckpt_1_5 = "runwayml/stable-diffusion-v1-5"
sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=weight_dtype).to(device)
images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
๊ทธ๋ฆฌ๊ณ ๋ง์ง๋ง์ผ๋ก CLIP ์ ์๋ฅผ ๋น๊ตํฉ๋๋ค:
sd_clip_score_1_4 = calculate_clip_score(images, prompts)
print(f"CLIP Score with v-1-4: {sd_clip_score_1_4}")
# CLIP Score with v-1-4: 34.9102
sd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts)
print(f"CLIP Score with v-1-5: {sd_clip_score_1_5}")
# CLIP Score with v-1-5: 36.2137
v1-5 ์ฒดํฌํฌ์ธํธ๊ฐ ์ด์ ๋ฒ์ ๋ณด๋ค ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ด๋ ๊ฒ ๊ฐ์ต๋๋ค. ๊ทธ๋ฌ๋ CLIP ์ ์๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด ์ฌ์ฉํ ํ๋กฌํํธ์ ์๊ฐ ์๋นํ ์ ์ต๋๋ค. ๋ณด๋ค ์ค์ฉ์ ์ธ ํ๊ฐ๋ฅผ ์ํด์๋ ์ด ์๋ฅผ ํจ์ฌ ๋๊ฒ ์ค์ ํ๊ณ , ํ๋กฌํํธ๋ฅผ ๋ค์ํ๊ฒ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
์ด ์ ์์๋ ๋ช ๊ฐ์ง ์ ํ ์ฌํญ์ด ์์ต๋๋ค. ํ๋ จ ๋ฐ์ดํฐ์
์ ์บก์
์ ์น์์ ํฌ๋กค๋ง๋์ด ์ด๋ฏธ์ง์ ๊ด๋ จ๋ alt
๋ฐ ์ ์ฌํ ํ๊ทธ์์ ์ถ์ถ๋์์ต๋๋ค. ์ด๋ค์ ์ธ๊ฐ์ด ์ด๋ฏธ์ง๋ฅผ ์ค๋ช
ํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ ๊ฒ๊ณผ ์ผ์นํ์ง ์์ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ฌ๊ธฐ์๋ ๋ช ๊ฐ์ง ํ๋กฌํํธ๋ฅผ "์์ง๋์ด๋ง"ํด์ผ ํ์ต๋๋ค.
์ด๋ฏธ์ง ์กฐ๊ฑดํ๋ ํ ์คํธ-์ด๋ฏธ์ง ์์ฑ[[image-conditioned-text-to-image-generation]]
์ด ๊ฒฝ์ฐ, ์์ฑ ํ์ดํ๋ผ์ธ์ ์
๋ ฅ ์ด๋ฏธ์ง์ ํ
์คํธ ํ๋กฌํํธ๋ก ์กฐ๊ฑดํํฉ๋๋ค. [StableDiffusionInstructPix2PixPipeline
]์ ์๋ก ๋ค์ด๋ณด๊ฒ ์ต๋๋ค. ์ด๋ ํธ์ง ์ง์๋ฌธ์ ์
๋ ฅ ํ๋กฌํํธ๋ก ์ฌ์ฉํ๊ณ ํธ์งํ ์
๋ ฅ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํฉ๋๋ค.
๋ค์์ ํ๋์ ์์์ ๋๋ค:
๋ชจ๋ธ์ ํ๊ฐํ๋ ํ ๊ฐ์ง ์ ๋ต์ ๋ ์ด๋ฏธ์ง ์บก์ ๊ฐ์ ๋ณ๊ฒฝ๊ณผ(CLIP-Guided Domain Adaptation of Image Generators์์ ๋ณด์ฌ์ค๋๋ค) ํจ๊ป ๋ ์ด๋ฏธ์ง ์ฌ์ด์ ๋ณ๊ฒฝ์ ์ผ๊ด์ฑ์ ์ธก์ ํ๋ ๊ฒ์ ๋๋ค (CLIP ๊ณต๊ฐ์์). ์ด๋ฅผ "CLIP ๋ฐฉํฅ์ฑ ์ ์ฌ์ฑ"์ด๋ผ๊ณ ํฉ๋๋ค.
- ์บก์ 1์ ํธ์งํ ์ด๋ฏธ์ง (์ด๋ฏธ์ง 1)์ ํด๋นํฉ๋๋ค.
- ์บก์ 2๋ ํธ์ง๋ ์ด๋ฏธ์ง (์ด๋ฏธ์ง 2)์ ํด๋นํฉ๋๋ค. ํธ์ง ์ง์๋ฅผ ๋ฐ์ํด์ผ ํฉ๋๋ค.
๋ค์์ ๊ทธ๋ฆผ์ผ๋ก ๋ ๊ฐ์์ ๋๋ค:
์ฐ๋ฆฌ๋ ์ด ์ธก์ ํญ๋ชฉ์ ๊ตฌํํ๊ธฐ ์ํด ๋ฏธ๋ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ค๋นํ์ต๋๋ค. ๋จผ์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ก๋ํด ๋ณด๊ฒ ์ต๋๋ค.
from datasets import load_dataset
dataset = load_dataset("sayakpaul/instructpix2pix-demo", split="train")
dataset.features
{'input': Value(dtype='string', id=None),
'edit': Value(dtype='string', id=None),
'output': Value(dtype='string', id=None),
'image': Image(decode=True, id=None)}
์ฌ๊ธฐ์๋ ๋ค์๊ณผ ๊ฐ์ ํญ๋ชฉ์ด ์์ต๋๋ค:
input
์image
์ ํด๋นํ๋ ์บก์ ์ ๋๋ค.edit
์ ํธ์ง ์ง์์ฌํญ์ ๋ํ๋ ๋๋ค.output
์edit
์ง์์ฌํญ์ ๋ฐ์ํ ์์ ๋ ์บก์ ์ ๋๋ค.
์ํ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
idx = 0
print(f"Original caption: {dataset[idx]['input']}")
print(f"Edit instruction: {dataset[idx]['edit']}")
print(f"Modified caption: {dataset[idx]['output']}")
Original caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'
Edit instruction: make the isles all white marble
Modified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills'
๋ค์์ ์ด๋ฏธ์ง์ ๋๋ค:
dataset[idx]["image"]
๋จผ์ ํธ์ง ์ง์์ฌํญ์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ์ ์ด๋ฏธ์ง๋ฅผ ํธ์งํ๊ณ ๋ฐฉํฅ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค.
[StableDiffusionInstructPix2PixPipeline
]๋ฅผ ๋จผ์ ๋ก๋ํฉ๋๋ค:
from diffusers import StableDiffusionInstructPix2PixPipeline
instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16
).to(device)
์ด์ ํธ์ง์ ์ํํฉ๋๋ค:
import numpy as np
def edit_image(input_image, instruction):
image = instruct_pix2pix_pipeline(
instruction,
image=input_image,
output_type="np",
generator=generator,
).images[0]
return image
input_images = []
original_captions = []
modified_captions = []
edited_images = []
for idx in range(len(dataset)):
input_image = dataset[idx]["image"]
edit_instruction = dataset[idx]["edit"]
edited_image = edit_image(input_image, edit_instruction)
input_images.append(np.array(input_image))
original_captions.append(dataset[idx]["input"])
modified_captions.append(dataset[idx]["output"])
edited_images.append(edited_image)
๋ฐฉํฅ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด์๋ ๋จผ์ CLIP์ ์ด๋ฏธ์ง์ ํ ์คํธ ์ธ์ฝ๋๋ฅผ ๋ก๋ํฉ๋๋ค:
from transformers import (
CLIPTokenizer,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
CLIPImageProcessor,
)
clip_id = "openai/clip-vit-large-patch14"
tokenizer = CLIPTokenizer.from_pretrained(clip_id)
text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(device)
image_processor = CLIPImageProcessor.from_pretrained(clip_id)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device)
์ฃผ๋ชฉํ ์ ์ ํน์ ํ CLIP ์ฒดํฌํฌ์ธํธ์ธ openai/clip-vit-large-patch14
๋ฅผ ์ฌ์ฉํ๊ณ ์๋ค๋ ๊ฒ์
๋๋ค. ์ด๋ Stable Diffusion ์ฌ์ ํ๋ จ์ด ์ด CLIP ๋ณํ์ฒด์ ํจ๊ป ์ํ๋์๊ธฐ ๋๋ฌธ์
๋๋ค. ์์ธํ ๋ด์ฉ์ ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์.
๋ค์์ผ๋ก, ๋ฐฉํฅ์ฑ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด PyTorch์ nn.Module
์ ์ค๋นํฉ๋๋ค:
import torch.nn as nn
import torch.nn.functional as F
class DirectionalSimilarity(nn.Module):
def __init__(self, tokenizer, text_encoder, image_processor, image_encoder):
super().__init__()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
self.image_processor = image_processor
self.image_encoder = image_encoder
def preprocess_image(self, image):
image = self.image_processor(image, return_tensors="pt")["pixel_values"]
return {"pixel_values": image.to(device)}
def tokenize_text(self, text):
inputs = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
return {"input_ids": inputs.input_ids.to(device)}
def encode_image(self, image):
preprocessed_image = self.preprocess_image(image)
image_features = self.image_encoder(**preprocessed_image).image_embeds
image_features = image_features / image_features.norm(dim=1, keepdim=True)
return image_features
def encode_text(self, text):
tokenized_text = self.tokenize_text(text)
text_features = self.text_encoder(**tokenized_text).text_embeds
text_features = text_features / text_features.norm(dim=1, keepdim=True)
return text_features
def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two):
sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one)
return sim_direction
def forward(self, image_one, image_two, caption_one, caption_two):
img_feat_one = self.encode_image(image_one)
img_feat_two = self.encode_image(image_two)
text_feat_one = self.encode_text(caption_one)
text_feat_two = self.encode_text(caption_two)
directional_similarity = self.compute_directional_similarity(
img_feat_one, img_feat_two, text_feat_one, text_feat_two
)
return directional_similarity
์ด์ DirectionalSimilarity
๋ฅผ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค.
dir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder)
scores = []
for i in range(len(input_images)):
original_image = input_images[i]
original_caption = original_captions[i]
edited_image = edited_images[i]
modified_caption = modified_captions[i]
similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption)
scores.append(float(similarity_score.detach().cpu()))
print(f"CLIP directional similarity: {np.mean(scores)}")
# CLIP directional similarity: 0.0797976553440094
CLIP ์ ์์ ๋ง์ฐฌ๊ฐ์ง๋ก, CLIP ๋ฐฉํฅ ์ ์ฌ์ฑ์ด ๋์์๋ก ์ข์ต๋๋ค.
StableDiffusionInstructPix2PixPipeline
์ image_guidance_scale
๊ณผ guidance_scale
์ด๋ผ๋ ๋ ๊ฐ์ง ์ธ์๋ฅผ ๋
ธ์ถ์ํต๋๋ค. ์ด ๋ ์ธ์๋ฅผ ์กฐ์ ํ์ฌ ์ต์ข
ํธ์ง๋ ์ด๋ฏธ์ง์ ํ์ง์ ์ ์ดํ ์ ์์ต๋๋ค. ์ด ๋ ์ธ์์ ์ํฅ์ ์คํํด๋ณด๊ณ ๋ฐฉํฅ ์ ์ฌ์ฑ์ ๋ฏธ์น๋ ์ํฅ์ ํ์ธํด๋ณด๊ธฐ๋ฅผ ๊ถ์ฅํฉ๋๋ค.
์ด๋ฌํ ๋ฉํธ๋ฆญ์ ๊ฐ๋
์ ํ์ฅํ์ฌ ์๋ณธ ์ด๋ฏธ์ง์ ํธ์ง๋ ๋ฒ์ ์ ์ ์ฌ์ฑ์ ์ธก์ ํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด F.cosine_similarity(img_feat_two, img_feat_one)
์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด๋ฌํ ์ข
๋ฅ์ ํธ์ง์์๋ ์ด๋ฏธ์ง์ ์ฃผ์ ์๋ฏธ๊ฐ ์ต๋ํ ๋ณด์กด๋์ด์ผ ํฉ๋๋ค. ์ฆ, ๋์ ์ ์ฌ์ฑ ์ ์๋ฅผ ์ป์ด์ผ ํฉ๋๋ค.
StableDiffusionPix2PixZeroPipeline
์ ๊ฐ์ ์ ์ฌํ ํ์ดํ๋ผ์ธ์๋ ์ด๋ฌํ ๋ฉํธ๋ฆญ์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
CLIP ์ ์์ CLIP ๋ฐฉํฅ ์ ์ฌ์ฑ ๋ชจ๋ CLIP ๋ชจ๋ธ์ ์์กดํ๊ธฐ ๋๋ฌธ์ ํ๊ฐ๊ฐ ํธํฅ๋ ์ ์์ต๋๋ค
IS, FID (๋์ค์ ์ค๋ช ํ ์์ ), ๋๋ KID์ ๊ฐ์ ๋ฉํธ๋ฆญ์ ํ์ฅํ๋ ๊ฒ์ ์ด๋ ค์ธ ์ ์์ต๋๋ค. ํ๊ฐ ์ค์ธ ๋ชจ๋ธ์ด ๋๊ท๋ชจ ์ด๋ฏธ์ง ์บก์ ๋ ๋ฐ์ดํฐ์ (์: LAION-5B ๋ฐ์ดํฐ์ )์์ ์ฌ์ ํ๋ จ๋์์ ๋ ์ด๋ ๋ฌธ์ ๊ฐ ๋ ์ ์์ต๋๋ค. ์๋ํ๋ฉด ์ด๋ฌํ ๋ฉํธ๋ฆญ์ ๊ธฐ๋ฐ์๋ ์ค๊ฐ ์ด๋ฏธ์ง ํน์ง์ ์ถ์ถํ๊ธฐ ์ํด ImageNet-1k ๋ฐ์ดํฐ์ ์์ ์ฌ์ ํ๋ จ๋ InceptionNet์ด ์ฌ์ฉ๋๊ธฐ ๋๋ฌธ์ ๋๋ค. Stable Diffusion์ ์ฌ์ ํ๋ จ ๋ฐ์ดํฐ์ ์ InceptionNet์ ์ฌ์ ํ๋ จ ๋ฐ์ดํฐ์ ๊ณผ ๊ฒน์น๋ ๋ถ๋ถ์ด ์ ํ์ ์ผ ์ ์์ผ๋ฏ๋ก ๋ฐ๋ผ์ ์ฌ๊ธฐ์๋ ์ข์ ํ๋ณด๊ฐ ์๋๋๋ค.
์์ ๋ฉํธ๋ฆญ์ ์ฌ์ฉํ๋ฉด ํด๋์ค ์กฐ๊ฑด์ด ์๋ ๋ชจ๋ธ์ ํ๊ฐํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, DiT. ์ด๋ ImageNet-1k ํด๋์ค์ ์กฐ๊ฑด์ ๊ฑธ๊ณ ์ฌ์ ํ๋ จ๋์์ต๋๋ค.
ํด๋์ค ์กฐ๊ฑดํ ์ด๋ฏธ์ง ์์ฑ[[class-conditioned-image-generation]]
ํด๋์ค ์กฐ๊ฑดํ ์์ฑ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก ImageNet-1k์ ๊ฐ์ ํด๋์ค ๋ ์ด๋ธ์ด ์ง์ ๋ ๋ฐ์ดํฐ์
์์ ์ฌ์ ํ๋ จ๋ฉ๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ์ ํ๊ฐํ๋ ์ธ๊ธฐ์๋ ์งํ์๋ Frรฉchet Inception Distance (FID), Kernel Inception Distance (KID) ๋ฐ Inception Score (IS)๊ฐ ์์ต๋๋ค. ์ด ๋ฌธ์์์๋ FID (Heusel et al.)์ ์ด์ ์ ๋ง์ถ๊ณ ์์ต๋๋ค. DiTPipeline
์ ์ฌ์ฉํ์ฌ FID๋ฅผ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. ์ด๋ ๋ด๋ถ์ ์ผ๋ก DiT ๋ชจ๋ธ์ ์ฌ์ฉํฉ๋๋ค.
FID๋ ๋ ๊ฐ์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ์ด ์ผ๋ง๋ ์ ์ฌํ์ง๋ฅผ ์ธก์ ํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. ์ด ์๋ฃ์ ๋ฐ๋ฅด๋ฉด:
Frรฉchet Inception Distance๋ ๋ ๊ฐ์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ๊ฐ์ ์ ์ฌ์ฑ์ ์ธก์ ํ๋ ์งํ์ ๋๋ค. ์๊ฐ์ ํ์ง์ ๋ํ ์ธ๊ฐ ํ๋จ๊ณผ ์ ์๊ด๋๋ ๊ฒ์ผ๋ก ๋ํ๋ฌ์ผ๋ฉฐ, ์ฃผ๋ก ์์ฑ์ ์ ๋ ์ ๊ฒฝ๋ง์ ์ํ ํ์ง์ ํ๊ฐํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. FID๋ Inception ๋คํธ์ํฌ์ ํน์ง ํํ์ ๋ง๊ฒ ์ ํฉํ ๋ ๊ฐ์ ๊ฐ์ฐ์์ ์ฌ์ด์ Frรฉchet ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ์ฌ ๊ตฌํฉ๋๋ค.
์ด ๋ ๊ฐ์ ๋ฐ์ดํฐ์ ์ ์ค์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ ๊ณผ ๊ฐ์ง ์ด๋ฏธ์ง ๋ฐ์ดํฐ์ (์ฐ๋ฆฌ์ ๊ฒฝ์ฐ ์์ฑ๋ ์ด๋ฏธ์ง)์ ๋๋ค. FID๋ ์ผ๋ฐ์ ์ผ๋ก ๋ ๊ฐ์ ํฐ ๋ฐ์ดํฐ์ ์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด ๋ฌธ์์์๋ ๋ ๊ฐ์ ๋ฏธ๋ ๋ฐ์ดํฐ์ ์ผ๋ก ์์ ํ ๊ฒ์ ๋๋ค.
๋จผ์ ImageNet-1k ํ๋ จ ์ธํธ์์ ๋ช ๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ๋ค์ด๋ก๋ํด ๋ด ์๋ค:
from zipfile import ZipFile
import requests
def download(url, local_filepath):
r = requests.get(url)
with open(local_filepath, "wb") as f:
f.write(r.content)
return local_filepath
dummy_dataset_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip"
local_filepath = download(dummy_dataset_url, dummy_dataset_url.split("/")[-1])
with ZipFile(local_filepath, "r") as zipper:
zipper.extractall(".")
from PIL import Image
import os
dataset_path = "sample-imagenet-images"
image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])
real_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths]
๋ค์์ ImageNet-1k classes์ ์ด๋ฏธ์ง 10๊ฐ์ ๋๋ค : "cassette_player", "chain_saw" (x2), "church", "gas_pump" (x3), "parachute" (x2), ๊ทธ๋ฆฌ๊ณ "tench".
Real images.
์ด์ ์ด๋ฏธ์ง๊ฐ ๋ก๋๋์์ผ๋ฏ๋ก ์ด๋ฏธ์ง์ ๊ฐ๋ฒผ์ด ์ ์ฒ๋ฆฌ๋ฅผ ์ ์ฉํ์ฌ FID ๊ณ์ฐ์ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค.
from torchvision.transforms import functional as F
def preprocess_image(image):
image = torch.tensor(image).unsqueeze(0)
image = image.permute(0, 3, 1, 2) / 255.0
return F.center_crop(image, (256, 256))
real_images = torch.cat([preprocess_image(image) for image in real_images])
print(real_images.shape)
# torch.Size([10, 3, 256, 256])
์ด์ ์์์ ์ธ๊ธํ ํด๋์ค์ ๋ฐ๋ผ ์กฐ๊ฑดํ ๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ํด DiTPipeline
๋ฅผ ๋ก๋ํฉ๋๋ค.
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
dit_pipeline = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)
dit_pipeline = dit_pipeline.to("cuda")
words = [
"cassette player",
"chainsaw",
"chainsaw",
"church",
"gas pump",
"gas pump",
"gas pump",
"parachute",
"parachute",
"tench",
]
class_ids = dit_pipeline.get_label_ids(words)
output = dit_pipeline(class_labels=class_ids, generator=generator, output_type="np")
fake_images = output.images
fake_images = torch.tensor(fake_images)
fake_images = fake_images.permute(0, 3, 1, 2)
print(fake_images.shape)
# torch.Size([10, 3, 256, 256])
์ด์ torchmetrics
๋ฅผ ์ฌ์ฉํ์ฌ FID๋ฅผ ๊ณ์ฐํ ์ ์์ต๋๋ค.
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(normalize=True)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
print(f"FID: {float(fid.compute())}")
# FID: 177.7147216796875
FID๋ ๋ฎ์์๋ก ์ข์ต๋๋ค. ์ฌ๋ฌ ๊ฐ์ง ์์๊ฐ FID์ ์ํฅ์ ์ค ์ ์์ต๋๋ค:
- ์ด๋ฏธ์ง์ ์ (์ค์ ์ด๋ฏธ์ง์ ๊ฐ์ง ์ด๋ฏธ์ง ๋ชจ๋)
- diffusion ๊ณผ์ ์์ ๋ฐ์ํ๋ ๋ฌด์์์ฑ
- diffusion ๊ณผ์ ์์์ ์ถ๋ก ๋จ๊ณ ์
- diffusion ๊ณผ์ ์์ ์ฌ์ฉ๋๋ ์ค์ผ์ค๋ฌ
๋ง์ง๋ง ๋ ๊ฐ์ง ์์์ ๋ํด์๋, ๋ค๋ฅธ ์๋์ ์ถ๋ก ๋จ๊ณ์์ ํ๊ฐ๋ฅผ ์คํํ๊ณ ํ๊ท ๊ฒฐ๊ณผ๋ฅผ ๋ณด๊ณ ํ๋ ๊ฒ์ ์ข์ ์ค์ฒ ๋ฐฉ๋ฒ์ ๋๋ค
FID ๊ฒฐ๊ณผ๋ ๋ง์ ์์์ ์์กดํ๊ธฐ ๋๋ฌธ์ ์ทจ์ฝํ ์ ์์ต๋๋ค:
- ๊ณ์ฐ ์ค ์ฌ์ฉ๋๋ ํน์ Inception ๋ชจ๋ธ.
- ๊ณ์ฐ์ ๊ตฌํ ์ ํ๋.
- ์ด๋ฏธ์ง ํ์ (PNG ๋๋ JPG์์ ์์ํ๋ ๊ฒฝ์ฐ๊ฐ ๋ค๋ฆ ๋๋ค).
์ด๋ฌํ ์ฌํญ์ ์ผ๋์ ๋๋ฉด, FID๋ ์ ์ฌํ ์คํ์ ๋น๊ตํ ๋ ๊ฐ์ฅ ์ ์ฉํ์ง๋ง, ์ ์๊ฐ FID ์ธก์ ์ฝ๋๋ฅผ ์ฃผ์ ๊น๊ฒ ๊ณต๊ฐํ์ง ์๋ ํ ๋ ผ๋ฌธ ๊ฒฐ๊ณผ๋ฅผ ์ฌํํ๊ธฐ๋ ์ด๋ ต์ต๋๋ค.
์ด๋ฌํ ์ฌํญ์ KID ๋ฐ IS์ ๊ฐ์ ๋ค๋ฅธ ๊ด๋ จ ๋ฉํธ๋ฆญ์๋ ์ ์ฉ๋ฉ๋๋ค.
๋ง์ง๋ง ๋จ๊ณ๋ก, fake_images
๋ฅผ ์๊ฐ์ ์ผ๋ก ๊ฒ์ฌํด ๋ด
์๋ค.
Fake images.