|
from functools import lru_cache |
|
from typing import Callable, Tuple, List, Any |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .default_helper import error_wrapper |
|
from .fake_linklink import FakeLink |
|
from .import_helper import try_import_link |
|
|
|
|
|
@lru_cache() |
|
def get_link(): |
|
return try_import_link() |
|
|
|
|
|
@lru_cache() |
|
def is_fake_link(): |
|
return isinstance(get_link(), FakeLink) |
|
|
|
|
|
def get_rank() -> int: |
|
""" |
|
Overview: |
|
Get the rank of ``linklink`` model, return 0 if use ``FakeLink``. |
|
|
|
.. note:: |
|
Reference ``import_helper.try_import_link`` and ``linklink.get_rank``. |
|
""" |
|
if is_fake_link(): |
|
return 0 |
|
return error_wrapper(get_link().get_rank, 0, "[WARNING]: call linklink error, return default_ret.")() |
|
|
|
|
|
def get_world_size() -> int: |
|
""" |
|
Overview: |
|
Get the ``world_size`` of ``linklink model``, return 0 if use ``FakeLink``. |
|
|
|
.. note:: |
|
Reference ``import_helper.try_import_link`` and ``linklink.get_world_size``. |
|
""" |
|
if is_fake_link(): |
|
return 1 |
|
return error_wrapper(get_link().get_world_size, 1, "[WARNING]: call linklink error, return default_ret.")() |
|
|
|
|
|
def broadcast(value: torch.Tensor, rank: int) -> None: |
|
""" |
|
Overview: |
|
Use ``linklink.broadcast`` and raise error when using ``FakeLink`` |
|
Arguments: |
|
- value (:obj:`obj`): the value to board cast |
|
- rank (:obj:`int`): the rank to broadcast on |
|
""" |
|
if is_fake_link(): |
|
raise NotImplementedError |
|
get_link().broadcast(value, rank) |
|
|
|
|
|
def allreduce(data: torch.Tensor, op: str = 'sum') -> None: |
|
""" |
|
Overview: |
|
Call ``linklink.allreduce`` on the data |
|
Arguments: |
|
- data (:obj:`obj`): the data to reduce |
|
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` |
|
""" |
|
link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} |
|
if op not in link_op_map.keys(): |
|
raise KeyError("not support allreduce op type: {}".format(op)) |
|
else: |
|
link_op = link_op_map[op] |
|
if is_fake_link(): |
|
return data |
|
get_link().allreduce(data, reduce_op=link_op) |
|
if op == 'sum': |
|
data.div_(get_world_size()) |
|
|
|
|
|
def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None: |
|
""" |
|
Overview: |
|
Call ``linklink.allreduce_async`` on the data |
|
Arguments: |
|
- data (:obj:`obj`): the data to reduce |
|
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` |
|
""" |
|
link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} |
|
if op not in link_op_map.keys(): |
|
raise KeyError("not support allreduce op type: {}".format(op)) |
|
else: |
|
link_op = link_op_map[op] |
|
if is_fake_link(): |
|
return data |
|
if op == 'sum': |
|
data.div_(get_world_size()) |
|
get_link().allreduce_async(data, reduce_op=link_op) |
|
|
|
|
|
def get_group(group_size: int) -> List: |
|
""" |
|
Overview: |
|
Get the group segmentation of ``group_size`` each group |
|
Arguments: |
|
- group_size (:obj:`int`) the ``group_size`` |
|
""" |
|
rank = get_rank() |
|
world_size = get_world_size() |
|
if group_size is None: |
|
group_size = world_size |
|
assert (world_size % group_size == 0) |
|
return simple_group_split(world_size, rank, world_size // group_size) |
|
|
|
|
|
def dist_mode(func: Callable) -> Callable: |
|
""" |
|
Overview: |
|
Wrap the function so that in can init and finalize automatically before each call |
|
Arguments: |
|
- func (:obj:`Callable`): the function to wrap |
|
""" |
|
|
|
def wrapper(*args, **kwargs): |
|
dist_init() |
|
func(*args, **kwargs) |
|
dist_finalize() |
|
|
|
return wrapper |
|
|
|
|
|
def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]: |
|
""" |
|
Overview: |
|
Init the distribution |
|
Arguments: |
|
- method (:obj:`str`): Support ``['slurm', 'single_node`]`` |
|
- device_id (:obj:`int`): Default device when using ``single_node`` method |
|
""" |
|
get_link().initialize() |
|
world_size = get_link().get_world_size() |
|
rank = get_link().get_rank() |
|
|
|
if method == 'slurm': |
|
|
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
torch.cuda.set_device(rank % num_gpus) |
|
elif method == 'single_node': |
|
torch.cuda.set_device(device_id) |
|
|
|
return rank, world_size |
|
|
|
|
|
def dist_finalize() -> None: |
|
""" |
|
Overview: |
|
Finalize ``linklink``, see ``linklink.finalize()`` |
|
""" |
|
get_link().finalize() |
|
|
|
|
|
class DistContext: |
|
""" |
|
Overview: |
|
A context manager for ``linklink`` distribution |
|
Interfaces: |
|
``__init__``, ``__enter__``, ``__exit__`` |
|
""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Overview: |
|
Initialize the ``DistContext`` |
|
""" |
|
|
|
pass |
|
|
|
def __enter__(self) -> None: |
|
""" |
|
Overview: |
|
Initialize ``linklink`` distribution |
|
""" |
|
|
|
dist_init() |
|
|
|
def __exit__(self, *args, **kwargs) -> Any: |
|
""" |
|
Overview: |
|
Finalize ``linklink`` distribution |
|
Arugments: |
|
- args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function. |
|
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function. |
|
""" |
|
|
|
dist_finalize() |
|
|
|
|
|
def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: |
|
""" |
|
Overview: |
|
Split the group according to ``worldsize``, ``rank`` and ``num_groups`` |
|
Arguments: |
|
- world_size (:obj:`int`): The world size |
|
- rank (:obj:`int`): The rank |
|
- num_groups (:obj:`int`): The number of groups |
|
.. note:: |
|
With faulty input, raise ``array split does not result in an equal division`` |
|
""" |
|
|
|
groups = [] |
|
rank_list = np.split(np.arange(world_size), num_groups) |
|
rank_list = [list(map(int, x)) for x in rank_list] |
|
for i in range(num_groups): |
|
groups.append(get_link().new_group(rank_list[i])) |
|
group_size = world_size // num_groups |
|
return groups[rank // group_size] |
|
|
|
|
|
def synchronize(): |
|
""" |
|
Overview: |
|
Synchronize the process |
|
""" |
|
|
|
get_link().synchronize() |
|
|