ํ์ดํ๋ผ์ธ, ๋ชจ๋ธ, ์ค์ผ์ค๋ฌ ๋ถ๋ฌ์ค๊ธฐ
๊ธฐ๋ณธ์ ์ผ๋ก diffusion ๋ชจ๋ธ์ ๋ค์ํ ์ปดํฌ๋ํธ๋ค(๋ชจ๋ธ, ํ ํฌ๋์ด์ , ์ค์ผ์ค๋ฌ) ๊ฐ์ ๋ณต์กํ ์ํธ์์ฉ์ ๊ธฐ๋ฐ์ผ๋ก ๋์ํฉ๋๋ค. ๋ํจ์ ์ค(Diffusers)๋ ์ด๋ฌํ diffusion ๋ชจ๋ธ์ ๋ณด๋ค ์ฝ๊ณ ๊ฐํธํ API๋ก ์ ๊ณตํ๋ ๊ฒ์ ๋ชฉํ๋ก ์ค๊ณ๋์์ต๋๋ค. [DiffusionPipeline
]์ diffusion ๋ชจ๋ธ์ด ๊ฐ๋ ๋ณต์ก์ฑ์ ํ๋์ ํ์ดํ๋ผ์ธ API๋ก ํตํฉํ๊ณ , ๋์์ ์ด๋ฅผ ๊ตฌ์ฑํ๋ ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ ํ์คํฌ์ ๋ง์ถฐ ์ ์ฐํ๊ฒ ์ปค์คํฐ๋ง์ด์งํ ์ ์๋๋ก ์ง์ํ๊ณ ์์ต๋๋ค.
diffusion ๋ชจ๋ธ์ ํ๋ จ๊ณผ ์ถ๋ก ์ ํ์ํ ๋ชจ๋ ๊ฒ์ [DiffusionPipeline.from_pretrained
] ๋ฉ์๋๋ฅผ ํตํด ์ ๊ทผํ ์ ์์ต๋๋ค. (์ด ๋ง์ ์๋ฏธ๋ ๋ค์ ๋จ๋ฝ์์ ๋ณด๋ค ์์ธํ๊ฒ ๋ค๋ค๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.)
์ด ๋ฌธ์์์๋ ์ค๋ช ํ ๋ด์ฉ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
ํ๋ธ๋ฅผ ํตํด ํน์ ๋ก์ปฌ๋ก ํ์ดํ๋ผ์ธ์ ๋ถ๋ฌ์ค๋ ๋ฒ
ํ์ดํ๋ผ์ธ์ ๋ค๋ฅธ ์ปดํฌ๋ํธ๋ค์ ์ ์ฉํ๋ ๋ฒ
์ค๋ฆฌ์ง๋ ์ฒดํฌํฌ์ธํธ๊ฐ ์๋ variant๋ฅผ ๋ถ๋ฌ์ค๋ ๋ฒ (variant๋ ๊ธฐ๋ณธ์ผ๋ก ์ค์ ๋
fp32
๊ฐ ์๋ ๋ค๋ฅธ ๋ถ๋ ์์์ ํ์ (์:fp16
)์ ์ฌ์ฉํ๊ฑฐ๋ Non-EMA ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ๋ค์ ์๋ฏธํฉ๋๋ค.)๋ชจ๋ธ๊ณผ ์ค์ผ์ค๋ฌ๋ฅผ ๋ถ๋ฌ์ค๋ ๋ฒ
Diffusion ํ์ดํ๋ผ์ธ
๐ก [DiffusionPipeline
] ํด๋์ค๊ฐ ๋์ํ๋ ๋ฐฉ์์ ๋ณด๋ค ์์ธํ ๋ด์ฉ์ด ๊ถ๊ธํ๋ค๋ฉด, DiffusionPipeline explained ์น์
์ ํ์ธํด๋ณด์ธ์.
[DiffusionPipeline
] ํด๋์ค๋ diffusion ๋ชจ๋ธ์ ํ๋ธ๋ก๋ถํฐ ๋ถ๋ฌ์ค๋ ๊ฐ์ฅ ์ฌํํ๋ฉด์ ๋ณดํธ์ ์ธ ๋ฐฉ์์
๋๋ค. [DiffusionPipeline.from_pretrained
] ๋ฉ์๋๋ ์ ํฉํ ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ์๋์ผ๋ก ํ์งํ๊ณ , ํ์ํ ๊ตฌ์ฑ์์(configuration)์ ๊ฐ์ค์น(weight) ํ์ผ๋ค์ ๋ค์ด๋ก๋ํ๊ณ ์บ์ฑํ ๋ค์, ํด๋น ํ์ดํ๋ผ์ธ ์ธ์คํด์ค๋ฅผ ๋ฐํํฉ๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = DiffusionPipeline.from_pretrained(repo_id)
๋ฌผ๋ก [DiffusionPipeline
] ํด๋์ค๋ฅผ ์ฌ์ฉํ์ง ์๊ณ , ๋ช
์์ ์ผ๋ก ์ง์ ํด๋น ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ๋ถ๋ฌ์ค๋ ๊ฒ๋ ๊ฐ๋ฅํฉ๋๋ค. ์๋ ์์ ์ฝ๋๋ ์ ์์์ ๋์ผํ ์ธ์คํด์ค๋ฅผ ๋ฐํํฉ๋๋ค.
from diffusers import StableDiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(repo_id)
CompVis/stable-diffusion-v1-4์ด๋ runwayml/stable-diffusion-v1-5 ๊ฐ์ ์ฒดํฌํฌ์ธํธ๋ค์ ๊ฒฝ์ฐ, ํ๋ ์ด์์ ๋ค์ํ ํ์คํฌ์ ํ์ฉ๋ ์ ์์ต๋๋ค. (์๋ฅผ ๋ค์ด ์์ ๋ ์ฒดํฌํฌ์ธํธ์ ๊ฒฝ์ฐ, text-to-image์ image-to-image์ ๋ชจ๋ ํ์ฉ๋ ์ ์์ต๋๋ค.) ๋ง์ฝ ์ด๋ฌํ ์ฒดํฌํฌ์ธํธ๋ค์ ๊ธฐ๋ณธ ์ค์ ํ์คํฌ๊ฐ ์๋ ๋ค๋ฅธ ํ์คํฌ์ ํ์ฉํ๊ณ ์ ํ๋ค๋ฉด, ํด๋น ํ์คํฌ์ ๋์๋๋ ํ์ดํ๋ผ์ธ(task-specific pipeline)์ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
from diffusers import StableDiffusionImg2ImgPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(repo_id)
๋ก์ปฌ ํ์ดํ๋ผ์ธ
ํ์ดํ๋ผ์ธ์ ๋ก์ปฌ๋ก ๋ถ๋ฌ์ค๊ณ ์ ํ๋ค๋ฉด, git-lfs
๋ฅผ ์ฌ์ฉํ์ฌ ์ง์ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ก์ปฌ ๋์คํฌ์ ๋ค์ด๋ก๋ ๋ฐ์์ผ ํฉ๋๋ค. ์๋์ ๋ช
๋ น์ด๋ฅผ ์คํํ๋ฉด ./stable-diffusion-v1-5
๋ ์ด๋ฆ์ผ๋ก ํด๋๊ฐ ๋ก์ปฌ๋์คํฌ์ ์์ฑ๋ฉ๋๋ค.
git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5
๊ทธ๋ฐ ๋ค์ ํด๋น ๋ก์ปฌ ๊ฒฝ๋ก๋ฅผ [~DiffusionPipeline.from_pretrained
] ๋ฉ์๋์ ์ ๋ฌํฉ๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "./stable-diffusion-v1-5"
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
์์ ์์์ฝ๋์ฒ๋ผ ๋ง์ฝ repo_id
๊ฐ ๋ก์ปฌ ํจ์ค(local path)๋ผ๋ฉด, [~DiffusionPipeline.from_pretrained
] ๋ฉ์๋๋ ์ด๋ฅผ ์๋์ผ๋ก ๊ฐ์งํ์ฌ ํ๋ธ์์ ํ์ผ์ ๋ค์ด๋ก๋ํ์ง ์์ต๋๋ค. ๋ง์ฝ ๋ก์ปฌ ๋์คํฌ์ ์ ์ฅ๋ ํ์ดํ๋ผ์ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์ต์ ๋ฒ์ ์ด ์๋ ๊ฒฝ์ฐ์๋, ์ต์ ๋ฒ์ ์ ๋ค์ด๋ก๋ํ์ง ์๊ณ ๊ธฐ์กด ๋ก์ปฌ ๋์คํฌ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฌ์ฉํ๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
ํ์ดํ๋ผ์ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ ๊ต์ฒดํ๊ธฐ
ํ์ดํ๋ผ์ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ๋ค์ ํธํ ๊ฐ๋ฅํ ๋ค๋ฅธ ์ปดํฌ๋ํธ๋ก ๊ต์ฒด๋ ์ ์์ต๋๋ค. ์ด์ ๊ฐ์ ์ปดํฌ๋ํธ ๊ต์ฒด๊ฐ ์ค์ํ ์ด์ ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ์ด๋ค ์ค์ผ์ค๋ฌ๋ฅผ ์ฌ์ฉํ ๊ฒ์ธ๊ฐ๋ ์์ฑ์๋์ ์์ฑํ์ง ๊ฐ์ ํธ๋ ์ด๋์คํ๋ฅผ ์ ์ํ๋ ์ค์ํ ์์์ ๋๋ค.
- diffusion ๋ชจ๋ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ๋ค์ ์ผ๋ฐ์ ์ผ๋ก ๊ฐ๊ฐ ๋ ๋ฆฝ์ ์ผ๋ก ํ๋ จ๋๊ธฐ ๋๋ฌธ์, ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ ์ปดํฌ๋ํธ๊ฐ ์๋ค๋ฉด ๊ทธ๊ฑธ๋ก ๊ต์ฒดํ๋ ์์ผ๋ก ์ฑ๋ฅ์ ํฅ์์ํฌ ์ ์์ต๋๋ค.
- ํ์ธ ํ๋ ๋จ๊ณ์์๋ ์ผ๋ฐ์ ์ผ๋ก UNet ํน์ ํ ์คํธ ์ธ์ฝ๋์ ๊ฐ์ ์ผ๋ถ ์ปดํฌ๋ํธ๋ค๋ง ํ๋ จํ๊ฒ ๋ฉ๋๋ค.
์ด๋ค ์ค์ผ์ค๋ฌ๋ค์ด ํธํ๊ฐ๋ฅํ์ง๋ compatibles
์์ฑ์ ํตํด ํ์ธํ ์ ์์ต๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id)
stable_diffusion.scheduler.compatibles
์ด๋ฒ์๋ [SchedulerMixin.from_pretrained
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด์, ๊ธฐ์กด ๊ธฐ๋ณธ ์ค์ผ์ค๋ฌ์๋ [PNDMScheduler
]๋ฅผ ๋ณด๋ค ์ฐ์ํ ์ฑ๋ฅ์ [EulerDiscreteScheduler
]๋ก ๋ฐ๊ฟ๋ด
์๋ค. ์ค์ผ์ค๋ฌ๋ฅผ ๋ก๋ํ ๋๋ subfolder
์ธ์๋ฅผ ํตํด, ํด๋น ํ์ดํ๋ผ์ธ์ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ์ค์ผ์ค๋ฌ์ ๊ดํ ํ์ํด๋๋ฅผ ๋ช
์ํด์ฃผ์ด์ผ ํฉ๋๋ค.
๊ทธ ๋ค์ ์๋กญ๊ฒ ์์ฑํ [EulerDiscreteScheduler
] ์ธ์คํด์ค๋ฅผ [DiffusionPipeline
]์ scheduler
์ธ์์ ์ ๋ฌํฉ๋๋ค.
from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultistepScheduler
repo_id = "runwayml/stable-diffusion-v1-5"
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
์ธ์ดํํฐ ์ฒด์ปค
์คํ
์ด๋ธ diffusion๊ณผ ๊ฐ์ diffusion ๋ชจ๋ธ๋ค์ ์ ํดํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ ์๋ ์์ต๋๋ค. ์ด๋ฅผ ์๋ฐฉํ๊ธฐ ์ํด ๋ํจ์ ์ค๋ ์์ฑ๋ ์ด๋ฏธ์ง์ ์ ํด์ฑ์ ํ๋จํ๋ ์ธ์ดํํฐ ์ฒด์ปค(safety checker) ๊ธฐ๋ฅ์ ์ง์ํ๊ณ ์์ต๋๋ค. ๋ง์ฝ ์ธ์ดํํฐ ์ฒด์ปค์ ์ฌ์ฉ์ ์ํ์ง ์๋๋ค๋ฉด, safety_checker
์ธ์์ None
์ ์ ๋ฌํด์ฃผ์๋ฉด ๋ฉ๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, safety_checker=None)
์ปดํฌ๋ํธ ์ฌ์ฌ์ฉ
๋ณต์์ ํ์ดํ๋ผ์ธ์ ๋์ผํ ๋ชจ๋ธ์ด ๋ฐ๋ณต์ ์ผ๋ก ์ฌ์ฉํ๋ค๋ฉด, ๊ตณ์ด ํด๋น ๋ชจ๋ธ์ ๋์ผํ ๊ฐ์ค์น๋ฅผ ์ค๋ณต์ผ๋ก RAM์ ๋ถ๋ฌ์ฌ ํ์๋ ์์ ๊ฒ์
๋๋ค. [~DiffusionPipeline.components
] ์์ฑ์ ํตํด ํ์ดํ๋ผ์ธ ๋ด๋ถ์ ์ปดํฌ๋ํธ๋ค์ ์ฐธ์กฐํ ์ ์๋๋ฐ, ์ด๋ฒ ๋จ๋ฝ์์๋ ์ด๋ฅผ ํตํด ๋์ผํ ๋ชจ๋ธ ๊ฐ์ค์น๋ฅผ RAM์ ์ค๋ณต์ผ๋ก ๋ถ๋ฌ์ค๋ ๊ฒ์ ๋ฐฉ์งํ๋ ๋ฒ์ ๋ํด ์์๋ณด๊ฒ ์ต๋๋ค.
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
model_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
components = stable_diffusion_txt2img.components
๊ทธ ๋ค์ ์ ์์ ์ฝ๋์์ ์ ์ธํ components
๋ณ์๋ฅผ ๋ค๋ฅธ ํ์ดํ๋ผ์ธ์ ์ ๋ฌํจ์ผ๋ก์จ, ๋ชจ๋ธ์ ๊ฐ์ค์น๋ฅผ ์ค๋ณต์ผ๋ก RAM์ ๋ก๋ฉํ์ง ์๊ณ , ๋์ผํ ์ปดํฌ๋ํธ๋ฅผ ์ฌ์ฌ์ฉํ ์ ์์ต๋๋ค.
stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(**components)
๋ฌผ๋ก ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ ๋ฐ๋ก ๋ฐ๋ก ํ์ดํ๋ผ์ธ์ ์ ๋ฌํ ์๋ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด stable_diffusion_txt2img
ํ์ดํ๋ผ์ธ ์์ ์ปดํฌ๋ํธ๋ค ๊ฐ์ด๋ฐ์ ์ธ์ดํํฐ ์ฒด์ปค(safety_checker
)์ ํผ์ณ ์ต์คํธ๋ํฐ(feature_extractor
)๋ฅผ ์ ์ธํ ์ปดํฌ๋ํธ๋ค๋ง stable_diffusion_img2img
ํ์ดํ๋ผ์ธ์์ ์ฌ์ฌ์ฉํ๋ ๋ฐฉ์ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
model_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
vae=stable_diffusion_txt2img.vae,
text_encoder=stable_diffusion_txt2img.text_encoder,
tokenizer=stable_diffusion_txt2img.tokenizer,
unet=stable_diffusion_txt2img.unet,
scheduler=stable_diffusion_txt2img.scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
Checkpoint variants
Variant๋ ์ผ๋ฐ์ ์ผ๋ก ๋ค์๊ณผ ๊ฐ์ ์ฒดํฌํฌ์ธํธ๋ค์ ์๋ฏธํฉ๋๋ค.
torch.float16
๊ณผ ๊ฐ์ด ์ ๋ฐ๋๋ ๋ ๋ฎ์ง๋ง, ์ฉ๋ ์ญ์ ๋ ์์ ๋ถ๋์์์ ํ์ ์ ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ. (๋ค๋ง ์ด์ ๊ฐ์ variant์ ๊ฒฝ์ฐ, ์ถ๊ฐ์ ์ธ ํ๋ จ๊ณผ CPUํ๊ฒฝ์์์ ๊ตฌ๋์ด ๋ถ๊ฐ๋ฅํฉ๋๋ค.)- Non-EMA ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ. (Non-EMA ๊ฐ์ค์น์ ๊ฒฝ์ฐ, ํ์ธ ํ๋ ๋จ๊ณ์์ ์ฌ์ฉํ๋ ๊ฒ์ด ๊ถ์ฅ๋๋๋ฐ, ์ถ๋ก ๋จ๊ณ์์ ์ฌ์ฉํ์ง ์๋ ๊ฒ์ด ๊ถ์ฅ๋ฉ๋๋ค.)
๐ก ๋ชจ๋ธ ๊ตฌ์กฐ๋ ๋์ผํ์ง๋ง ์๋ก ๋ค๋ฅธ ํ์ต ํ๊ฒฝ์์ ์๋ก ๋ค๋ฅธ ๋ฐ์ดํฐ์
์ผ๋ก ํ์ต๋ ์ฒดํฌํฌ์ธํธ๋ค์ด ์์ ๊ฒฝ์ฐ, ํด๋น ์ฒดํฌํฌ์ธํธ๋ค์ variant ๋จ๊ณ๊ฐ ์๋ ๋ฆฌํฌ์งํ ๋ฆฌ ๋จ๊ณ์์ ๋ถ๋ฆฌ๋์ด ๊ด๋ฆฌ๋์ด์ผ ํฉ๋๋ค. (์ฆ, ํด๋น ์ฒดํฌํฌ์ธํธ๋ค์ ์๋ก ๋ค๋ฅธ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ๋ฐ๋ก ๊ด๋ฆฌ๋์ด์ผ ํฉ๋๋ค. ์์: [stable-diffusion-v1-4
], [stable-diffusion-v1-5
]).
checkpoint type | weight name | argument for loading weights |
---|---|---|
original | diffusion_pytorch_model.bin | |
floating point | diffusion_pytorch_model.fp16.bin | variant , torch_dtype |
non-EMA | diffusion_pytorch_model.non_ema.bin | variant |
variant๋ฅผ ๋ก๋ํ ๋ 2๊ฐ์ ์ค์ํ argument๊ฐ ์์ต๋๋ค.
torch_dtype
์ ๋ถ๋ฌ์ฌ ์ฒดํฌํฌ์ธํธ์ ๋ถ๋์์์ ์ ์ ์ํฉ๋๋ค. ์๋ฅผ ๋ค์ดtorch_dtype=torch.float16
์ ๋ช ์ํจ์ผ๋ก์จ ๊ฐ์ค์น์ ๋ถ๋์์์ ํ์ ์fl16
์ผ๋ก ๋ณํํ ์ ์์ต๋๋ค. (๋ง์ฝ ๋ฐ๋ก ์ค์ ํ์ง ์์ ๊ฒฝ์ฐ, ๊ธฐ๋ณธ๊ฐ์ผ๋กfp32
ํ์ ์ ๊ฐ์ค์น๊ฐ ๋ก๋ฉ๋ฉ๋๋ค.) ๋ํvariant
์ธ์๋ฅผ ๋ช ์ํ์ง ์์ ์ฑ๋ก ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์จ ๋ค์, ํด๋น ์ฒดํฌํฌ์ธํธ๋ฅผtorch_dtype=torch.float16
์ธ์๋ฅผ ํตํดfp16
ํ์ ์ผ๋ก ๋ณํํ๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค. ์ด ๊ฒฝ์ฐ ๊ธฐ๋ณธ์ผ๋ก ์ค์ ๋fp32
๊ฐ์ค์น๊ฐ ๋จผ์ ๋ค์ด๋ก๋๋๊ณ , ํด๋น ๊ฐ์ค์น๋ค์ ๋ถ๋ฌ์จ ๋ค์fp16
ํ์ ์ผ๋ก ๋ณํํ๊ฒ ๋ฉ๋๋ค.variant
์ธ์๋ ๋ฆฌํฌ์งํ ๋ฆฌ์์ ์ด๋ค variant๋ฅผ ๋ถ๋ฌ์ฌ ๊ฒ์ธ๊ฐ๋ฅผ ์ ์ํฉ๋๋ค. ๊ฐ๋ นdiffusers/stable-diffusion-variants
๋ฆฌํฌ์งํ ๋ฆฌ๋ก๋ถํฐnon_ema
์ฒดํฌํฌ์ธํธ๋ฅผ ๋ถ๋ฌ์ค๊ณ ์ ํ๋ค๋ฉด,variant="non_ema"
์ธ์๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค.
from diffusers import DiffusionPipeline
# load fp16 variant
stable_diffusion = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
)
# load non_ema variant
stable_diffusion = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema")
๋ค๋ฅธ ๋ถ๋์์์ ํ์
์ ๊ฐ์ค์น ํน์ non-EMA ๊ฐ์ค์น๋ฅผ ์ฌ์ฉํ๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ๊ธฐ ์ํด์๋, [DiffusionPipeline.save_pretrained
] ๋ฉ์๋๋ฅผ ์ฌ์ฉํด์ผ ํ๋ฉฐ, ์ด ๋ variant
์ธ์๋ฅผ ๋ช
์ํด์ค์ผ ํฉ๋๋ค. ์๋์ ์ฒดํฌํฌ์ธํธ์ ๋์ผํ ํด๋์ variant๋ฅผ ์ ์ฅํด์ผ ํ๋ฉฐ, ์ด๋ ๊ฒ ํ๋ฉด ๋์ผํ ํด๋์์ ์ค๋ฆฌ์ง๋ ์ฒดํฌํฌ์ธํธ๊ณผ variant๋ฅผ ๋ชจ๋ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค.
from diffusers import DiffusionPipeline
# save as fp16 variant
stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="fp16")
# save as non-ema variant
stable_diffusion.save_pretrained("runwayml/stable-diffusion-v1-5", variant="non_ema")
๋ง์ฝ variant๋ฅผ ๊ธฐ์กด ํด๋์ ์ ์ฅํ์ง ์์ ๊ฒฝ์ฐ, variant
์ธ์๋ฅผ ๋ฐ๋์ ๋ช
์ํด์ผ ํฉ๋๋ค. ๊ทธ๋ ๊ฒ ํ์ง ์์ ๊ฒฝ์ฐ ์๋์ ์ค๋ฆฌ์ง๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ ์ ์๊ฒ ๋๊ธฐ ๋๋ฌธ์ ์๋ฌ๊ฐ ๋ฐ์ํฉ๋๋ค.
# ๐ this won't work
stable_diffusion = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5", torch_dtype=torch.float16)
# ๐ this works
stable_diffusion = DiffusionPipeline.from_pretrained(
"./stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
)
๋ชจ๋ธ ๋ถ๋ฌ์ค๊ธฐ
๋ชจ๋ธ๋ค์ [ModelMixin.from_pretrained
] ๋ฉ์๋๋ฅผ ํตํด ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. ํด๋น ๋ฉ์๋๋ ์ต์ ๋ฒ์ ์ ๋ชจ๋ธ ๊ฐ์ค์น ํ์ผ๊ณผ ์ค์ ํ์ผ(configurations)์ ๋ค์ด๋ก๋ํ๊ณ ์บ์ฑํฉ๋๋ค. ๋ง์ฝ ์ด๋ฌํ ํ์ผ๋ค์ด ์ต์ ๋ฒ์ ์ผ๋ก ๋ก์ปฌ ์บ์์ ์ ์ฅ๋์ด ์๋ค๋ฉด, [ModelMixin.from_pretrained
]๋ ๊ตณ์ด ํด๋น ํ์ผ๋ค์ ๋ค์ ๋ค์ด๋ก๋ํ์ง ์์ผ๋ฉฐ, ๊ทธ์ ์บ์์ ์๋ ์ต์ ํ์ผ๋ค์ ์ฌ์ฌ์ฉํฉ๋๋ค.
๋ชจ๋ธ์ subfolder
์ธ์์ ๋ช
์๋ ํ์ ํด๋๋ก๋ถํฐ ๋ก๋๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด runwayml/stable-diffusion-v1-5
์ UNet ๋ชจ๋ธ์ ๊ฐ์ค์น๋ unet
ํด๋์ ์ ์ฅ๋์ด ์์ต๋๋ค.
from diffusers import UNet2DConditionModel
repo_id = "runwayml/stable-diffusion-v1-5"
model = UNet2DConditionModel.from_pretrained(repo_id, subfolder="unet")
ํน์ ํด๋น ๋ชจ๋ธ์ ๋ฆฌํฌ์งํ ๋ฆฌ๋ก๋ถํฐ ๋ค์ด๋ ํธ๋ก ๊ฐ์ ธ์ค๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
from diffusers import UNet2DModel
repo_id = "google/ddpm-cifar10-32"
model = UNet2DModel.from_pretrained(repo_id)
๋ํ ์์ ๋ดค๋ variant
์ธ์๋ฅผ ๋ช
์ํจ์ผ๋ก์จ, Non-EMA๋ fp16
์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค.
from diffusers import UNet2DConditionModel
model = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet", variant="non-ema")
model.save_pretrained("./local-unet", variant="non-ema")
์ค์ผ์ค๋ฌ
์ค์ผ์ค๋ฌ๋ค์ [SchedulerMixin.from_pretrained
] ๋ฉ์๋๋ฅผ ํตํด ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. ๋ชจ๋ธ๊ณผ ๋ฌ๋ฆฌ ์ค์ผ์ค๋ฌ๋ ๋ณ๋์ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ์์ผ๋ฉฐ, ๋ฐ๋ผ์ ๋น์ฐํ ๋ณ๋์ ํ์ต๊ณผ์ ์ ์๊ตฌํ์ง ์์ต๋๋ค. ์ด๋ฌํ ์ค์ผ์ค๋ฌ๋ค์ (ํด๋น ์ค์ผ์ค๋ฌ ํ์ํด๋์) configration ํ์ผ์ ํตํด ์ ์๋ฉ๋๋ค.
์ฌ๋ฌ๊ฐ์ ์ค์ผ์ค๋ฌ๋ฅผ ๋ถ๋ฌ์จ๋ค๊ณ ํด์ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋ชจํ๋ ๊ฒ์ ์๋๋ฉฐ, ๋ค์ํ ์ค์ผ์ค๋ฌ๋ค์ ๋์ผํ ์ค์ผ์ค๋ฌ configration์ ์ ์ฉํ๋ ๊ฒ ์ญ์ ๊ฐ๋ฅํฉ๋๋ค. ๋ค์ ์์ ์ฝ๋์์ ๋ถ๋ฌ์ค๋ ์ค์ผ์ค๋ฌ๋ค์ ๋ชจ๋ [StableDiffusionPipeline
]๊ณผ ํธํ๋๋๋ฐ, ์ด๋ ๊ณง ํด๋น ์ค์ผ์ค๋ฌ๋ค์ ๋์ผํ ์ค์ผ์ค๋ฌ configration ํ์ผ์ ์ ์ฉํ ์ ์์์ ์๋ฏธํฉ๋๋ค.
from diffusers import StableDiffusionPipeline
from diffusers import (
DDPMScheduler,
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
repo_id = "runwayml/stable-diffusion-v1-5"
ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")
# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler_anc`, `euler`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
DiffusionPipeline์ ๋ํด ์์๋ณด๊ธฐ
ํด๋์ค ๋ฉ์๋๋ก์ [DiffusionPipeline.from_pretrained
]์ 2๊ฐ์ง๋ฅผ ๋ด๋นํฉ๋๋ค.
- ์ฒซ์งธ๋ก,
from_pretrained
๋ฉ์๋๋ ์ต์ ๋ฒ์ ์ ํ์ดํ๋ผ์ธ์ ๋ค์ด๋ก๋ํ๊ณ , ์บ์์ ์ ์ฅํฉ๋๋ค. ์ด๋ฏธ ๋ก์ปฌ ์บ์์ ์ต์ ๋ฒ์ ์ ํ์ดํ๋ผ์ธ์ด ์ ์ฅ๋์ด ์๋ค๋ฉด, [DiffusionPipeline.from_pretrained
]์ ํด๋น ํ์ผ๋ค์ ๋ค์ ๋ค์ด๋ก๋ํ์ง ์๊ณ , ๋ก์ปฌ ์บ์์ ์ ์ฅ๋์ด ์๋ ํ์ดํ๋ผ์ธ์ ๋ถ๋ฌ์ต๋๋ค. model_index.json
ํ์ผ์ ํตํด ์ฒดํฌํฌ์ธํธ์ ๋์๋๋ ์ ํฉํ ํ์ดํ๋ผ์ธ ํด๋์ค๋ก ๋ถ๋ฌ์ต๋๋ค.
ํ์ดํ๋ผ์ธ์ ํด๋ ๊ตฌ์กฐ๋ ํด๋น ํ์ดํ๋ผ์ธ ํด๋์ค์ ๊ตฌ์กฐ์ ์ง์ ์ ์ผ๋ก ์ผ์นํฉ๋๋ค. ์๋ฅผ ๋ค์ด [StableDiffusionPipeline
] ํด๋์ค๋ runwayml/stable-diffusion-v1-5
๋ฆฌํฌ์งํ ๋ฆฌ์ ๋์๋๋ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ต๋๋ค.
from diffusers import DiffusionPipeline
repo_id = "runwayml/stable-diffusion-v1-5"
pipeline = DiffusionPipeline.from_pretrained(repo_id)
print(pipeline)
์์ ์ฝ๋ ์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํด๋ณด๋ฉด, pipeline
์ [StableDiffusionPipeline
]์ ์ธ์คํด์ค์ด๋ฉฐ, ๋ค์๊ณผ ๊ฐ์ด ์ด 7๊ฐ์ ์ปดํฌ๋ํธ๋ก ๊ตฌ์ฑ๋๋ค๋ ๊ฒ์ ์ ์ ์์ต๋๋ค.
"feature_extractor"
: [~transformers.CLIPFeatureExtractor
]์ ์ธ์คํด์ค"safety_checker"
: ์ ํดํ ์ปจํ ์ธ ๋ฅผ ์คํฌ๋ฆฌ๋ํ๊ธฐ ์ํ ์ปดํฌ๋ํธ"scheduler"
: [PNDMScheduler
]์ ์ธ์คํด์ค"text_encoder"
: [~transformers.CLIPTextModel
]์ ์ธ์คํด์ค"tokenizer"
: a [~transformers.CLIPTokenizer
]์ ์ธ์คํด์ค"unet"
: [UNet2DConditionModel
]์ ์ธ์คํด์ค"vae"
[AutoencoderKL
]์ ์ธ์คํด์ค
StableDiffusionPipeline {
"feature_extractor": [
"transformers",
"CLIPImageProcessor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}
ํ์ดํ๋ผ์ธ ์ธ์คํด์ค์ ์ปดํฌ๋ํธ๋ค์ runwayml/stable-diffusion-v1-5
์ ํด๋ ๊ตฌ์กฐ์ ๋น๊ตํด๋ณผ ๊ฒฝ์ฐ, ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ง๋ค ๋ณ๋์ ํด๋๊ฐ ์์์ ํ์ธํ ์ ์์ต๋๋ค.
.
โโโ feature_extractor
โ โโโ preprocessor_config.json
โโโ model_index.json
โโโ safety_checker
โ โโโ config.json
โ โโโ pytorch_model.bin
โโโ scheduler
โ โโโ scheduler_config.json
โโโ text_encoder
โ โโโ config.json
โ โโโ pytorch_model.bin
โโโ tokenizer
โ โโโ merges.txt
โ โโโ special_tokens_map.json
โ โโโ tokenizer_config.json
โ โโโ vocab.json
โโโ unet
โ โโโ config.json
โ โโโ diffusion_pytorch_model.bin
โโโ vae
โโโ config.json
โโโ diffusion_pytorch_model.bin
๋ํ ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ ํ์ดํ๋ผ์ธ ์ธ์คํด์ค์ ์์ฑ์ผ๋ก์จ ์ฐธ์กฐํ ์ ์์ต๋๋ค.
pipeline.tokenizer
CLIPTokenizer(
name_or_path="/root/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/39593d5650112b4cc580433f6b0435385882d819/tokenizer",
vocab_size=49408,
model_max_length=77,
is_fast=False,
padding_side="right",
truncation_side="right",
special_tokens={
"bos_token": AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
"eos_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
"unk_token": AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True),
"pad_token": "<|endoftext|>",
},
)
๋ชจ๋ ํ์ดํ๋ผ์ธ์ model_index.json
ํ์ผ์ ํตํด [DiffusionPipeline
]์ ๋ค์๊ณผ ๊ฐ์ ์ ๋ณด๋ฅผ ์ ๋ฌํฉ๋๋ค.
_class_name
๋ ์ด๋ค ํ์ดํ๋ผ์ธ ํด๋์ค๋ฅผ ์ฌ์ฉํด์ผ ํ๋์ง์ ๋ํด ์๋ ค์ค๋๋ค._diffusers_version
๋ ์ด๋ค ๋ฒ์ ์ ๋ํจ์ ์ค๋ก ํ์ดํ๋ผ์ธ ์์ ๋ชจ๋ธ๋ค์ด ๋ง๋ค์ด์ก๋์ง๋ฅผ ์๋ ค์ค๋๋ค.- ๊ทธ ๋ค์์ ๊ฐ๊ฐ์ ์ปดํฌ๋ํธ๋ค์ด ์ด๋ค ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ์ด๋ค ํด๋์ค๋ก ๋ง๋ค์ด์ก๋์ง์ ๋ํด ์๋ ค์ค๋๋ค. (์๋ ์์์์
"feature_extractor" : ["transformers", "CLIPImageProcessor"]
์ ๊ฒฝ์ฐ,feature_extractor
์ปดํฌ๋ํธ๋transformers
๋ผ์ด๋ธ๋ฌ๋ฆฌ์CLIPImageProcessor
ํด๋์ค๋ฅผ ํตํด ๋ง๋ค์ด์ก๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.)
{
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.6.0",
"feature_extractor": [
"transformers",
"CLIPImageProcessor"
],
"safety_checker": [
"stable_diffusion",
"StableDiffusionSafetyChecker"
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}