diffusers-sdxl-controlnet / docs /source /ko /using-diffusers /stable_diffusion_jax_how_to.md
svjack's picture
Upload 1392 files
43b7e92 verified
|
raw
history blame
12.8 kB

JAX / Flaxμ—μ„œμ˜ 🧨 Stable Diffusion!

[[open-in-colab]]

πŸ€— Hugging Face [Diffusers] (https://github.com/huggingface/diffusers) λŠ” 버전 0.5.1λΆ€ν„° Flaxλ₯Ό μ§€μ›ν•©λ‹ˆλ‹€! 이λ₯Ό 톡해 Colab, Kaggle, Google Cloud Platformμ—μ„œ μ‚¬μš©ν•  수 μžˆλŠ” κ²ƒμ²˜λŸΌ Google TPUμ—μ„œ μ΄ˆκ³ μ† 좔둠이 κ°€λŠ₯ν•©λ‹ˆλ‹€.

이 λ…ΈνŠΈλΆμ€ JAX / Flaxλ₯Ό μ‚¬μš©ν•΄ 좔둠을 μ‹€ν–‰ν•˜λŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€. Stable Diffusion의 μž‘λ™ 방식에 λŒ€ν•œ μžμ„Έν•œ λ‚΄μš©μ„ μ›ν•˜κ±°λ‚˜ GPUμ—μ„œ μ‹€ν–‰ν•˜λ €λ©΄ 이 [λ…ΈνŠΈλΆ] ](https://huggingface.co/docs/diffusers/stable_diffusion)을 μ°Έμ‘°ν•˜μ„Έμš”.

λ¨Όμ €, TPU λ°±μ—”λ“œλ₯Ό μ‚¬μš©ν•˜κ³  μžˆλŠ”μ§€ ν™•μΈν•©λ‹ˆλ‹€. Colabμ—μ„œ 이 λ…ΈνŠΈλΆμ„ μ‹€ν–‰ν•˜λŠ” 경우, λ©”λ‰΄μ—μ„œ λŸ°νƒ€μž„μ„ μ„ νƒν•œ λ‹€μŒ "λŸ°νƒ€μž„ μœ ν˜• λ³€κ²½" μ˜΅μ…˜μ„ μ„ νƒν•œ λ‹€μŒ ν•˜λ“œμ›¨μ–΄ 가속기 μ„€μ •μ—μ„œ TPUλ₯Ό μ„ νƒν•©λ‹ˆλ‹€.

JAXλŠ” TPU μ „μš©μ€ μ•„λ‹ˆμ§€λ§Œ 각 TPU μ„œλ²„μ—λŠ” 8개의 TPU 가속기가 λ³‘λ ¬λ‘œ μž‘λ™ν•˜κΈ° λ•Œλ¬Έμ— ν•΄λ‹Ή ν•˜λ“œμ›¨μ–΄μ—μ„œ 더 빛을 λ°œν•œλ‹€λŠ” 점은 μ•Œμ•„λ‘μ„Έμš”.

Setup

λ¨Όμ € diffusersκ°€ μ„€μΉ˜λ˜μ–΄ μžˆλŠ”μ§€ ν™•μΈν•©λ‹ˆλ‹€.

!pip install jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
!pip install diffusers
import jax.tools.colab_tpu

jax.tools.colab_tpu.setup_tpu()
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
    "TPU" in device_type
), "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
Found 8 JAX devices of type Cloud TPU.

그런 λ‹€μŒ λͺ¨λ“  dependenciesλ₯Ό κ°€μ Έμ˜΅λ‹ˆλ‹€.

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

λͺ¨λΈ 뢈러였기

TPU μž₯μΉ˜λŠ” 효율적인 half-float μœ ν˜•μΈ bfloat16을 μ§€μ›ν•©λ‹ˆλ‹€. ν…ŒμŠ€νŠΈμ—λŠ” 이 μœ ν˜•μ„ μ‚¬μš©ν•˜μ§€λ§Œ λŒ€μ‹  float32λ₯Ό μ‚¬μš©ν•˜μ—¬ 전체 정밀도(full precision)λ₯Ό μ‚¬μš©ν•  μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€.

