Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import functools | |
import ctypes | |
import torch | |
import torch.distributed as dist | |
def init_dist(local_rank, backend='nccl', **kwargs): | |
r"""Initialize distributed training""" | |
if dist.is_available(): | |
if dist.is_initialized(): | |
return torch.cuda.current_device() | |
torch.cuda.set_device(local_rank) | |
dist.init_process_group(backend=backend, init_method='env://', **kwargs) | |
# Increase the L2 fetch granularity for faster speed. | |
_libcudart = ctypes.CDLL('libcudart.so') | |
# Set device limit on the current device | |
# cudaLimitMaxL2FetchGranularity = 0x05 | |
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) | |
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128)) | |
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05)) | |
# assert pValue.contents.value == 128 | |
def get_rank(): | |
r"""Get rank of the thread.""" | |
rank = 0 | |
if dist.is_available(): | |
if dist.is_initialized(): | |
rank = dist.get_rank() | |
return rank | |
def get_world_size(): | |
r"""Get world size. How many GPUs are available in this job.""" | |
world_size = 1 | |
if dist.is_available(): | |
if dist.is_initialized(): | |
world_size = dist.get_world_size() | |
return world_size | |
def master_only(func): | |
r"""Apply this function only to the master GPU.""" | |
def wrapper(*args, **kwargs): | |
r"""Simple function wrapper for the master function""" | |
if get_rank() == 0: | |
return func(*args, **kwargs) | |
else: | |
return None | |
return wrapper | |
def is_master(): | |
r"""check if current process is the master""" | |
return get_rank() == 0 | |
def is_local_master(): | |
return torch.cuda.current_device() == 0 | |
def master_only_print(*args): | |
r"""master-only print""" | |
print(*args) | |
def dist_reduce_tensor(tensor, rank=0, reduce='mean'): | |
r""" Reduce to rank 0 """ | |
world_size = get_world_size() | |
if world_size < 2: | |
return tensor | |
with torch.no_grad(): | |
dist.reduce(tensor, dst=rank) | |
if get_rank() == rank: | |
if reduce == 'mean': | |
tensor /= world_size | |
elif reduce == 'sum': | |
pass | |
else: | |
raise NotImplementedError | |
return tensor | |
def dist_all_reduce_tensor(tensor, reduce='mean'): | |
r""" Reduce to all ranks """ | |
world_size = get_world_size() | |
if world_size < 2: | |
return tensor | |
with torch.no_grad(): | |
dist.all_reduce(tensor) | |
if reduce == 'mean': | |
tensor /= world_size | |
elif reduce == 'sum': | |
pass | |
else: | |
raise NotImplementedError | |
return tensor | |
def dist_all_gather_tensor(tensor): | |
r""" gather to all ranks """ | |
world_size = get_world_size() | |
if world_size < 2: | |
return [tensor] | |
tensor_list = [ | |
torch.ones_like(tensor) for _ in range(dist.get_world_size())] | |
with torch.no_grad(): | |
dist.all_gather(tensor_list, tensor) | |
return tensor_list | |