svjack's picture
Upload 1392 files
43b7e92 verified
|
raw
history blame
9.57 kB

์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์— ๊ธฐ์—ฌํ•˜๋Š” ๋ฐฉ๋ฒ•

๐Ÿ’ก ๋ชจ๋“  ์‚ฌ๋žŒ์ด ์†๋„ ์ €ํ•˜ ์—†์ด ์‰ฝ๊ฒŒ ์ž‘์—…์„ ๊ณต์œ ํ•  ์ˆ˜ ์žˆ๋„๋ก ์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ถ”๊ฐ€ํ•˜๋Š” ์ด์œ ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์€ GitHub ์ด์Šˆ #841๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.

์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•˜๋ฉด [DiffusionPipeline] ์œ„์— ์›ํ•˜๋Š” ์ถ”๊ฐ€ ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. DiffusionPipeline ์œ„์— ๊ตฌ์ถ•ํ•  ๋•Œ์˜ ๊ฐ€์žฅ ํฐ ์žฅ์ ์€ ๋ˆ„๊ตฌ๋‚˜ ์ธ์ˆ˜๋ฅผ ํ•˜๋‚˜๋งŒ ์ถ”๊ฐ€ํ•˜๋ฉด ํŒŒ์ดํ”„๋ผ์ธ์„ ๋กœ๋“œํ•˜๊ณ  ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์–ด ์ปค๋ฎค๋‹ˆํ‹ฐ๊ฐ€ ๋งค์šฐ ์‰ฝ๊ฒŒ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์ด๋ฒˆ ๊ฐ€์ด๋“œ์—์„œ๋Š” ์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•๊ณผ ์ž‘๋™ ์›๋ฆฌ๋ฅผ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ๊ฐ„๋‹จํ•˜๊ฒŒ ์„ค๋ช…ํ•˜๊ธฐ ์œ„ํ•ด UNet์ด ๋‹จ์ผ forward pass๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ  ์Šค์ผ€์ค„๋Ÿฌ๋ฅผ ํ•œ ๋ฒˆ ํ˜ธ์ถœํ•˜๋Š” "one-step" ํŒŒ์ดํ”„๋ผ์ธ์„ ๋งŒ๋“ค๊ฒ ์Šต๋‹ˆ๋‹ค.

ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™”

์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ์œ„ํ•œ one_step_unet.py ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ์ด ํŒŒ์ผ์—์„œ, Hub์—์„œ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜์™€ ์Šค์ผ€์ค„๋Ÿฌ ๊ตฌ์„ฑ์„ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ๋„๋ก [DiffusionPipeline]์„ ์ƒ์†ํ•˜๋Š” ํŒŒ์ดํ”„๋ผ์ธ ํด๋ž˜์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. one-step ํŒŒ์ดํ”„๋ผ์ธ์—๋Š” UNet๊ณผ ์Šค์ผ€์ค„๋Ÿฌ๊ฐ€ ํ•„์š”ํ•˜๋ฏ€๋กœ ์ด๋ฅผ __init__ ํ•จ์ˆ˜์— ์ธ์ˆ˜๋กœ ์ถ”๊ฐ€ํ•ด์•ผํ•ฉ๋‹ˆ๋‹ค:

from diffusers import DiffusionPipeline
import torch


class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
    def __init__(self, unet, scheduler):
        super().__init__()

ํŒŒ์ดํ”„๋ผ์ธ๊ณผ ๊ทธ ๊ตฌ์„ฑ์š”์†Œ(unet and scheduler)๋ฅผ [~DiffusionPipeline.save_pretrained]์œผ๋กœ ์ €์žฅํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋ ค๋ฉด register_modules ํ•จ์ˆ˜์— ์ถ”๊ฐ€ํ•˜์„ธ์š”:

  from diffusers import DiffusionPipeline
  import torch

  class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
      def __init__(self, unet, scheduler):
          super().__init__()

+         self.register_modules(unet=unet, scheduler=scheduler)

์ด์ œ '์ดˆ๊ธฐํ™”' ๋‹จ๊ณ„๊ฐ€ ์™„๋ฃŒ๋˜์—ˆ์œผ๋‹ˆ forward pass๋กœ ์ด๋™ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค! ๐Ÿ”ฅ

Forward pass ์ •์˜

Forward pass ์—์„œ๋Š”(__call__๋กœ ์ •์˜ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค) ์›ํ•˜๋Š” ๊ธฐ๋Šฅ์„ ์ถ”๊ฐ€ํ•  ์ˆ˜ ์žˆ๋Š” ์™„์ „ํ•œ ์ฐฝ์ž‘ ์ž์œ ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์šฐ๋ฆฌ์˜ ๋†€๋ผ์šด one-step ํŒŒ์ดํ”„๋ผ์ธ์˜ ๊ฒฝ์šฐ, ์ž„์˜์˜ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ณ  timestep=1์„ ์„ค์ •ํ•˜์—ฌ unet๊ณผ scheduler๋ฅผ ํ•œ ๋ฒˆ๋งŒ ํ˜ธ์ถœํ•ฉ๋‹ˆ๋‹ค:

  from diffusers import DiffusionPipeline
  import torch


  class UnetSchedulerOneForwardPipeline(DiffusionPipeline):
      def __init__(self, unet, scheduler):
          super().__init__()

          self.register_modules(unet=unet, scheduler=scheduler)