dtype = jnp.bfloat16

FlaxλŠ” ν•¨μˆ˜ν˜• ν”„λ ˆμž„μ›Œν¬μ΄λ―€λ‘œ λͺ¨λΈμ€ λ¬΄μƒνƒœ(stateless)ν˜•μ΄λ©° λ§€κ°œλ³€μˆ˜λŠ” λͺ¨λΈ 외뢀에 μ €μž₯λ©λ‹ˆλ‹€. μ‚¬μ „ν•™μŠ΅λœ Flax νŒŒμ΄ν”„λΌμΈμ„ 뢈러였면 νŒŒμ΄ν”„λΌμΈ μžμ²΄μ™€ λͺ¨λΈ κ°€μ€‘μΉ˜(λ˜λŠ” λ§€κ°œλ³€μˆ˜)κ°€ λͺ¨λ‘ λ°˜ν™˜λ©λ‹ˆλ‹€. μ €ν¬λŠ” bf16 λ²„μ „μ˜ κ°€μ€‘μΉ˜λ₯Ό μ‚¬μš©ν•˜κ³  μžˆμœΌλ―€λ‘œ μœ ν˜• κ²½κ³ κ°€ ν‘œμ‹œλ˜μ§€λ§Œ λ¬΄μ‹œν•΄λ„ λ©λ‹ˆλ‹€.

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

μΆ”λ‘ 

TPUμ—λŠ” 일반적으둜 8개의 λ””λ°”μ΄μŠ€κ°€ λ³‘λ ¬λ‘œ μž‘λ™ν•˜λ―€λ‘œ λ³΄μœ ν•œ λ””λ°”μ΄μŠ€ 수만큼 ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•©λ‹ˆλ‹€. 그런 λ‹€μŒ 각각 ν•˜λ‚˜μ˜ 이미지 생성을 λ‹΄λ‹Ήν•˜λŠ” 8개의 λ””λ°”μ΄μŠ€μ—μ„œ ν•œ λ²ˆμ— 좔둠을 μˆ˜ν–‰ν•©λ‹ˆλ‹€. λ”°λΌμ„œ ν•˜λ‚˜μ˜ 칩이 ν•˜λ‚˜μ˜ 이미지λ₯Ό μƒμ„±ν•˜λŠ” 데 κ±Έλ¦¬λŠ” μ‹œκ°„κ³Ό λ™μΌν•œ μ‹œκ°„μ— 8개의 이미지λ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€.

ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•˜κ³  λ‚˜λ©΄ νŒŒμ΄ν”„λΌμΈμ˜ prepare_inputs ν•¨μˆ˜λ₯Ό ν˜ΈμΆœν•˜μ—¬ ν† ν°ν™”λœ ν…μŠ€νŠΈ IDλ₯Ό μ–»μŠ΅λ‹ˆλ‹€. ν† ν°ν™”λœ ν…μŠ€νŠΈμ˜ κΈΈμ΄λŠ” κΈ°λ³Έ CLIP ν…μŠ€νŠΈ λͺ¨λΈμ˜ ꡬ성에 따라 77ν† ν°μœΌλ‘œ μ„€μ •λ©λ‹ˆλ‹€.

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
(8, 77)

볡사(Replication) 및 μ •λ ¬ν™”

λͺ¨λΈ λ§€κ°œλ³€μˆ˜μ™€ μž…λ ₯값은 μš°λ¦¬κ°€ λ³΄μœ ν•œ 8개의 병렬 μž₯μΉ˜μ— 볡사(Replication)λ˜μ–΄μ•Ό ν•©λ‹ˆλ‹€. λ§€κ°œλ³€μˆ˜ λ”•μ…”λ„ˆλ¦¬λŠ” flax.jax_utils.replicate(λ”•μ…”λ„ˆλ¦¬λ₯Ό μˆœνšŒν•˜λ©° κ°€μ€‘μΉ˜μ˜ λͺ¨μ–‘을 λ³€κ²½ν•˜μ—¬ 8번 λ°˜λ³΅ν•˜λŠ” ν•¨μˆ˜)λ₯Ό μ‚¬μš©ν•˜μ—¬ λ³΅μ‚¬λ©λ‹ˆλ‹€. 배열은 shardλ₯Ό μ‚¬μš©ν•˜μ—¬ λ³΅μ œλ©λ‹ˆλ‹€.

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
(8, 1, 77)

