gomoku / DI-engine /ding /utils /linklink_dist_helper.py
zjowowen's picture
init space
079c32c
raw
history blame
6.32 kB
from functools import lru_cache
from typing import Callable, Tuple, List, Any
import numpy as np
import torch
from .default_helper import error_wrapper
from .fake_linklink import FakeLink
from .import_helper import try_import_link
@lru_cache()
def get_link():
return try_import_link()
@lru_cache()
def is_fake_link():
return isinstance(get_link(), FakeLink)
def get_rank() -> int:
"""
Overview:
Get the rank of ``linklink`` model, return 0 if use ``FakeLink``.
.. note::
Reference ``import_helper.try_import_link`` and ``linklink.get_rank``.
"""
if is_fake_link():
return 0
return error_wrapper(get_link().get_rank, 0, "[WARNING]: call linklink error, return default_ret.")()
def get_world_size() -> int:
"""
Overview:
Get the ``world_size`` of ``linklink model``, return 0 if use ``FakeLink``.
.. note::
Reference ``import_helper.try_import_link`` and ``linklink.get_world_size``.
"""
if is_fake_link():
return 1
return error_wrapper(get_link().get_world_size, 1, "[WARNING]: call linklink error, return default_ret.")()
def broadcast(value: torch.Tensor, rank: int) -> None:
"""
Overview:
Use ``linklink.broadcast`` and raise error when using ``FakeLink``
Arguments:
- value (:obj:`obj`): the value to board cast
- rank (:obj:`int`): the rank to broadcast on
"""
if is_fake_link():
raise NotImplementedError
get_link().broadcast(value, rank)
def allreduce(data: torch.Tensor, op: str = 'sum') -> None:
"""
Overview:
Call ``linklink.allreduce`` on the data
Arguments:
- data (:obj:`obj`): the data to reduce
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']``
"""
link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max}
if op not in link_op_map.keys():
raise KeyError("not support allreduce op type: {}".format(op))
else:
link_op = link_op_map[op]
if is_fake_link():
return data
get_link().allreduce(data, reduce_op=link_op)
if op == 'sum':
data.div_(get_world_size())
def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None:
"""
Overview:
Call ``linklink.allreduce_async`` on the data
Arguments:
- data (:obj:`obj`): the data to reduce
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']``
"""
link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max}
if op not in link_op_map.keys():
raise KeyError("not support allreduce op type: {}".format(op))
else:
link_op = link_op_map[op]
if is_fake_link():
return data
if op == 'sum':
data.div_(get_world_size())
get_link().allreduce_async(data, reduce_op=link_op)
def get_group(group_size: int) -> List:
"""
Overview:
Get the group segmentation of ``group_size`` each group
Arguments:
- group_size (:obj:`int`) the ``group_size``
"""
rank = get_rank()
world_size = get_world_size()
if group_size is None:
group_size = world_size
assert (world_size % group_size == 0)
return simple_group_split(world_size, rank, world_size // group_size)
def dist_mode(func: Callable) -> Callable:
"""
Overview:
Wrap the function so that in can init and finalize automatically before each call
Arguments:
- func (:obj:`Callable`): the function to wrap
"""
def wrapper(*args, **kwargs):
dist_init()
func(*args, **kwargs)
dist_finalize()
return wrapper
def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]:
"""
Overview:
Init the distribution
Arguments:
- method (:obj:`str`): Support ``['slurm', 'single_node`]``
- device_id (:obj:`int`): Default device when using ``single_node`` method
"""
get_link().initialize()
world_size = get_link().get_world_size()
rank = get_link().get_rank()
if method == 'slurm':
# proc_id = int(os.environ['SLURM_PROCID'])
# ntasks = int(os.environ['SLURM_NTASKS'])
# node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
elif method == 'single_node':
torch.cuda.set_device(device_id)
return rank, world_size
def dist_finalize() -> None:
"""
Overview:
Finalize ``linklink``, see ``linklink.finalize()``
"""
get_link().finalize()
class DistContext:
"""
Overview:
A context manager for ``linklink`` distribution
Interfaces:
``__init__``, ``__enter__``, ``__exit__``
"""
def __init__(self) -> None:
"""
Overview:
Initialize the ``DistContext``
"""
pass
def __enter__(self) -> None:
"""
Overview:
Initialize ``linklink`` distribution
"""
dist_init()
def __exit__(self, *args, **kwargs) -> Any:
"""
Overview:
Finalize ``linklink`` distribution
Arugments:
- args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function.
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function.
"""
dist_finalize()
def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:
"""
Overview:
Split the group according to ``worldsize``, ``rank`` and ``num_groups``
Arguments:
- world_size (:obj:`int`): The world size
- rank (:obj:`int`): The rank
- num_groups (:obj:`int`): The number of groups
.. note::
With faulty input, raise ``array split does not result in an equal division``
"""
groups = []
rank_list = np.split(np.arange(world_size), num_groups)
rank_list = [list(map(int, x)) for x in rank_list]
for i in range(num_groups):
groups.append(get_link().new_group(rank_list[i]))
group_size = world_size // num_groups
return groups[rank // group_size]
def synchronize():
"""
Overview:
Synchronize the process
"""
get_link().synchronize()