sat3density / imaginaire /utils /distributed.py
venite's picture
initial
f670afc
raw
history blame
3.27 kB
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools
import ctypes
import torch
import torch.distributed as dist
def init_dist(local_rank, backend='nccl', **kwargs):
r"""Initialize distributed training"""
if dist.is_available():
if dist.is_initialized():
return torch.cuda.current_device()
torch.cuda.set_device(local_rank)
dist.init_process_group(backend=backend, init_method='env://', **kwargs)
# Increase the L2 fetch granularity for faster speed.
_libcudart = ctypes.CDLL('libcudart.so')
# Set device limit on the current device
# cudaLimitMaxL2FetchGranularity = 0x05
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
# assert pValue.contents.value == 128
def get_rank():
r"""Get rank of the thread."""
rank = 0
if dist.is_available():
if dist.is_initialized():
rank = dist.get_rank()
return rank
def get_world_size():
r"""Get world size. How many GPUs are available in this job."""
world_size = 1
if dist.is_available():
if dist.is_initialized():
world_size = dist.get_world_size()
return world_size
def master_only(func):
r"""Apply this function only to the master GPU."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
r"""Simple function wrapper for the master function"""
if get_rank() == 0:
return func(*args, **kwargs)
else:
return None
return wrapper
def is_master():
r"""check if current process is the master"""
return get_rank() == 0
def is_local_master():
return torch.cuda.current_device() == 0
@master_only
def master_only_print(*args):
r"""master-only print"""
print(*args)
def dist_reduce_tensor(tensor, rank=0, reduce='mean'):
r""" Reduce to rank 0 """
world_size = get_world_size()
if world_size < 2:
return tensor
with torch.no_grad():
dist.reduce(tensor, dst=rank)
if get_rank() == rank:
if reduce == 'mean':
tensor /= world_size
elif reduce == 'sum':
pass
else:
raise NotImplementedError
return tensor
def dist_all_reduce_tensor(tensor, reduce='mean'):
r""" Reduce to all ranks """
world_size = get_world_size()
if world_size < 2:
return tensor
with torch.no_grad():
dist.all_reduce(tensor)
if reduce == 'mean':
tensor /= world_size
elif reduce == 'sum':
pass
else:
raise NotImplementedError
return tensor
def dist_all_gather_tensor(tensor):
r""" gather to all ranks """
world_size = get_world_size()
if world_size < 2:
return [tensor]
tensor_list = [
torch.ones_like(tensor) for _ in range(dist.get_world_size())]
with torch.no_grad():
dist.all_gather(tensor_list, tensor)
return tensor_list