이 shape은 8개의 λ””λ°”μ΄μŠ€ 각각이 shape (1, 77)의 jnp 배열을 μž…λ ₯κ°’μœΌλ‘œ λ°›λŠ”λ‹€λŠ” μ˜λ―Έμž…λ‹ˆλ‹€. 즉 1은 λ””λ°”μ΄μŠ€λ‹Ή batch(배치) ν¬κΈ°μž…λ‹ˆλ‹€. λ©”λͺ¨λ¦¬κ°€ μΆ©λΆ„ν•œ TPUμ—μ„œλŠ” ν•œ λ²ˆμ— μ—¬λŸ¬ 이미지(μΉ©λ‹Ή)λ₯Ό μƒμ„±ν•˜λ €λŠ” 경우 1보닀 클 수 μžˆμŠ΅λ‹ˆλ‹€.

이미지λ₯Ό 생성할 μ€€λΉ„κ°€ 거의 μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€! 이제 생성 ν•¨μˆ˜μ— 전달할 λ‚œμˆ˜ μƒμ„±κΈ°λ§Œ λ§Œλ“€λ©΄ λ©λ‹ˆλ‹€. 이것은 λ‚œμˆ˜λ₯Ό λ‹€λ£¨λŠ” λͺ¨λ“  ν•¨μˆ˜μ— λ‚œμˆ˜ 생성기가 μžˆμ–΄μ•Ό ν•œλ‹€λŠ”, λ‚œμˆ˜μ— λŒ€ν•΄ 맀우 μ§„μ§€ν•˜κ³  독단적인 Flax의 ν‘œμ€€ μ ˆμ°¨μž…λ‹ˆλ‹€. μ΄λ ‡κ²Œ ν•˜λ©΄ μ—¬λŸ¬ λΆ„μ‚°λœ κΈ°κΈ°μ—μ„œ ν›ˆλ ¨ν•  λ•Œμ—λ„ μž¬ν˜„μ„±μ΄ 보μž₯λ©λ‹ˆλ‹€.

μ•„λž˜ 헬퍼 ν•¨μˆ˜λŠ” μ‹œλ“œλ₯Ό μ‚¬μš©ν•˜μ—¬ λ‚œμˆ˜ 생성기λ₯Ό μ΄ˆκΈ°ν™”ν•©λ‹ˆλ‹€. λ™μΌν•œ μ‹œλ“œλ₯Ό μ‚¬μš©ν•˜λŠ” ν•œ μ •ν™•νžˆ λ™μΌν•œ κ²°κ³Όλ₯Ό 얻을 수 μžˆμŠ΅λ‹ˆλ‹€. λ‚˜μ€‘μ— λ…ΈνŠΈλΆμ—μ„œ κ²°κ³Όλ₯Ό 탐색할 λ•Œμ—” λ‹€λ₯Έ μ‹œλ“œλ₯Ό 자유둭게 μ‚¬μš©ν•˜μ„Έμš”.

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rngλ₯Ό 얻은 λ‹€μŒ 8번 'λΆ„ν• 'ν•˜μ—¬ 각 λ””λ°”μ΄μŠ€κ°€ λ‹€λ₯Έ μ œλ„ˆλ ˆμ΄ν„°λ₯Ό μˆ˜μ‹ ν•˜λ„λ‘ ν•©λ‹ˆλ‹€. λ”°λΌμ„œ 각 λ””λ°”μ΄μŠ€λ§ˆλ‹€ λ‹€λ₯Έ 이미지가 μƒμ„±λ˜λ©° 전체 ν”„λ‘œμ„ΈμŠ€λ₯Ό μž¬ν˜„ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX μ½”λ“œλŠ” 맀우 λΉ λ₯΄κ²Œ μ‹€ν–‰λ˜λŠ” 효율적인 ν‘œν˜„μœΌλ‘œ μ»΄νŒŒμΌν•  수 μžˆμŠ΅λ‹ˆλ‹€. ν•˜μ§€λ§Œ 후속 ν˜ΈμΆœμ—μ„œ λͺ¨λ“  μž…λ ₯이 λ™μΌν•œ λͺ¨μ–‘을 갖도둝 ν•΄μ•Ό ν•˜λ©°, 그렇지 μ•ŠμœΌλ©΄ JAXκ°€ μ½”λ“œλ₯Ό λ‹€μ‹œ μ»΄νŒŒμΌν•΄μ•Ό ν•˜λ―€λ‘œ μ΅œμ ν™”λœ 속도λ₯Ό ν™œμš©ν•  수 μ—†μŠ΅λ‹ˆλ‹€.