+     def __call__(self):
+         image = torch.randn(
+             (1, self.unet.config.in_channels, self.unet.config.sample_size, self.unet.config.sample_size),
+         )
+         timestep = 1

+         model_output = self.unet(image, timestep).sample
+         scheduler_output = self.scheduler.step(model_output, timestep, image).prev_sample

+         return scheduler_output

๋๋‚ฌ์Šต๋‹ˆ๋‹ค! ๐Ÿš€ ์ด์ œ ์ด ํŒŒ์ดํ”„๋ผ์ธ์— unet๊ณผ scheduler๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

from diffusers import DDPMScheduler, UNet2DModel

scheduler = DDPMScheduler()
unet = UNet2DModel()

pipeline = UnetSchedulerOneForwardPipeline(unet=unet, scheduler=scheduler)

output = pipeline()

ํ•˜์ง€๋งŒ ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์กฐ๊ฐ€ ๋™์ผํ•œ ๊ฒฝ์šฐ ๊ธฐ์กด ๊ฐ€์ค‘์น˜๋ฅผ ํŒŒ์ดํ”„๋ผ์ธ์— ๋กœ๋“œํ•  ์ˆ˜ ์žˆ๋‹ค๋Š” ์žฅ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด one-step ํŒŒ์ดํ”„๋ผ์ธ์— google/ddpm-cifar10-32 ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

pipeline = UnetSchedulerOneForwardPipeline.from_pretrained("google/ddpm-cifar10-32")

output = pipeline()

ํŒŒ์ดํ”„๋ผ์ธ ๊ณต์œ 

๐ŸงจDiffusers ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์—์„œ Pull Request๋ฅผ ์—ด์–ด examples/community ํ•˜์œ„ ํด๋”์— one_step_unet.py์˜ ๋ฉ‹์ง„ ํŒŒ์ดํ”„๋ผ์ธ์„ ์ถ”๊ฐ€ํ•˜์„ธ์š”.

๋ณ‘ํ•ฉ์ด ๋˜๋ฉด, diffusers >= 0.4.0์ด ์„ค์น˜๋œ ์‚ฌ์šฉ์ž๋ผ๋ฉด ๋ˆ„๊ตฌ๋‚˜ custom_pipeline ์ธ์ˆ˜์— ์ง€์ •ํ•˜์—ฌ ์ด ํŒŒ์ดํ”„๋ผ์ธ์„ ๋งˆ์ˆ ์ฒ˜๋Ÿผ ๐Ÿช„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค:

from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="one_step_unet")
pipe()

์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ ๊ณต์œ ํ•˜๋Š” ๋˜ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์€ Hub ์—์„œ ์„ ํ˜ธํ•˜๋Š” ๋ชจ๋ธ ๋ฆฌํฌ์ง€ํ† ๋ฆฌ์— ์ง์ ‘ one_step_unet.py ํŒŒ์ผ์„ ์—…๋กœ๋“œํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. one_step_unet.py ํŒŒ์ผ์„ ์ง€์ •ํ•˜๋Š” ๋Œ€์‹  ๋ชจ๋ธ ์ €์žฅ์†Œ id๋ฅผ custom_pipeline ์ธ์ˆ˜์— ์ „๋‹ฌํ•˜์„ธ์š”:

from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cifar10-32", custom_pipeline="stevhliu/one_step_unet")

๋‹ค์Œ ํ‘œ์—์„œ ๋‘ ๊ฐ€์ง€ ๊ณต์œ  ์›Œํฌํ”Œ๋กœ์šฐ๋ฅผ ๋น„๊ตํ•˜์—ฌ ์ž์‹ ์—๊ฒŒ ๊ฐ€์žฅ ์ ํ•ฉํ•œ ์˜ต์…˜์„ ๊ฒฐ์ •ํ•˜๋Š” ๋ฐ ๋„์›€์ด ๋˜๋Š” ์ •๋ณด๋ฅผ ํ™•์ธํ•˜์„ธ์š”:

