Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Helpers for distributed training. | |
""" | |
import datetime | |
import io | |
import os | |
import socket | |
import blobfile as bf | |
from pdb import set_trace as st | |
# from mpi4py import MPI | |
import torch as th | |
import torch.distributed as dist | |
# Change this to reflect your cluster layout. | |
# The GPU for a given rank is (rank % GPUS_PER_NODE). | |
GPUS_PER_NODE = 8 | |
SETUP_RETRY_COUNT = 3 | |
def get_rank(): | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def synchronize(): | |
if not dist.is_available(): | |
return | |
if not dist.is_initialized(): | |
return | |
world_size = dist.get_world_size() | |
if world_size == 1: | |
return | |
dist.barrier() | |
def get_world_size(): | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def setup_dist(args): | |
""" | |
Setup a distributed process group. | |
""" | |
if dist.is_initialized(): | |
return | |
# print(f"{os.environ['MASTER_ADDR']=} {args.master_port=}") | |
# dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count(), timeout=datetime.timedelta(seconds=5400)) | |
# st() no mark | |
dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000)) | |
print(f"{args.local_rank=} init complete") | |
# synchronize() # extra memory on rank 0, why? | |
th.cuda.empty_cache() | |
def cleanup(): | |
dist.destroy_process_group() | |
def dev(): | |
""" | |
Get the device to use for torch.distributed. | |
""" | |
if th.cuda.is_available(): | |
if get_world_size() > 1: | |
return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}") | |
return th.device(f"cuda") | |
return th.device("cpu") | |
# def load_state_dict(path, submodule_name='', **kwargs): | |
def load_state_dict(path, **kwargs): | |
""" | |
Load a PyTorch file without redundant fetches across MPI ranks. | |
""" | |
# chunk_size = 2 ** 30 # MPI has a relatively small size limit | |
# if get_rank() == 0: | |
# with bf.BlobFile(path, "rb") as f: | |
# data = f.read() | |
# num_chunks = len(data) // chunk_size | |
# if len(data) % chunk_size: | |
# num_chunks += 1 | |
# MPI.COMM_WORLD.bcast(num_chunks) | |
# for i in range(0, len(data), chunk_size): | |
# MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) | |
# else: | |
# num_chunks = MPI.COMM_WORLD.bcast(None) | |
# data = bytes() | |
# for _ in range(num_chunks): | |
# data += MPI.COMM_WORLD.bcast(None) | |
# return th.load(io.BytesIO(data), **kwargs) | |
# with open(path) as f: | |
ckpt = th.load(path, **kwargs) | |
# if submodule_name != '': | |
# assert submodule_name in ckpt | |
# return ckpt[submodule_name] | |
# else: | |
return ckpt | |
def sync_params(params): | |
""" | |
Synchronize a sequence of Tensors across ranks from rank 0. | |
""" | |
# for k, p in params: | |
for p in params: | |
with th.no_grad(): | |
try: | |
dist.broadcast(p, 0) | |
except Exception as e: | |
print(k, e) | |
# print(e) | |
def _find_free_port(): | |
try: | |
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
s.bind(("", 0)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
return s.getsockname()[1] | |
finally: | |
s.close() | |
_num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] | |
_reduce_dtype = th.float32 # Data type to use for initial per-tensor reduction. | |
_counter_dtype = th.float64 # Data type to use for the internal counters. | |
_rank = 0 # Rank of the current process. | |
_sync_device = None # Device to use for multiprocess communication. None = single-process. | |
_sync_called = False # Has _sync() been called yet? | |
_counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor | |
_cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor | |
def init_multiprocessing(rank, sync_device): | |
r"""Initializes `utils.torch_utils.training_stats` for collecting statistics | |
across multiple processes. | |
This function must be called after | |
`torch.distributed.init_process_group()` and before `Collector.update()`. | |
The call is not necessary if multi-process collection is not needed. | |
Args: | |
rank: Rank of the current process. | |
sync_device: PyTorch device to use for inter-process | |
communication, or None to disable multi-process | |
collection. Typically `torch.device('cuda', rank)`. | |
""" | |
global _rank, _sync_device | |
assert not _sync_called | |
_rank = rank | |
_sync_device = sync_device |