jit = Trueλ₯Ό 인수둜 μ „λ‹¬ν•˜λ©΄ Flax νŒŒμ΄ν”„λΌμΈμ΄ μ½”λ“œλ₯Ό μ»΄νŒŒμΌν•  수 μžˆμŠ΅λ‹ˆλ‹€. λ˜ν•œ λͺ¨λΈμ΄ μ‚¬μš© κ°€λŠ₯ν•œ 8개의 λ””λ°”μ΄μŠ€μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰λ˜λ„λ‘ 보μž₯ν•©λ‹ˆλ‹€.

λ‹€μŒ 셀을 처음 μ‹€ν–‰ν•˜λ©΄ μ»΄νŒŒμΌν•˜λŠ” 데 μ‹œκ°„μ΄ 였래 κ±Έλ¦¬μ§€λ§Œ 이후 호좜(μž…λ ₯이 λ‹€λ₯Έ κ²½μš°μ—λ„)은 훨씬 λΉ¨λΌμ§‘λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄, ν…ŒμŠ€νŠΈν–ˆμ„ λ•Œ TPU v2-8μ—μ„œ μ»΄νŒŒμΌν•˜λŠ” 데 1λΆ„ 이상 κ±Έλ¦¬μ§€λ§Œ 이후 μΆ”λ‘  μ‹€ν–‰μ—λŠ” μ•½ 7μ΄ˆκ°€ κ±Έλ¦½λ‹ˆλ‹€.

%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
Wall time: 1min 29s

λ°˜ν™˜λœ λ°°μ—΄μ˜ shape은 (8, 1, 512, 512, 3)μž…λ‹ˆλ‹€. 이λ₯Ό μž¬κ΅¬μ„±ν•˜μ—¬ 두 번째 차원을 μ œκ±°ν•˜κ³  512 Γ— 512 Γ— 3의 이미지 8개λ₯Ό 얻은 λ‹€μŒ PIL둜 λ³€ν™˜ν•©λ‹ˆλ‹€.

images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

μ‹œκ°ν™”

이미지λ₯Ό κ·Έλ¦¬λ“œμ— ν‘œμ‹œν•˜λŠ” λ„μš°λ―Έ ν•¨μˆ˜λ₯Ό λ§Œλ“€μ–΄ λ³΄κ² μŠ΅λ‹ˆλ‹€.

def image_grid(imgs, rows, cols):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid
image_grid(images, 2, 4)

img

λ‹€λ₯Έ ν”„λ‘¬ν”„νŠΈ μ‚¬μš©

λͺ¨λ“  λ””λ°”μ΄μŠ€μ—μ„œ λ™μΌν•œ ν”„λ‘¬ν”„νŠΈλ₯Ό λ³΅μ œν•  ν•„μš”λŠ” μ—†μŠ΅λ‹ˆλ‹€. ν”„λ‘¬ν”„νŠΈ 2개λ₯Ό 각각 4λ²ˆμ”© μƒμ„±ν•˜κ±°λ‚˜ ν•œ λ²ˆμ— 8개의 μ„œλ‘œ λ‹€λ₯Έ ν”„λ‘¬ν”„νŠΈλ₯Ό μƒμ„±ν•˜λŠ” λ“± μ›ν•˜λŠ” 것은 무엇이든 ν•  수 μžˆμŠ΅λ‹ˆλ‹€. ν•œλ²ˆ ν•΄λ³΄μ„Έμš”!

