|
<!--Copyright 2024 The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with |
|
the License. You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on |
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
specific language governing permissions and limitations under the License. |
|
--> |
|
|
|
# Diffusion ๋ชจ๋ธ ํ๊ฐํ๊ธฐ[[evaluating-diffusion-models]] |
|
|
|
<a target="_blank" href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/evaluation.ipynb"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> |
|
</a> |
|
|
|
[Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion)์ ๊ฐ์ ์์ฑ ๋ชจ๋ธ์ ํ๊ฐ๋ ์ฃผ๊ด์ ์ธ ์ฑ๊ฒฉ์ ๊ฐ์ง๊ณ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ค๋ฌด์์ ์ฐ๊ตฌ์๋ก์ ์ฐ๋ฆฌ๋ ์ข
์ข
๋ค์ํ ๊ฐ๋ฅ์ฑ ์ค์์ ์ ์คํ ์ ํ์ ํด์ผ ํฉ๋๋ค. ๊ทธ๋์ ๋ค์ํ ์์ฑ ๋ชจ๋ธ (GAN, Diffusion ๋ฑ)์ ์ฌ์ฉํ ๋ ์ด๋ป๊ฒ ์ ํํด์ผ ํ ๊น์? |
|
|
|
์ ์ฑ์ ์ธ ํ๊ฐ๋ ๋ชจ๋ธ์ ์ด๋ฏธ์ง ํ์ง์ ๋ํ ์ฃผ๊ด์ ์ธ ํ๊ฐ์ด๋ฏ๋ก ์ค๋ฅ๊ฐ ๋ฐ์ํ ์ ์๊ณ ๊ฒฐ์ ์ ์๋ชป๋ ์ํฅ์ ๋ฏธ์น ์ ์์ต๋๋ค. ๋ฐ๋ฉด, ์ ๋์ ์ธ ํ๊ฐ๋ ์ด๋ฏธ์ง ํ์ง๊ณผ ์ง์ ์ ์ธ ์๊ด๊ด๊ณ๋ฅผ ๊ฐ์ง ์์ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ผ๋ฐ์ ์ผ๋ก ์ ์ฑ์ ํ๊ฐ์ ์ ๋์ ํ๊ฐ๋ฅผ ๋ชจ๋ ๊ณ ๋ คํ๋ ๊ฒ์ด ๋ ๊ฐ๋ ฅํ ์ ํธ๋ฅผ ์ ๊ณตํ์ฌ ๋ชจ๋ธ ์ ํ์ ๋์์ด ๋ฉ๋๋ค. |
|
|
|
์ด ๋ฌธ์์์๋ Diffusion ๋ชจ๋ธ์ ํ๊ฐํ๊ธฐ ์ํ ์ ์ฑ์ ๋ฐ ์ ๋์ ๋ฐฉ๋ฒ์ ๋ํด ์์ธํ ์ค๋ช
ํฉ๋๋ค. ์ ๋์ ๋ฐฉ๋ฒ์ ๋ํด์๋ ํนํ `diffusers`์ ํจ๊ป ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ์ด์ ์ ๋ง์ถ์์ต๋๋ค. |
|
|
|
์ด ๋ฌธ์์์ ๋ณด์ฌ์ง ๋ฐฉ๋ฒ๋ค์ ๊ธฐ๋ฐ ์์ฑ ๋ชจ๋ธ์ ๊ณ ์ ์ํค๊ณ ๋ค์ํ [๋
ธ์ด์ฆ ์ค์ผ์ค๋ฌ](https://huggingface.co/docs/diffusers/main/en/api/schedulers/overview)๋ฅผ ํ๊ฐํ๋ ๋ฐ์๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
|
|
## ์๋๋ฆฌ์ค[[scenarios]] |
|
๋ค์๊ณผ ๊ฐ์ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ์ฌ Diffusion ๋ชจ๋ธ์ ๋ค๋ฃน๋๋ค: |
|
|
|
- ํ
์คํธ๋ก ์๋ด๋ ์ด๋ฏธ์ง ์์ฑ (์: [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img)). |
|
- ์
๋ ฅ ์ด๋ฏธ์ง์ ์ถ๊ฐ๋ก ์กฐ๊ฑด์ ๊ฑด ํ
์คํธ๋ก ์๋ด๋ ์ด๋ฏธ์ง ์์ฑ (์: [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/img2img) ๋ฐ [`StableDiffusionInstructPix2PixPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix)). |
|
- ํด๋์ค ์กฐ๊ฑดํ๋ ์ด๋ฏธ์ง ์์ฑ ๋ชจ๋ธ (์: [`DiTPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit)). |
|
|
|
## ์ ์ฑ์ ํ๊ฐ[[qualitative-evaluation]] |
|
|
|
์ ์ฑ์ ํ๊ฐ๋ ์ผ๋ฐ์ ์ผ๋ก ์์ฑ๋ ์ด๋ฏธ์ง์ ์ธ๊ฐ ํ๊ฐ๋ฅผ ํฌํจํฉ๋๋ค. ํ์ง์ ๊ตฌ์ฑ์ฑ, ์ด๋ฏธ์ง-ํ
์คํธ ์ผ์น, ๊ณต๊ฐ ๊ด๊ณ ๋ฑ๊ณผ ๊ฐ์ ์ธก๋ฉด์์ ์ธก์ ๋ฉ๋๋ค. ์ผ๋ฐ์ ์ธ ํ๋กฌํํธ๋ ์ฃผ๊ด์ ์ธ ์งํ์ ๋ํ ์ผ์ ํ ๊ธฐ์ค์ ์ ๊ณตํฉ๋๋ค. |
|
DrawBench์ PartiPrompts๋ ์ ์ฑ์ ์ธ ๋ฒค์น๋งํน์ ์ฌ์ฉ๋๋ ํ๋กฌํํธ ๋ฐ์ดํฐ์
์
๋๋ค. DrawBench์ PartiPrompts๋ ๊ฐ๊ฐ [Imagen](https://imagen.research.google/)๊ณผ [Parti](https://parti.research.google/)์์ ์๊ฐ๋์์ต๋๋ค. |
|
|
|
[Parti ๊ณต์ ์น์ฌ์ดํธ](https://parti.research.google/)์์ ๋ค์๊ณผ ๊ฐ์ด ์ค๋ช
ํ๊ณ ์์ต๋๋ค: |
|
|
|
> PartiPrompts (P2)๋ ์ด ์์
์ ์ผ๋ถ๋ก ๊ณต๊ฐ๋๋ ์์ด๋ก ๋ 1600๊ฐ ์ด์์ ๋ค์ํ ํ๋กฌํํธ ์ธํธ์
๋๋ค. P2๋ ๋ค์ํ ๋ฒ์ฃผ์ ๋์ ์ธก๋ฉด์์ ๋ชจ๋ธ์ ๋ฅ๋ ฅ์ ์ธก์ ํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
|
|
![parti-prompts](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts.png) |
|
|
|
PartiPrompts๋ ๋ค์๊ณผ ๊ฐ์ ์ด์ ๊ฐ์ง๊ณ ์์ต๋๋ค: |
|
|
|
- ํ๋กฌํํธ (Prompt) |
|
- ํ๋กฌํํธ์ ์นดํ
๊ณ ๋ฆฌ (์: "Abstract", "World Knowledge" ๋ฑ) |
|
- ๋์ด๋๋ฅผ ๋ฐ์ํ ์ฑ๋ฆฐ์ง (์: "Basic", "Complex", "Writing & Symbols" ๋ฑ) |
|
|
|
์ด๋ฌํ ๋ฒค์น๋งํฌ๋ ์๋ก ๋ค๋ฅธ ์ด๋ฏธ์ง ์์ฑ ๋ชจ๋ธ์ ์ธ๊ฐ ํ๊ฐ๋ก ๋น๊ตํ ์ ์๋๋ก ํฉ๋๋ค. |
|
|
|
์ด๋ฅผ ์ํด ๐งจ Diffusers ํ์ **Open Parti Prompts**๋ฅผ ๊ตฌ์ถํ์ต๋๋ค. ์ด๋ Parti Prompts๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ ์ปค๋ฎค๋ํฐ ๊ธฐ๋ฐ์ ์ง์ ๋ฒค์น๋งํฌ๋ก, ์ต์ฒจ๋จ ์คํ ์์ค ํ์ฐ ๋ชจ๋ธ์ ๋น๊ตํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค: |
|
- [Open Parti Prompts ๊ฒ์](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts): 10๊ฐ์ parti prompt์ ๋ํด 4๊ฐ์ ์์ฑ๋ ์ด๋ฏธ์ง๊ฐ ์ ์๋๋ฉฐ, ์ฌ์ฉ์๋ ํ๋กฌํํธ์ ๊ฐ์ฅ ์ ํฉํ ์ด๋ฏธ์ง๋ฅผ ์ ํํฉ๋๋ค. |
|
- [Open Parti Prompts ๋ฆฌ๋๋ณด๋](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard): ํ์ฌ ์ต๊ณ ์ ์คํ ์์ค diffusion ๋ชจ๋ธ๋ค์ ์๋ก ๋น๊ตํ๋ ๋ฆฌ๋๋ณด๋์
๋๋ค. |
|
|
|
์ด๋ฏธ์ง๋ฅผ ์๋์ผ๋ก ๋น๊ตํ๋ ค๋ฉด, `diffusers`๋ฅผ ์ฌ์ฉํ์ฌ ๋ช๊ฐ์ง PartiPrompts๋ฅผ ์ด๋ป๊ฒ ํ์ฉํ ์ ์๋์ง ์์๋ด
์๋ค. |
|
|
|
๋ค์์ ๋ช ๊ฐ์ง ๋ค๋ฅธ ๋์ ์์ ์ํ๋งํ ํ๋กฌํํธ๋ฅผ ๋ณด์ฌ์ค๋๋ค: Basic, Complex, Linguistic Structures, Imagination, Writing & Symbols. ์ฌ๊ธฐ์๋ PartiPrompts๋ฅผ [๋ฐ์ดํฐ์
](https://huggingface.co/datasets/nateraw/parti-prompts)์ผ๋ก ์ฌ์ฉํฉ๋๋ค. |
|
|
|
```python |
|
from datasets import load_dataset |
|
|
|
# prompts = load_dataset("nateraw/parti-prompts", split="train") |
|
# prompts = prompts.shuffle() |
|
# sample_prompts = [prompts[i]["Prompt"] for i in range(5)] |
|
|
|
# Fixing these sample prompts in the interest of reproducibility. |
|
sample_prompts = [ |
|
"a corgi", |
|
"a hot air balloon with a yin-yang symbol, with the moon visible in the daytime sky", |
|
"a car with no windows", |
|
"a cube made of porcupine", |
|
'The saying "BE EXCELLENT TO EACH OTHER" written on a red brick wall with a graffiti image of a green alien wearing a tuxedo. A yellow fire hydrant is on a sidewalk in the foreground.', |
|
] |
|
``` |
|
์ด์ ์ด๋ฐ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํ์ฌ Stable Diffusion ([v1-4 checkpoint](https://huggingface.co/CompVis/stable-diffusion-v1-4))๋ฅผ ์ฌ์ฉํ ์ด๋ฏธ์ง ์์ฑ์ ํ ์ ์์ต๋๋ค : |
|
|
|
```python |
|
import torch |
|
|
|
seed = 0 |
|
generator = torch.manual_seed(seed) |
|
|
|
images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generator).images |
|
``` |
|
|
|
![parti-prompts-14](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-14.png) |
|
|
|
|
|
`num_images_per_prompt`๋ฅผ ์ค์ ํ์ฌ ๋์ผํ ํ๋กฌํํธ์ ๋ํด ๋ค๋ฅธ ์ด๋ฏธ์ง๋ฅผ ๋น๊ตํ ์๋ ์์ต๋๋ค. ๋ค๋ฅธ ์ฒดํฌํฌ์ธํธ([v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5))๋ก ๋์ผํ ํ์ดํ๋ผ์ธ์ ์คํํ๋ฉด ๋ค์๊ณผ ๊ฐ์ ๊ฒฐ๊ณผ๊ฐ ๋์ต๋๋ค: |
|
|
|
![parti-prompts-15](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/parti-prompts-15.png) |
|
|
|
|
|
๋ค์ํ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ๋ชจ๋ ํ๋กฌํํธ์์ ์์ฑ๋ ์ฌ๋ฌ ์ด๋ฏธ์ง๋ค์ด ์์ฑ๋๋ฉด (ํ๊ฐ ๊ณผ์ ์์) ์ด๋ฌํ ๊ฒฐ๊ณผ๋ฌผ๋ค์ ์ฌ๋ ํ๊ฐ์๋ค์๊ฒ ์ ์๋ฅผ ๋งค๊ธฐ๊ธฐ ์ํด ์ ์๋ฉ๋๋ค. DrawBench์ PartiPrompts ๋ฒค์น๋งํฌ์ ๋ํ ์์ธํ ๋ด์ฉ์ ๊ฐ๊ฐ์ ๋
ผ๋ฌธ์ ์ฐธ์กฐํ์ญ์์ค. |
|
|
|
<Tip> |
|
|
|
๋ชจ๋ธ์ด ํ๋ จ ์ค์ผ ๋ ์ถ๋ก ์ํ์ ์ดํด๋ณด๋ ๊ฒ์ ํ๋ จ ์งํ ์ํฉ์ ์ธก์ ํ๋ ๋ฐ ์ ์ฉํฉ๋๋ค. [ํ๋ จ ์คํฌ๋ฆฝํธ](https://github.com/huggingface/diffusers/tree/main/examples/)์์๋ TensorBoard์ Weights & Biases์ ๋ํ ์ถ๊ฐ ์ง์๊ณผ ํจ๊ป ์ด ์ ํธ๋ฆฌํฐ๋ฅผ ์ง์ํฉ๋๋ค. |
|
|
|
</Tip> |
|
|
|
## ์ ๋์ ํ๊ฐ[[quantitative-evaluation]] |
|
|
|
์ด ์น์
์์๋ ์ธ ๊ฐ์ง ๋ค๋ฅธ ํ์ฐ ํ์ดํ๋ผ์ธ์ ํ๊ฐํ๋ ๋ฐฉ๋ฒ์ ์๋ดํฉ๋๋ค: |
|
|
|
- CLIP ์ ์ |
|
- CLIP ๋ฐฉํฅ์ฑ ์ ์ฌ๋ |
|
- FID |
|
|
|
### ํ
์คํธ ์๋ด ์ด๋ฏธ์ง ์์ฑ[[text-guided-image-generation]] |
|
|
|
[CLIP ์ ์](https://arxiv.org/abs/2104.08718)๋ ์ด๋ฏธ์ง-์บก์
์์ ํธํ์ฑ์ ์ธก์ ํฉ๋๋ค. ๋์ CLIP ์ ์๋ ๋์ ํธํ์ฑ๐ผ์ ๋ํ๋
๋๋ค. CLIP ์ ์๋ ์ด๋ฏธ์ง์ ์บก์
์ฌ์ด์ ์๋ฏธ์ ์ ์ฌ์ฑ์ผ๋ก ์๊ฐํ ์๋ ์์ต๋๋ค. CLIP ์ ์๋ ์ธ๊ฐ ํ๋จ๊ณผ ๋์ ์๊ด๊ด๊ณ๋ฅผ ๊ฐ์ง๊ณ ์์ต๋๋ค. |
|
|
|
[`StableDiffusionPipeline`]์ ์ผ๋จ ๋ก๋ํด๋ด
์๋ค: |
|
|
|
```python |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
|
|
model_ckpt = "CompVis/stable-diffusion-v1-4" |
|
sd_pipeline = StableDiffusionPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16).to("cuda") |
|
``` |
|
|
|
์ฌ๋ฌ ๊ฐ์ ํ๋กฌํํธ๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค: |
|
|
|
```python |
|
prompts = [ |
|
"a photo of an astronaut riding a horse on mars", |
|
"A high tech solarpunk utopia in the Amazon rainforest", |
|
"A pikachu fine dining with a view to the Eiffel Tower", |
|
"A mecha robot in a favela in expressionist style", |
|
"an insect robot preparing a delicious meal", |
|
"A small cabin on top of a snowy mountain in the style of Disney, artstation", |
|
] |
|
|
|
images = sd_pipeline(prompts, num_images_per_prompt=1, output_type="np").images |
|
|
|
print(images.shape) |
|
# (6, 512, 512, 3) |
|
``` |
|
|
|
๊ทธ๋ฌ๊ณ ๋์ CLIP ์ ์๋ฅผ ๊ณ์ฐํฉ๋๋ค. |
|
|
|
```python |
|
from torchmetrics.functional.multimodal import clip_score |
|
from functools import partial |
|
|
|
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16") |
|
|
|
def calculate_clip_score(images, prompts): |
|
images_int = (images * 255).astype("uint8") |
|
clip_score = clip_score_fn(torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts).detach() |
|
return round(float(clip_score), 4) |
|
|
|
sd_clip_score = calculate_clip_score(images, prompts) |
|
print(f"CLIP score: {sd_clip_score}") |
|
# CLIP score: 35.7038 |
|
``` |
|
|
|
์์ ์์ ์์๋ ๊ฐ ํ๋กฌํํธ ๋น ํ๋์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ์ต๋๋ค. ๋ง์ฝ ํ๋กฌํํธ ๋น ์ฌ๋ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ค๋ฉด, ํ๋กฌํํธ ๋น ์์ฑ๋ ์ด๋ฏธ์ง์ ํ๊ท ์ ์๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค. |
|
|
|
์ด์ [`StableDiffusionPipeline`]๊ณผ ํธํ๋๋ ๋ ๊ฐ์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋น๊ตํ๋ ค๋ฉด, ํ์ดํ๋ผ์ธ์ ํธ์ถํ ๋ generator๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค. ๋จผ์ , ๊ณ ์ ๋ ์๋๋ก [v1-4 Stable Diffusion ์ฒดํฌํฌ์ธํธ](https://huggingface.co/CompVis/stable-diffusion-v1-4)๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค: |
|
|
|
```python |
|
seed = 0 |
|
generator = torch.manual_seed(seed) |
|
|
|
images = sd_pipeline(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images |
|
``` |
|
|
|
๊ทธ๋ฐ ๋ค์ [v1-5 checkpoint](https://huggingface.co/runwayml/stable-diffusion-v1-5)๋ฅผ ๋ก๋ํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค: |
|
|
|
```python |
|
model_ckpt_1_5 = "runwayml/stable-diffusion-v1-5" |
|
sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=weight_dtype).to(device) |
|
|
|
images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images |
|
``` |
|
|
|
๊ทธ๋ฆฌ๊ณ ๋ง์ง๋ง์ผ๋ก CLIP ์ ์๋ฅผ ๋น๊ตํฉ๋๋ค: |
|
|
|
```python |
|
sd_clip_score_1_4 = calculate_clip_score(images, prompts) |
|
print(f"CLIP Score with v-1-4: {sd_clip_score_1_4}") |
|
# CLIP Score with v-1-4: 34.9102 |
|
|
|
sd_clip_score_1_5 = calculate_clip_score(images_1_5, prompts) |
|
print(f"CLIP Score with v-1-5: {sd_clip_score_1_5}") |
|
# CLIP Score with v-1-5: 36.2137 |
|
``` |
|
|
|
[v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) ์ฒดํฌํฌ์ธํธ๊ฐ ์ด์ ๋ฒ์ ๋ณด๋ค ๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ด๋ ๊ฒ ๊ฐ์ต๋๋ค. ๊ทธ๋ฌ๋ CLIP ์ ์๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด ์ฌ์ฉํ ํ๋กฌํํธ์ ์๊ฐ ์๋นํ ์ ์ต๋๋ค. ๋ณด๋ค ์ค์ฉ์ ์ธ ํ๊ฐ๋ฅผ ์ํด์๋ ์ด ์๋ฅผ ํจ์ฌ ๋๊ฒ ์ค์ ํ๊ณ , ํ๋กฌํํธ๋ฅผ ๋ค์ํ๊ฒ ์ฌ์ฉํด์ผ ํฉ๋๋ค. |
|
|
|
<Tip warning={true}> |
|
|
|
์ด ์ ์์๋ ๋ช ๊ฐ์ง ์ ํ ์ฌํญ์ด ์์ต๋๋ค. ํ๋ จ ๋ฐ์ดํฐ์
์ ์บก์
์ ์น์์ ํฌ๋กค๋ง๋์ด ์ด๋ฏธ์ง์ ๊ด๋ จ๋ `alt` ๋ฐ ์ ์ฌํ ํ๊ทธ์์ ์ถ์ถ๋์์ต๋๋ค. ์ด๋ค์ ์ธ๊ฐ์ด ์ด๋ฏธ์ง๋ฅผ ์ค๋ช
ํ๋ ๋ฐ ์ฌ์ฉํ ์ ์๋ ๊ฒ๊ณผ ์ผ์นํ์ง ์์ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ ์ฌ๊ธฐ์๋ ๋ช ๊ฐ์ง ํ๋กฌํํธ๋ฅผ "์์ง๋์ด๋ง"ํด์ผ ํ์ต๋๋ค. |
|
|
|
</Tip> |
|
|
|
### ์ด๋ฏธ์ง ์กฐ๊ฑดํ๋ ํ
์คํธ-์ด๋ฏธ์ง ์์ฑ[[image-conditioned-text-to-image-generation]] |
|
|
|
์ด ๊ฒฝ์ฐ, ์์ฑ ํ์ดํ๋ผ์ธ์ ์
๋ ฅ ์ด๋ฏธ์ง์ ํ
์คํธ ํ๋กฌํํธ๋ก ์กฐ๊ฑดํํฉ๋๋ค. [`StableDiffusionInstructPix2PixPipeline`]์ ์๋ก ๋ค์ด๋ณด๊ฒ ์ต๋๋ค. ์ด๋ ํธ์ง ์ง์๋ฌธ์ ์
๋ ฅ ํ๋กฌํํธ๋ก ์ฌ์ฉํ๊ณ ํธ์งํ ์
๋ ฅ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํฉ๋๋ค. |
|
|
|
๋ค์์ ํ๋์ ์์์
๋๋ค: |
|
|
|
![edit-instruction](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-instruction.png) |
|
|
|
๋ชจ๋ธ์ ํ๊ฐํ๋ ํ ๊ฐ์ง ์ ๋ต์ ๋ ์ด๋ฏธ์ง ์บก์
๊ฐ์ ๋ณ๊ฒฝ๊ณผ([CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/abs/2108.00946)์์ ๋ณด์ฌ์ค๋๋ค) ํจ๊ป ๋ ์ด๋ฏธ์ง ์ฌ์ด์ ๋ณ๊ฒฝ์ ์ผ๊ด์ฑ์ ์ธก์ ํ๋ ๊ฒ์
๋๋ค ([CLIP](https://huggingface.co/docs/transformers/model_doc/clip) ๊ณต๊ฐ์์). ์ด๋ฅผ "**CLIP ๋ฐฉํฅ์ฑ ์ ์ฌ์ฑ**"์ด๋ผ๊ณ ํฉ๋๋ค. |
|
|
|
- ์บก์
1์ ํธ์งํ ์ด๋ฏธ์ง (์ด๋ฏธ์ง 1)์ ํด๋นํฉ๋๋ค. |
|
- ์บก์
2๋ ํธ์ง๋ ์ด๋ฏธ์ง (์ด๋ฏธ์ง 2)์ ํด๋นํฉ๋๋ค. ํธ์ง ์ง์๋ฅผ ๋ฐ์ํด์ผ ํฉ๋๋ค. |
|
|
|
๋ค์์ ๊ทธ๋ฆผ์ผ๋ก ๋ ๊ฐ์์
๋๋ค: |
|
|
|
![edit-consistency](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-consistency.png) |
|
|
|
์ฐ๋ฆฌ๋ ์ด ์ธก์ ํญ๋ชฉ์ ๊ตฌํํ๊ธฐ ์ํด ๋ฏธ๋ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ์ค๋นํ์ต๋๋ค. ๋จผ์ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ก๋ํด ๋ณด๊ฒ ์ต๋๋ค. |
|
|
|
```python |
|
from datasets import load_dataset |
|
|
|
dataset = load_dataset("sayakpaul/instructpix2pix-demo", split="train") |
|
dataset.features |
|
``` |
|
|
|
```bash |
|
{'input': Value(dtype='string', id=None), |
|
'edit': Value(dtype='string', id=None), |
|
'output': Value(dtype='string', id=None), |
|
'image': Image(decode=True, id=None)} |
|
``` |
|
|
|
์ฌ๊ธฐ์๋ ๋ค์๊ณผ ๊ฐ์ ํญ๋ชฉ์ด ์์ต๋๋ค: |
|
|
|
- `input`์ `image`์ ํด๋นํ๋ ์บก์
์
๋๋ค. |
|
- `edit`์ ํธ์ง ์ง์์ฌํญ์ ๋ํ๋
๋๋ค. |
|
- `output`์ `edit` ์ง์์ฌํญ์ ๋ฐ์ํ ์์ ๋ ์บก์
์
๋๋ค. |
|
|
|
์ํ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค. |
|
|
|
```python |
|
idx = 0 |
|
print(f"Original caption: {dataset[idx]['input']}") |
|
print(f"Edit instruction: {dataset[idx]['edit']}") |
|
print(f"Modified caption: {dataset[idx]['output']}") |
|
``` |
|
|
|
```bash |
|
Original caption: 2. FAROE ISLANDS: An archipelago of 18 mountainous isles in the North Atlantic Ocean between Norway and Iceland, the Faroe Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills' |
|
Edit instruction: make the isles all white marble |
|
Modified caption: 2. WHITE MARBLE ISLANDS: An archipelago of 18 mountainous white marble isles in the North Atlantic Ocean between Norway and Iceland, the White Marble Islands has 'everything you could hope for', according to Big 7 Travel. It boasts 'crystal clear waterfalls, rocky cliffs that seem to jut out of nowhere and velvety green hills' |
|
``` |
|
|
|
๋ค์์ ์ด๋ฏธ์ง์
๋๋ค: |
|
|
|
```python |
|
dataset[idx]["image"] |
|
``` |
|
|
|
![edit-dataset](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/edit-dataset.png) |
|
|
|
๋จผ์ ํธ์ง ์ง์์ฌํญ์ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ ์ธํธ์ ์ด๋ฏธ์ง๋ฅผ ํธ์งํ๊ณ ๋ฐฉํฅ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค. |
|
|
|
[`StableDiffusionInstructPix2PixPipeline`]๋ฅผ ๋จผ์ ๋ก๋ํฉ๋๋ค: |
|
|
|
```python |
|
from diffusers import StableDiffusionInstructPix2PixPipeline |
|
|
|
instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( |
|
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16 |
|
).to(device) |
|
``` |
|
|
|
์ด์ ํธ์ง์ ์ํํฉ๋๋ค: |
|
|
|
```python |
|
import numpy as np |
|
|
|
|
|
def edit_image(input_image, instruction): |
|
image = instruct_pix2pix_pipeline( |
|
instruction, |
|
image=input_image, |
|
output_type="np", |
|
generator=generator, |
|
).images[0] |
|
return image |
|
|
|
input_images = [] |
|
original_captions = [] |
|
modified_captions = [] |
|
edited_images = [] |
|
|
|
for idx in range(len(dataset)): |
|
input_image = dataset[idx]["image"] |
|
edit_instruction = dataset[idx]["edit"] |
|
edited_image = edit_image(input_image, edit_instruction) |
|
|
|
input_images.append(np.array(input_image)) |
|
original_captions.append(dataset[idx]["input"]) |
|
modified_captions.append(dataset[idx]["output"]) |
|
edited_images.append(edited_image) |
|
``` |
|
๋ฐฉํฅ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด์๋ ๋จผ์ CLIP์ ์ด๋ฏธ์ง์ ํ
์คํธ ์ธ์ฝ๋๋ฅผ ๋ก๋ํฉ๋๋ค: |
|
|
|
```python |
|
from transformers import ( |
|
CLIPTokenizer, |
|
CLIPTextModelWithProjection, |
|
CLIPVisionModelWithProjection, |
|
CLIPImageProcessor, |
|
) |
|
|
|
clip_id = "openai/clip-vit-large-patch14" |
|
tokenizer = CLIPTokenizer.from_pretrained(clip_id) |
|
text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(device) |
|
image_processor = CLIPImageProcessor.from_pretrained(clip_id) |
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device) |
|
``` |
|
|
|
์ฃผ๋ชฉํ ์ ์ ํน์ ํ CLIP ์ฒดํฌํฌ์ธํธ์ธ `openai/clip-vit-large-patch14`๋ฅผ ์ฌ์ฉํ๊ณ ์๋ค๋ ๊ฒ์
๋๋ค. ์ด๋ Stable Diffusion ์ฌ์ ํ๋ จ์ด ์ด CLIP ๋ณํ์ฒด์ ํจ๊ป ์ํ๋์๊ธฐ ๋๋ฌธ์
๋๋ค. ์์ธํ ๋ด์ฉ์ [๋ฌธ์](https://huggingface.co/docs/transformers/model_doc/clip)๋ฅผ ์ฐธ์กฐํ์ธ์. |
|
|
|
๋ค์์ผ๋ก, ๋ฐฉํฅ์ฑ ์ ์ฌ๋๋ฅผ ๊ณ์ฐํ๊ธฐ ์ํด PyTorch์ `nn.Module`์ ์ค๋นํฉ๋๋ค: |
|
|
|
```python |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class DirectionalSimilarity(nn.Module): |
|
def __init__(self, tokenizer, text_encoder, image_processor, image_encoder): |
|
super().__init__() |
|
self.tokenizer = tokenizer |
|
self.text_encoder = text_encoder |
|
self.image_processor = image_processor |
|
self.image_encoder = image_encoder |
|
|
|
def preprocess_image(self, image): |
|
image = self.image_processor(image, return_tensors="pt")["pixel_values"] |
|
return {"pixel_values": image.to(device)} |
|
|
|
def tokenize_text(self, text): |
|
inputs = self.tokenizer( |
|
text, |
|
max_length=self.tokenizer.model_max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
return {"input_ids": inputs.input_ids.to(device)} |
|
|
|
def encode_image(self, image): |
|
preprocessed_image = self.preprocess_image(image) |
|
image_features = self.image_encoder(**preprocessed_image).image_embeds |
|
image_features = image_features / image_features.norm(dim=1, keepdim=True) |
|
return image_features |
|
|
|
def encode_text(self, text): |
|
tokenized_text = self.tokenize_text(text) |
|
text_features = self.text_encoder(**tokenized_text).text_embeds |
|
text_features = text_features / text_features.norm(dim=1, keepdim=True) |
|
return text_features |
|
|
|
def compute_directional_similarity(self, img_feat_one, img_feat_two, text_feat_one, text_feat_two): |
|
sim_direction = F.cosine_similarity(img_feat_two - img_feat_one, text_feat_two - text_feat_one) |
|
return sim_direction |
|
|
|
def forward(self, image_one, image_two, caption_one, caption_two): |
|
img_feat_one = self.encode_image(image_one) |
|
img_feat_two = self.encode_image(image_two) |
|
text_feat_one = self.encode_text(caption_one) |
|
text_feat_two = self.encode_text(caption_two) |
|
directional_similarity = self.compute_directional_similarity( |
|
img_feat_one, img_feat_two, text_feat_one, text_feat_two |
|
) |
|
return directional_similarity |
|
``` |
|
|
|
์ด์ ย `DirectionalSimilarity`๋ฅผ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค. |
|
|
|
```python |
|
dir_similarity = DirectionalSimilarity(tokenizer, text_encoder, image_processor, image_encoder) |
|
scores = [] |
|
|
|
for i in range(len(input_images)): |
|
original_image = input_images[i] |
|
original_caption = original_captions[i] |
|
edited_image = edited_images[i] |
|
modified_caption = modified_captions[i] |
|
|
|
similarity_score = dir_similarity(original_image, edited_image, original_caption, modified_caption) |
|
scores.append(float(similarity_score.detach().cpu())) |
|
|
|
print(f"CLIP directional similarity: {np.mean(scores)}") |
|
# CLIP directional similarity: 0.0797976553440094 |
|
``` |
|
|
|
CLIP ์ ์์ ๋ง์ฐฌ๊ฐ์ง๋ก, CLIP ๋ฐฉํฅ ์ ์ฌ์ฑ์ด ๋์์๋ก ์ข์ต๋๋ค. |
|
|
|
`StableDiffusionInstructPix2PixPipeline`์ `image_guidance_scale`๊ณผ `guidance_scale`์ด๋ผ๋ ๋ ๊ฐ์ง ์ธ์๋ฅผ ๋
ธ์ถ์ํต๋๋ค. ์ด ๋ ์ธ์๋ฅผ ์กฐ์ ํ์ฌ ์ต์ข
ํธ์ง๋ ์ด๋ฏธ์ง์ ํ์ง์ ์ ์ดํ ์ ์์ต๋๋ค. ์ด ๋ ์ธ์์ ์ํฅ์ ์คํํด๋ณด๊ณ ๋ฐฉํฅ ์ ์ฌ์ฑ์ ๋ฏธ์น๋ ์ํฅ์ ํ์ธํด๋ณด๊ธฐ๋ฅผ ๊ถ์ฅํฉ๋๋ค. |
|
|
|
์ด๋ฌํ ๋ฉํธ๋ฆญ์ ๊ฐ๋
์ ํ์ฅํ์ฌ ์๋ณธ ์ด๋ฏธ์ง์ ํธ์ง๋ ๋ฒ์ ์ ์ ์ฌ์ฑ์ ์ธก์ ํ ์ ์์ต๋๋ค. ์ด๋ฅผ ์ํด `F.cosine_similarity(img_feat_two, img_feat_one)`์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ด๋ฌํ ์ข
๋ฅ์ ํธ์ง์์๋ ์ด๋ฏธ์ง์ ์ฃผ์ ์๋ฏธ๊ฐ ์ต๋ํ ๋ณด์กด๋์ด์ผ ํฉ๋๋ค. ์ฆ, ๋์ ์ ์ฌ์ฑ ์ ์๋ฅผ ์ป์ด์ผ ํฉ๋๋ค. |
|
|
|
[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)์ ๊ฐ์ ์ ์ฌํ ํ์ดํ๋ผ์ธ์๋ ์ด๋ฌํ ๋ฉํธ๋ฆญ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. |
|
|
|
<Tip> |
|
|
|
CLIP ์ ์์ CLIP ๋ฐฉํฅ ์ ์ฌ์ฑ ๋ชจ๋ CLIP ๋ชจ๋ธ์ ์์กดํ๊ธฐ ๋๋ฌธ์ ํ๊ฐ๊ฐ ํธํฅ๋ ์ ์์ต๋๋ค |
|
|
|
</Tip> |
|
|
|
***IS, FID (๋์ค์ ์ค๋ช
ํ ์์ ), ๋๋ KID์ ๊ฐ์ ๋ฉํธ๋ฆญ์ ํ์ฅํ๋ ๊ฒ์ ์ด๋ ค์ธ ์ ์์ต๋๋ค***. ํ๊ฐ ์ค์ธ ๋ชจ๋ธ์ด ๋๊ท๋ชจ ์ด๋ฏธ์ง ์บก์
๋ ๋ฐ์ดํฐ์
(์: [LAION-5B ๋ฐ์ดํฐ์
](https://laion.ai/blog/laion-5b/))์์ ์ฌ์ ํ๋ จ๋์์ ๋ ์ด๋ ๋ฌธ์ ๊ฐ ๋ ์ ์์ต๋๋ค. ์๋ํ๋ฉด ์ด๋ฌํ ๋ฉํธ๋ฆญ์ ๊ธฐ๋ฐ์๋ ์ค๊ฐ ์ด๋ฏธ์ง ํน์ง์ ์ถ์ถํ๊ธฐ ์ํด ImageNet-1k ๋ฐ์ดํฐ์
์์ ์ฌ์ ํ๋ จ๋ InceptionNet์ด ์ฌ์ฉ๋๊ธฐ ๋๋ฌธ์
๋๋ค. Stable Diffusion์ ์ฌ์ ํ๋ จ ๋ฐ์ดํฐ์
์ InceptionNet์ ์ฌ์ ํ๋ จ ๋ฐ์ดํฐ์
๊ณผ ๊ฒน์น๋ ๋ถ๋ถ์ด ์ ํ์ ์ผ ์ ์์ผ๋ฏ๋ก ๋ฐ๋ผ์ ์ฌ๊ธฐ์๋ ์ข์ ํ๋ณด๊ฐ ์๋๋๋ค. |
|
|
|
***์์ ๋ฉํธ๋ฆญ์ ์ฌ์ฉํ๋ฉด ํด๋์ค ์กฐ๊ฑด์ด ์๋ ๋ชจ๋ธ์ ํ๊ฐํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, [DiT](https://huggingface.co/docs/diffusers/main/en/api/pipelines/dit). ์ด๋ ImageNet-1k ํด๋์ค์ ์กฐ๊ฑด์ ๊ฑธ๊ณ ์ฌ์ ํ๋ จ๋์์ต๋๋ค.*** |
|
|
|
### ํด๋์ค ์กฐ๊ฑดํ ์ด๋ฏธ์ง ์์ฑ[[class-conditioned-image-generation]] |
|
|
|
ํด๋์ค ์กฐ๊ฑดํ ์์ฑ ๋ชจ๋ธ์ ์ผ๋ฐ์ ์ผ๋ก [ImageNet-1k](https://huggingface.co/datasets/imagenet-1k)์ ๊ฐ์ ํด๋์ค ๋ ์ด๋ธ์ด ์ง์ ๋ ๋ฐ์ดํฐ์
์์ ์ฌ์ ํ๋ จ๋ฉ๋๋ค. ์ด๋ฌํ ๋ชจ๋ธ์ ํ๊ฐํ๋ ์ธ๊ธฐ์๋ ์งํ์๋ Frรฉchet Inception Distance (FID), Kernel Inception Distance (KID) ๋ฐ Inception Score (IS)๊ฐ ์์ต๋๋ค. ์ด ๋ฌธ์์์๋ FID ([Heusel et al.](https://arxiv.org/abs/1706.08500))์ ์ด์ ์ ๋ง์ถ๊ณ ์์ต๋๋ค. [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)์ ์ฌ์ฉํ์ฌ FID๋ฅผ ๊ณ์ฐํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค. ์ด๋ ๋ด๋ถ์ ์ผ๋ก [DiT ๋ชจ๋ธ](https://arxiv.org/abs/2212.09748)์ ์ฌ์ฉํฉ๋๋ค. |
|
|
|
FID๋ ๋ ๊ฐ์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
์ด ์ผ๋ง๋ ์ ์ฌํ์ง๋ฅผ ์ธก์ ํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. [์ด ์๋ฃ](https://mmgeneration.readthedocs.io/en/latest/quick_run.html#fid)์ ๋ฐ๋ฅด๋ฉด: |
|
|
|
> Frรฉchet Inception Distance๋ ๋ ๊ฐ์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
๊ฐ์ ์ ์ฌ์ฑ์ ์ธก์ ํ๋ ์งํ์
๋๋ค. ์๊ฐ์ ํ์ง์ ๋ํ ์ธ๊ฐ ํ๋จ๊ณผ ์ ์๊ด๋๋ ๊ฒ์ผ๋ก ๋ํ๋ฌ์ผ๋ฉฐ, ์ฃผ๋ก ์์ฑ์ ์ ๋ ์ ๊ฒฝ๋ง์ ์ํ ํ์ง์ ํ๊ฐํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค. FID๋ Inception ๋คํธ์ํฌ์ ํน์ง ํํ์ ๋ง๊ฒ ์ ํฉํ ๋ ๊ฐ์ ๊ฐ์ฐ์์ ์ฌ์ด์ Frรฉchet ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ์ฌ ๊ตฌํฉ๋๋ค. |
|
|
|
์ด ๋ ๊ฐ์ ๋ฐ์ดํฐ์
์ ์ค์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
๊ณผ ๊ฐ์ง ์ด๋ฏธ์ง ๋ฐ์ดํฐ์
(์ฐ๋ฆฌ์ ๊ฒฝ์ฐ ์์ฑ๋ ์ด๋ฏธ์ง)์
๋๋ค. FID๋ ์ผ๋ฐ์ ์ผ๋ก ๋ ๊ฐ์ ํฐ ๋ฐ์ดํฐ์
์ผ๋ก ๊ณ์ฐ๋ฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด ๋ฌธ์์์๋ ๋ ๊ฐ์ ๋ฏธ๋ ๋ฐ์ดํฐ์
์ผ๋ก ์์
ํ ๊ฒ์
๋๋ค. |
|
|
|
๋จผ์ ImageNet-1k ํ๋ จ ์ธํธ์์ ๋ช ๊ฐ์ ์ด๋ฏธ์ง๋ฅผ ๋ค์ด๋ก๋ํด ๋ด
์๋ค: |
|
|
|
```python |
|
from zipfile import ZipFile |
|
import requests |
|
|
|
|
|
def download(url, local_filepath): |
|
r = requests.get(url) |
|
with open(local_filepath, "wb") as f: |
|
f.write(r.content) |
|
return local_filepath |
|
|
|
dummy_dataset_url = "https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/sample-imagenet-images.zip" |
|
local_filepath = download(dummy_dataset_url, dummy_dataset_url.split("/")[-1]) |
|
|
|
with ZipFile(local_filepath, "r") as zipper: |
|
zipper.extractall(".") |
|
``` |
|
|
|
```python |
|
from PIL import Image |
|
import os |
|
|
|
dataset_path = "sample-imagenet-images" |
|
image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)]) |
|
|
|
real_images = [np.array(Image.open(path).convert("RGB")) for path in image_paths] |
|
``` |
|
|
|
๋ค์์ ImageNet-1k classes์ ์ด๋ฏธ์ง 10๊ฐ์
๋๋ค : "cassette_player", "chain_saw" (x2), "church", "gas_pump" (x3), "parachute" (x2), ๊ทธ๋ฆฌ๊ณ "tench". |
|
|
|
<p align="center"> |
|
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/real-images.png" alt="real-images"><br> |
|
<em>Real images.</em> |
|
</p> |
|
|
|
์ด์ ์ด๋ฏธ์ง๊ฐ ๋ก๋๋์์ผ๋ฏ๋ก ์ด๋ฏธ์ง์ ๊ฐ๋ฒผ์ด ์ ์ฒ๋ฆฌ๋ฅผ ์ ์ฉํ์ฌ FID ๊ณ์ฐ์ ์ฌ์ฉํด ๋ณด๊ฒ ์ต๋๋ค. |
|
|
|
```python |
|
from torchvision.transforms import functional as F |
|
|
|
|
|
def preprocess_image(image): |
|
image = torch.tensor(image).unsqueeze(0) |
|
image = image.permute(0, 3, 1, 2) / 255.0 |
|
return F.center_crop(image, (256, 256)) |
|
|
|
real_images = torch.cat([preprocess_image(image) for image in real_images]) |
|
print(real_images.shape) |
|
# torch.Size([10, 3, 256, 256]) |
|
``` |
|
|
|
์ด์ ์์์ ์ธ๊ธํ ํด๋์ค์ ๋ฐ๋ผ ์กฐ๊ฑดํ ๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ํด [`DiTPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/dit)๋ฅผ ๋ก๋ํฉ๋๋ค. |
|
|
|
```python |
|
from diffusers import DiTPipeline, DPMSolverMultistepScheduler |
|
|
|
dit_pipeline = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) |
|
dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config) |
|
dit_pipeline = dit_pipeline.to("cuda") |
|
|
|
words = [ |
|
"cassette player", |
|
"chainsaw", |
|
"chainsaw", |
|
"church", |
|
"gas pump", |
|
"gas pump", |
|
"gas pump", |
|
"parachute", |
|
"parachute", |
|
"tench", |
|
] |
|
|
|
class_ids = dit_pipeline.get_label_ids(words) |
|
output = dit_pipeline(class_labels=class_ids, generator=generator, output_type="np") |
|
|
|
fake_images = output.images |
|
fake_images = torch.tensor(fake_images) |
|
fake_images = fake_images.permute(0, 3, 1, 2) |
|
print(fake_images.shape) |
|
# torch.Size([10, 3, 256, 256]) |
|
``` |
|
|
|
์ด์ [`torchmetrics`](https://torchmetrics.readthedocs.io/)๋ฅผ ์ฌ์ฉํ์ฌ FID๋ฅผ ๊ณ์ฐํ ์ ์์ต๋๋ค. |
|
|
|
```python |
|
from torchmetrics.image.fid import FrechetInceptionDistance |
|
|
|
fid = FrechetInceptionDistance(normalize=True) |
|
fid.update(real_images, real=True) |
|
fid.update(fake_images, real=False) |
|
|
|
print(f"FID: {float(fid.compute())}") |
|
# FID: 177.7147216796875 |
|
``` |
|
|
|
FID๋ ๋ฎ์์๋ก ์ข์ต๋๋ค. ์ฌ๋ฌ ๊ฐ์ง ์์๊ฐ FID์ ์ํฅ์ ์ค ์ ์์ต๋๋ค: |
|
|
|
- ์ด๋ฏธ์ง์ ์ (์ค์ ์ด๋ฏธ์ง์ ๊ฐ์ง ์ด๋ฏธ์ง ๋ชจ๋) |
|
- diffusion ๊ณผ์ ์์ ๋ฐ์ํ๋ ๋ฌด์์์ฑ |
|
- diffusion ๊ณผ์ ์์์ ์ถ๋ก ๋จ๊ณ ์ |
|
- diffusion ๊ณผ์ ์์ ์ฌ์ฉ๋๋ ์ค์ผ์ค๋ฌ |
|
|
|
๋ง์ง๋ง ๋ ๊ฐ์ง ์์์ ๋ํด์๋, ๋ค๋ฅธ ์๋์ ์ถ๋ก ๋จ๊ณ์์ ํ๊ฐ๋ฅผ ์คํํ๊ณ ํ๊ท ๊ฒฐ๊ณผ๋ฅผ ๋ณด๊ณ ํ๋ ๊ฒ์ ์ข์ ์ค์ฒ ๋ฐฉ๋ฒ์
๋๋ค |
|
|
|
<Tip warning={true}> |
|
|
|
FID ๊ฒฐ๊ณผ๋ ๋ง์ ์์์ ์์กดํ๊ธฐ ๋๋ฌธ์ ์ทจ์ฝํ ์ ์์ต๋๋ค: |
|
|
|
* ๊ณ์ฐ ์ค ์ฌ์ฉ๋๋ ํน์ Inception ๋ชจ๋ธ. |
|
* ๊ณ์ฐ์ ๊ตฌํ ์ ํ๋. |
|
* ์ด๋ฏธ์ง ํ์ (PNG ๋๋ JPG์์ ์์ํ๋ ๊ฒฝ์ฐ๊ฐ ๋ค๋ฆ
๋๋ค). |
|
|
|
์ด๋ฌํ ์ฌํญ์ ์ผ๋์ ๋๋ฉด, FID๋ ์ ์ฌํ ์คํ์ ๋น๊ตํ ๋ ๊ฐ์ฅ ์ ์ฉํ์ง๋ง, ์ ์๊ฐ FID ์ธก์ ์ฝ๋๋ฅผ ์ฃผ์ ๊น๊ฒ ๊ณต๊ฐํ์ง ์๋ ํ ๋
ผ๋ฌธ ๊ฒฐ๊ณผ๋ฅผ ์ฌํํ๊ธฐ๋ ์ด๋ ต์ต๋๋ค. |
|
|
|
์ด๋ฌํ ์ฌํญ์ KID ๋ฐ IS์ ๊ฐ์ ๋ค๋ฅธ ๊ด๋ จ ๋ฉํธ๋ฆญ์๋ ์ ์ฉ๋ฉ๋๋ค. |
|
|
|
</Tip> |
|
|
|
๋ง์ง๋ง ๋จ๊ณ๋ก, `fake_images`๋ฅผ ์๊ฐ์ ์ผ๋ก ๊ฒ์ฌํด ๋ด
์๋ค. |
|
|
|
<p align="center"> |
|
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/evaluation_diffusion_models/fake-images.png" alt="fake-images"><br> |
|
<em>Fake images.</em> |
|
</p> |