GitHub ์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ HF Hub ์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ
์‚ฌ์šฉ๋ฒ• ๋™์ผ ๋™์ผ
๋ฆฌ๋ทฐ ๊ณผ์ • ๋ณ‘ํ•ฉํ•˜๊ธฐ ์ „์— GitHub์—์„œ Pull Request๋ฅผ ์—ด๊ณ  Diffusers ํŒ€์˜ ๊ฒ€ํ†  ๊ณผ์ •์„ ๊ฑฐ์นฉ๋‹ˆ๋‹ค. ์†๋„๊ฐ€ ๋Š๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ฒ€ํ†  ์—†์ด Hub ์ €์žฅ์†Œ์— ๋ฐ”๋กœ ์—…๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. ๊ฐ€์žฅ ๋น ๋ฅธ ์›Œํฌํ”Œ๋กœ์šฐ ์ž…๋‹ˆ๋‹ค.
๊ฐ€์‹œ์„ฑ ๊ณต์‹ Diffusers ์ €์žฅ์†Œ ๋ฐ ๋ฌธ์„œ์— ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. HF ํ—ˆ๋ธŒ ํ”„๋กœํ•„์— ํฌํ•จ๋˜๋ฉฐ ๊ฐ€์‹œ์„ฑ์„ ํ™•๋ณดํ•˜๊ธฐ ์œ„ํ•ด ์ž์‹ ์˜ ์‚ฌ์šฉ๋Ÿ‰/ํ”„๋กœ๋ชจ์…˜์— ์˜์กดํ•ฉ๋‹ˆ๋‹ค.

๐Ÿ’ก ์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ ํŒŒ์ผ์— ์›ํ•˜๋Š” ํŒจํ‚ค์ง€๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž๊ฐ€ ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•˜๊ธฐ๋งŒ ํ•˜๋ฉด ๋ชจ๋“  ๊ฒƒ์ด ์ •์ƒ์ ์œผ๋กœ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค. ํŒŒ์ดํ”„๋ผ์ธ์ด ์ž๋™์œผ๋กœ ๊ฐ์ง€๋˜๋ฏ€๋กœ DiffusionPipeline์—์„œ ์ƒ์†ํ•˜๋Š” ํŒŒ์ดํ”„๋ผ์ธ ํด๋ž˜์Šค๊ฐ€ ํ•˜๋‚˜๋งŒ ์žˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.

์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์€ ์–ด๋–ป๊ฒŒ ์ž‘๋™ํ•˜๋‚˜์š”?

์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์€ [DiffusionPipeline]์„ ์ƒ์†ํ•˜๋Š” ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค:

  • [custom_pipeline] ์ธ์ˆ˜๋กœ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  • ๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋ฐ ์Šค์ผ€์ค„๋Ÿฌ ๊ตฌ์„ฑ์€ [pretrained_model_name_or_path]์—์„œ ๋กœ๋“œ๋ฉ๋‹ˆ๋‹ค.
  • ์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์—์„œ ๊ธฐ๋Šฅ์„ ๊ตฌํ˜„ํ•˜๋Š” ์ฝ”๋“œ๋Š” pipeline.py ํŒŒ์ผ์— ์ •์˜๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

๊ณต์‹ ์ €์žฅ์†Œ์—์„œ ๋ชจ๋“  ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ ์š”์†Œ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œํ•  ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ฒฝ์šฐ ๋‹ค๋ฅธ ๊ตฌ์„ฑ ์š”์†Œ๋Š” ํŒŒ์ดํ”„๋ผ์ธ์— ์ง์ ‘ ์ „๋‹ฌํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:

from diffusers import DiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPModel

model_id = "CompVis/stable-diffusion-v1-4"
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"

feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id)
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)

pipeline = DiffusionPipeline.from_pretrained(
    model_id,
    custom_pipeline="clip_guided_stable_diffusion",
    clip_model=clip_model,
    feature_extractor=feature_extractor,
    scheduler=scheduler,
    torch_dtype=torch.float16,
)

์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์˜ ๋งˆ๋ฒ•์€ ๋‹ค์Œ ์ฝ”๋“œ์— ๋‹ด๊ฒจ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด ์ปค๋ฎค๋‹ˆํ‹ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ GitHub ๋˜๋Š” Hub์—์„œ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๋ชจ๋“  ๐Ÿงจ Diffusers ํŒจํ‚ค์ง€์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

# 2. ํŒŒ์ดํ”„๋ผ์ธ ํด๋ž˜์Šค๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค. ์‚ฌ์šฉ์ž ์ง€์ • ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ Hub์—์„œ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค
# ๋ช…์‹œ์  ํด๋ž˜์Šค์—์„œ ๋กœ๋“œํ•˜๋Š” ๊ฒฝ์šฐ, ์ด๋ฅผ ์‚ฌ์šฉํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
if custom_pipeline is not None:
    pipeline_class = get_class_from_dynamic_module(
        custom_pipeline, module_file=CUSTOM_PIPELINE_FILE_NAME, cache_dir=custom_pipeline
    )
elif cls != DiffusionPipeline:
    pipeline_class = cls
else:
    diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
    pipeline_class = getattr(diffusers_module, config_dict["_class_name"])