λ¨Όμ € μž…λ ₯ μ€€λΉ„ μ½”λ“œλ₯Ό νŽΈλ¦¬ν•œ ν•¨μˆ˜λ‘œ λ¦¬νŒ©ν„°λ§ν•˜κ² μŠ΅λ‹ˆλ‹€:

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

image_grid(images, 2, 4)

img

병렬화(parallelization)λŠ” μ–΄λ–»κ²Œ μž‘λ™ν•˜λŠ”κ°€?

μ•žμ„œ diffusers Flax νŒŒμ΄ν”„λΌμΈμ΄ λͺ¨λΈμ„ μžλ™μœΌλ‘œ μ»΄νŒŒμΌν•˜κ³  μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λ“  κΈ°κΈ°μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰ν•œλ‹€κ³  λ§μ”€λ“œλ ΈμŠ΅λ‹ˆλ‹€. 이제 κ·Έ ν”„λ‘œμ„ΈμŠ€λ₯Ό κ°„λž΅ν•˜κ²Œ μ‚΄νŽ΄λ³΄κ³  μž‘λ™ 방식을 λ³΄μ—¬λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€.

JAX λ³‘λ ¬ν™”λŠ” μ—¬λŸ¬ 가지 λ°©λ²•μœΌλ‘œ μˆ˜ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€. κ°€μž₯ μ‰¬μš΄ 방법은 jax.pmap ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•˜μ—¬ 단일 ν”„λ‘œκ·Έλž¨, 닀쀑 데이터(SPMD) 병렬화λ₯Ό λ‹¬μ„±ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. 즉, λ™μΌν•œ μ½”λ“œμ˜ 볡사본을 각각 λ‹€λ₯Έ 데이터 μž…λ ₯에 λŒ€ν•΄ μ—¬λŸ¬ 개 μ‹€ν–‰ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. 더 μ •κ΅ν•œ μ ‘κ·Ό 방식도 κ°€λŠ₯ν•˜λ―€λ‘œ 관심이 μžˆμœΌμ‹œλ‹€λ©΄ JAX λ¬Έμ„œμ™€ pjit νŽ˜μ΄μ§€μ—μ„œ 이 주제λ₯Ό μ‚΄νŽ΄λ³΄μ‹œκΈ° λ°”λžλ‹ˆλ‹€!

jax.pmap은 두 가지 κΈ°λŠ₯을 μˆ˜ν–‰ν•©λ‹ˆλ‹€:

  • jax.jit()λ₯Ό ν˜ΈμΆœν•œ κ²ƒμ²˜λŸΌ μ½”λ“œλ₯Ό 컴파일(λ˜λŠ” jit)ν•©λ‹ˆλ‹€. 이 μž‘μ—…μ€ pmap을 ν˜ΈμΆœν•  λ•Œκ°€ μ•„λ‹ˆλΌ pmapped ν•¨μˆ˜κ°€ 처음 호좜될 λ•Œ μˆ˜ν–‰λ©λ‹ˆλ‹€.
  • 컴파일된 μ½”λ“œκ°€ μ‚¬μš© κ°€λŠ₯ν•œ λͺ¨λ“  κΈ°κΈ°μ—μ„œ λ³‘λ ¬λ‘œ μ‹€ν–‰λ˜λ„λ‘ ν•©λ‹ˆλ‹€.

