svjack's picture
Upload 1392 files
43b7e92 verified
|
raw
history blame
3.34 kB

Deterministic(๊ฒฐ์ •์ ) ์ƒ์„ฑ์„ ํ†ตํ•œ ์ด๋ฏธ์ง€ ํ’ˆ์งˆ ๊ฐœ์„ 

์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€์˜ ํ’ˆ์งˆ์„ ๊ฐœ์„ ํ•˜๋Š” ์ผ๋ฐ˜์ ์ธ ๋ฐฉ๋ฒ•์€ ๊ฒฐ์ •์  batch(๋ฐฐ์น˜) ์ƒ์„ฑ์„ ์‚ฌ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด ๋ฐฉ๋ฒ•์€ ์ด๋ฏธ์ง€ batch(๋ฐฐ์น˜)๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ๋‘ ๋ฒˆ์งธ ์ถ”๋ก  ๋ผ์šด๋“œ์—์„œ ๋” ์ž์„ธํ•œ ํ”„๋กฌํ”„ํŠธ์™€ ํ•จ๊ป˜ ๊ฐœ์„ ํ•  ์ด๋ฏธ์ง€ ํ•˜๋‚˜๋ฅผ ์„ ํƒํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํ•ต์‹ฌ์€ ์ผ๊ด„ ์ด๋ฏธ์ง€ ์ƒ์„ฑ์„ ์œ„ํ•ด ํŒŒ์ดํ”„๋ผ์ธ์— torch.Generator ๋ชฉ๋ก์„ ์ „๋‹ฌํ•˜๊ณ , ๊ฐ Generator๋ฅผ ์‹œ๋“œ์— ์—ฐ๊ฒฐํ•˜์—ฌ ์ด๋ฏธ์ง€์— ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

์˜ˆ๋ฅผ ๋“ค์–ด runwayml/stable-diffusion-v1-5๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์Œ ํ”„๋กฌํ”„ํŠธ์˜ ์—ฌ๋Ÿฌ ๋ฒ„์ „์„ ์ƒ์„ฑํ•ด ๋ด…์‹œ๋‹ค.

prompt = "Labrador in the style of Vermeer"

(๊ฐ€๋Šฅํ•˜๋‹ค๋ฉด) ํŒŒ์ดํ”„๋ผ์ธ์„ [DiffusionPipeline.from_pretrained]๋กœ ์ธ์Šคํ„ด์Šคํ™”ํ•˜์—ฌ GPU์— ๋ฐฐ์น˜ํ•ฉ๋‹ˆ๋‹ค.

>>> from diffusers import DiffusionPipeline

>>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")

์ด์ œ ๋„ค ๊ฐœ์˜ ์„œ๋กœ ๋‹ค๋ฅธ Generator๋ฅผ ์ •์˜ํ•˜๊ณ  ๊ฐ Generator์— ์‹œ๋“œ(0 ~ 3)๋ฅผ ํ• ๋‹นํ•˜์—ฌ ๋‚˜์ค‘์— ํŠน์ • ์ด๋ฏธ์ง€์— ๋Œ€ํ•ด Generator๋ฅผ ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.

>>> import torch

>>> generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]

์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์‚ดํŽด๋ด…๋‹ˆ๋‹ค.

>>> images = pipe(prompt, generator=generator, num_images_per_prompt=4).images
>>> images

img

์ด ์˜ˆ์ œ์—์„œ๋Š” ์ฒซ ๋ฒˆ์งธ ์ด๋ฏธ์ง€๋ฅผ ๊ฐœ์„ ํ–ˆ์ง€๋งŒ ์‹ค์ œ๋กœ๋Š” ์›ํ•˜๋Š” ๋ชจ๋“  ์ด๋ฏธ์ง€๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค(์‹ฌ์ง€์–ด ๋‘ ๊ฐœ์˜ ๋ˆˆ์ด ์žˆ๋Š” ์ด๋ฏธ์ง€๋„!). ์ฒซ ๋ฒˆ์งธ ์ด๋ฏธ์ง€์—์„œ๋Š” ์‹œ๋“œ๊ฐ€ '0'์ธ '์ƒ์„ฑ๊ธฐ'๋ฅผ ์‚ฌ์šฉํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— ๋‘ ๋ฒˆ์งธ ์ถ”๋ก  ๋ผ์šด๋“œ์—์„œ๋Š” ์ด '์ƒ์„ฑ๊ธฐ'๋ฅผ ์žฌ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋ฏธ์ง€์˜ ํ’ˆ์งˆ์„ ๊ฐœ์„ ํ•˜๋ ค๋ฉด ํ”„๋กฌํ”„ํŠธ์— ๋ช‡ ๊ฐ€์ง€ ํ…์ŠคํŠธ๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค:

prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]

์‹œ๋“œ๊ฐ€ 0์ธ ์ œ๋„ˆ๋ ˆ์ดํ„ฐ 4๊ฐœ๋ฅผ ์ƒ์„ฑํ•˜๊ณ , ์ด์ „ ๋ผ์šด๋“œ์˜ ์ฒซ ๋ฒˆ์งธ ์ด๋ฏธ์ง€์ฒ˜๋Ÿผ ๋ณด์ด๋Š” ๋‹ค๋ฅธ ์ด๋ฏธ์ง€ batch(๋ฐฐ์น˜)๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค!

>>> images = pipe(prompt, generator=generator).images
>>> images

img