LN3Diff / guided_diffusion /dist_util.py
NIRVANALAN
release file
87c126b
"""
Helpers for distributed training.
"""
import datetime
import io
import os
import socket
import blobfile as bf
from pdb import set_trace as st
# from mpi4py import MPI
import torch as th
import torch.distributed as dist
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def synchronize():
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def setup_dist(args):
"""
Setup a distributed process group.
"""
if dist.is_initialized():
return
# print(f"{os.environ['MASTER_ADDR']=} {args.master_port=}")
# dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count(), timeout=datetime.timedelta(seconds=5400))
# st() no mark
dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000))
print(f"{args.local_rank=} init complete")
# synchronize() # extra memory on rank 0, why?
th.cuda.empty_cache()
def cleanup():
dist.destroy_process_group()
def dev():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
if get_world_size() > 1:
return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}")
return th.device(f"cuda")
return th.device("cpu")
# def load_state_dict(path, submodule_name='', **kwargs):
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
# chunk_size = 2 ** 30 # MPI has a relatively small size limit
# if get_rank() == 0:
# with bf.BlobFile(path, "rb") as f:
# data = f.read()
# num_chunks = len(data) // chunk_size
# if len(data) % chunk_size:
# num_chunks += 1
# MPI.COMM_WORLD.bcast(num_chunks)
# for i in range(0, len(data), chunk_size):
# MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
# else:
# num_chunks = MPI.COMM_WORLD.bcast(None)
# data = bytes()
# for _ in range(num_chunks):
# data += MPI.COMM_WORLD.bcast(None)
# return th.load(io.BytesIO(data), **kwargs)
# with open(path) as f:
ckpt = th.load(path, **kwargs)
# if submodule_name != '':
# assert submodule_name in ckpt
# return ckpt[submodule_name]
# else:
return ckpt
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
# for k, p in params:
for p in params:
with th.no_grad():
try:
dist.broadcast(p, 0)
except Exception as e:
print(k, e)
# print(e)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares]
_reduce_dtype = th.float32 # Data type to use for initial per-tensor reduction.
_counter_dtype = th.float64 # Data type to use for the internal counters.
_rank = 0 # Rank of the current process.
_sync_device = None # Device to use for multiprocess communication. None = single-process.
_sync_called = False # Has _sync() been called yet?
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor
def init_multiprocessing(rank, sync_device):
r"""Initializes `utils.torch_utils.training_stats` for collecting statistics
across multiple processes.
This function must be called after
`torch.distributed.init_process_group()` and before `Collector.update()`.
The call is not necessary if multi-process collection is not needed.
Args:
rank: Rank of the current process.
sync_device: PyTorch device to use for inter-process
communication, or None to disable multi-process
collection. Typically `torch.device('cuda', rank)`.
"""
global _rank, _sync_device
assert not _sync_called
_rank = rank
_sync_device = sync_device