svjack's picture
Upload 1392 files
43b7e92 verified
|
raw
history blame
4.4 kB
# ์—ฌ๋Ÿฌ GPU๋ฅผ ์‚ฌ์šฉํ•œ ๋ถ„์‚ฐ ์ถ”๋ก 
๋ถ„์‚ฐ ์„ค์ •์—์„œ๋Š” ์—ฌ๋Ÿฌ ๊ฐœ์˜ ํ”„๋กฌํ”„ํŠธ๋ฅผ ๋™์‹œ์— ์ƒ์„ฑํ•  ๋•Œ ์œ ์šฉํ•œ ๐Ÿค— [Accelerate](https://huggingface.co/docs/accelerate/index) ๋˜๋Š” [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html)๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ GPU์—์„œ ์ถ”๋ก ์„ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ด ๊ฐ€์ด๋“œ์—์„œ๋Š” ๋ถ„์‚ฐ ์ถ”๋ก ์„ ์œ„ํ•ด ๐Ÿค— Accelerate์™€ PyTorch Distributed๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ๋“œ๋ฆฝ๋‹ˆ๋‹ค.
## ๐Ÿค— Accelerate
๐Ÿค— [Accelerate](https://huggingface.co/docs/accelerate/index)๋Š” ๋ถ„์‚ฐ ์„ค์ •์—์„œ ์ถ”๋ก ์„ ์‰ฝ๊ฒŒ ํ›ˆ๋ จํ•˜๊ฑฐ๋‚˜ ์‹คํ–‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ์„ค๊ณ„๋œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์ž…๋‹ˆ๋‹ค. ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์„ค์ • ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ„์†Œํ™”ํ•˜์—ฌ PyTorch ์ฝ”๋“œ์— ์ง‘์ค‘ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ด์ค๋‹ˆ๋‹ค.
์‹œ์ž‘ํ•˜๋ ค๋ฉด Python ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  [`accelerate.PartialState`]๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜์—ฌ ๋ถ„์‚ฐ ํ™˜๊ฒฝ์„ ์ƒ์„ฑํ•˜๋ฉด, ์„ค์ •์ด ์ž๋™์œผ๋กœ ๊ฐ์ง€๋˜๋ฏ€๋กœ `rank` ๋˜๋Š” `world_size`๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ •์˜ํ•  ํ•„์š”๊ฐ€ ์—†์Šต๋‹ˆ๋‹ค. ['DiffusionPipeline`]์„ `distributed_state.device`๋กœ ์ด๋™ํ•˜์—ฌ ๊ฐ ํ”„๋กœ์„ธ์Šค์— GPU๋ฅผ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.
์ด์ œ ์ปจํ…์ŠคํŠธ ๊ด€๋ฆฌ์ž๋กœ [`~accelerate.PartialState.split_between_processes`] ์œ ํ‹ธ๋ฆฌํ‹ฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ”„๋กœ์„ธ์Šค ์ˆ˜์— ๋”ฐ๋ผ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž๋™์œผ๋กœ ๋ถ„๋ฐฐํ•ฉ๋‹ˆ๋‹ค.
```py
from accelerate import PartialState
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
distributed_state = PartialState()
pipeline.to(distributed_state.device)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
result = pipeline(prompt).images[0]
result.save(f"result_{distributed_state.process_index}.png")
```
Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:
```bash
accelerate launch run_distributed.py --num_processes=2
```
<Tip>์ž์„ธํ•œ ๋‚ด์šฉ์€ [๐Ÿค— Accelerate๋ฅผ ์‚ฌ์šฉํ•œ ๋ถ„์‚ฐ ์ถ”๋ก ](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) ๊ฐ€์ด๋“œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.
</Tip>
## Pytoerch ๋ถ„์‚ฐ
PyTorch๋Š” ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•˜๋Š” [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)์„ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค.
์‹œ์ž‘ํ•˜๋ ค๋ฉด Python ํŒŒ์ผ์„ ์ƒ์„ฑํ•˜๊ณ  `torch.distributed` ๋ฐ `torch.multiprocessing`์„ ์ž„ํฌํŠธํ•˜์—ฌ ๋ถ„์‚ฐ ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์„ ์„ค์ •ํ•˜๊ณ  ๊ฐ GPU์—์„œ ์ถ”๋ก ์šฉ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  [`DiffusionPipeline`]๋„ ์ดˆ๊ธฐํ™”ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค:
ํ™•์‚ฐ ํŒŒ์ดํ”„๋ผ์ธ์„ `rank`๋กœ ์ด๋™ํ•˜๊ณ  `get_rank`๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ํ”„๋กœ์„ธ์Šค์— GPU๋ฅผ ํ• ๋‹นํ•˜๋ฉด ๊ฐ ํ”„๋กœ์„ธ์Šค๊ฐ€ ๋‹ค๋ฅธ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค:
```py
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers import DiffusionPipeline
sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```
์‚ฌ์šฉํ•  ๋ฐฑ์—”๋“œ ์œ ํ˜•, ํ˜„์žฌ ํ”„๋กœ์„ธ์Šค์˜ `rank`, `world_size` ๋˜๋Š” ์ฐธ์—ฌํ•˜๋Š” ํ”„๋กœ์„ธ์Šค ์ˆ˜๋กœ ๋ถ„์‚ฐ ํ™˜๊ฒฝ ์ƒ์„ฑ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜[`init_process_group`]๋ฅผ ๋งŒ๋“ค์–ด ์ถ”๋ก ์„ ์‹คํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
2๊ฐœ์˜ GPU์—์„œ ์ถ”๋ก ์„ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰ํ•˜๋Š” ๊ฒฝ์šฐ `world_size`๋Š” 2์ž…๋‹ˆ๋‹ค.
```py
def run_inference(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
sd.to(rank)
if torch.distributed.get_rank() == 0:
prompt = "a dog"
elif torch.distributed.get_rank() == 1:
prompt = "a cat"
image = sd(prompt).images[0]
image.save(f"./{'_'.join(prompt)}.png")
```
๋ถ„์‚ฐ ์ถ”๋ก ์„ ์‹คํ–‰ํ•˜๋ ค๋ฉด [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn)์„ ํ˜ธ์ถœํ•˜์—ฌ `world_size`์— ์ •์˜๋œ GPU ์ˆ˜์— ๋Œ€ํ•ด `run_inference` ํ•จ์ˆ˜๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:
```py
def main():
world_size = 2
mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
```
์ถ”๋ก  ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์™„๋ฃŒํ–ˆ์œผ๋ฉด `--nproc_per_node` ์ธ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‚ฌ์šฉํ•  GPU ์ˆ˜๋ฅผ ์ง€์ •ํ•˜๊ณ  `torchrun`์„ ํ˜ธ์ถœํ•˜์—ฌ ์Šคํฌ๋ฆฝํŠธ๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:
```bash
torchrun run_distributed.py --nproc_per_node=2
```