μž‘λ™ 방식을 λ³΄μ—¬λ“œλ¦¬κΈ° μœ„ν•΄ 이미지 생성을 μ‹€ν–‰ν•˜λŠ” λΉ„κ³΅κ°œ λ©”μ„œλ“œμΈ νŒŒμ΄ν”„λΌμΈμ˜ _generate λ©”μ„œλ“œλ₯Ό pmapν•©λ‹ˆλ‹€. 이 λ©”μ„œλ“œλŠ” ν–₯ν›„ Diffusers λ¦΄λ¦¬μŠ€μ—μ„œ 이름이 λ³€κ²½λ˜κ±°λ‚˜ 제거될 수 μžˆλ‹€λŠ” 점에 μœ μ˜ν•˜μ„Έμš”.

p_generate = pmap(pipeline._generate)

pmap을 μ‚¬μš©ν•œ ν›„ μ€€λΉ„λœ ν•¨μˆ˜ p_generateλŠ” κ°œλ…μ μœΌλ‘œ λ‹€μŒμ„ μˆ˜ν–‰ν•©λ‹ˆλ‹€:

  • 각 μž₯μΉ˜μ—μ„œ κΈ°λ³Έ ν•¨μˆ˜ pipeline._generate의 볡사본을 ν˜ΈμΆœν•©λ‹ˆλ‹€.
  • 각 μž₯μΉ˜μ— μž…λ ₯ 인수의 λ‹€λ₯Έ 뢀뢄을 λ³΄λƒ…λ‹ˆλ‹€. 이것이 λ°”λ‘œ 샀딩이 μ‚¬μš©λ˜λŠ” μ΄μœ μž…λ‹ˆλ‹€. 이 경우 prompt_ids의 shape은 (8, 1, 77, 768)μž…λ‹ˆλ‹€. 이 배열은 8개둜 λΆ„ν• λ˜κ³  _generate의 각 볡사본은 (1, 77, 768)의 shape을 가진 μž…λ ₯을 λ°›κ²Œ λ©λ‹ˆλ‹€.

λ³‘λ ¬λ‘œ ν˜ΈμΆœλœλ‹€λŠ” 사싀을 μ™„μ „νžˆ λ¬΄μ‹œν•˜κ³  _generateλ₯Ό μ½”λ”©ν•  수 μžˆμŠ΅λ‹ˆλ‹€. batch(배치) 크기(이 μ˜ˆμ œμ—μ„œλŠ” 1)와 μ½”λ“œμ— μ ν•©ν•œ μ°¨μ›λ§Œ μ‹ κ²½ μ“°λ©΄ 되며, λ³‘λ ¬λ‘œ μž‘λ™ν•˜κΈ° μœ„ν•΄ 아무것도 λ³€κ²½ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€.

νŒŒμ΄ν”„λΌμΈ ν˜ΈμΆœμ„ μ‚¬μš©ν•  λ•Œμ™€ λ§ˆμ°¬κ°€μ§€λ‘œ, λ‹€μŒ 셀을 처음 μ‹€ν–‰ν•  λ•ŒλŠ” μ‹œκ°„μ΄ κ±Έλ¦¬μ§€λ§Œ κ·Έ μ΄ν›„μ—λŠ” 훨씬 λΉ¨λΌμ§‘λ‹ˆλ‹€.

%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
Wall time: 1min 15s
images.shape
(8, 1, 512, 512, 3)

JAXλŠ” 비동기 λ””μŠ€νŒ¨μΉ˜λ₯Ό μ‚¬μš©ν•˜κ³  κ°€λŠ₯ν•œ ν•œ 빨리 μ œμ–΄κΆŒμ„ Python 루프에 λ°˜ν™˜ν•˜κΈ° λ•Œλ¬Έμ— μΆ”λ‘  μ‹œκ°„μ„ μ •ν™•ν•˜κ²Œ μΈ‘μ •ν•˜κΈ° μœ„ν•΄ block_until_ready()λ₯Ό μ‚¬μš©ν•©λ‹ˆλ‹€. 아직 κ΅¬μ²΄ν™”λ˜μ§€ μ•Šμ€ 계산 κ²°κ³Όλ₯Ό μ‚¬μš©ν•˜λ €λŠ” 경우 μžλ™μœΌλ‘œ 차단이 μˆ˜ν–‰λ˜λ―€λ‘œ μ½”λ“œμ—μ„œ 이 ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€.