|
from typing import Callable, Tuple, List, Any, Union |
|
from easydict import EasyDict |
|
|
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
|
|
from .default_helper import error_wrapper |
|
|
|
|
|
|
|
|
|
def get_rank() -> int: |
|
""" |
|
Overview: |
|
Get the rank of current process in total world_size |
|
""" |
|
|
|
return error_wrapper(dist.get_rank, 0)() |
|
|
|
|
|
def get_world_size() -> int: |
|
""" |
|
Overview: |
|
Get the world_size(total process number in data parallel training) |
|
""" |
|
|
|
return error_wrapper(dist.get_world_size, 1)() |
|
|
|
|
|
broadcast = dist.broadcast |
|
allgather = dist.all_gather |
|
broadcast_object_list = dist.broadcast_object_list |
|
|
|
|
|
def allreduce(x: torch.Tensor) -> None: |
|
""" |
|
Overview: |
|
All reduce the tensor ``x`` in the world |
|
Arguments: |
|
- x (:obj:`torch.Tensor`): the tensor to be reduced |
|
""" |
|
|
|
dist.all_reduce(x) |
|
x.div_(get_world_size()) |
|
|
|
|
|
def allreduce_async(name: str, x: torch.Tensor) -> None: |
|
""" |
|
Overview: |
|
All reduce the tensor ``x`` in the world asynchronously |
|
Arguments: |
|
- name (:obj:`str`): the name of the tensor |
|
- x (:obj:`torch.Tensor`): the tensor to be reduced |
|
""" |
|
|
|
x.div_(get_world_size()) |
|
dist.all_reduce(x, async_op=True) |
|
|
|
|
|
def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, float, torch.Tensor]: |
|
""" |
|
Overview: |
|
Reduce the tensor ``x`` to the destination process ``dst`` |
|
Arguments: |
|
- x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced |
|
- dst (:obj:`int`): the destination process |
|
""" |
|
|
|
if np.isscalar(x): |
|
x_tensor = torch.as_tensor([x]).cuda() |
|
dist.reduce(x_tensor, dst) |
|
return x_tensor.item() |
|
elif isinstance(x, torch.Tensor): |
|
dist.reduce(x, dst) |
|
return x |
|
else: |
|
raise TypeError("not supported type: {}".format(type(x))) |
|
|
|
|
|
def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, float, torch.Tensor]: |
|
""" |
|
Overview: |
|
All reduce the tensor ``x`` in the world |
|
Arguments: |
|
- x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced |
|
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'avg']`` |
|
""" |
|
|
|
assert op in ['sum', 'avg'], op |
|
if np.isscalar(x): |
|
x_tensor = torch.as_tensor([x]).cuda() |
|
dist.all_reduce(x_tensor) |
|
if op == 'avg': |
|
x_tensor.div_(get_world_size()) |
|
return x_tensor.item() |
|
elif isinstance(x, torch.Tensor): |
|
dist.all_reduce(x) |
|
if op == 'avg': |
|
x.div_(get_world_size()) |
|
return x |
|
else: |
|
raise TypeError("not supported type: {}".format(type(x))) |
|
|
|
|
|
synchronize = torch.cuda.synchronize |
|
|
|
|
|
def get_group(group_size: int) -> List: |
|
""" |
|
Overview: |
|
Get the group segmentation of ``group_size`` each group |
|
Arguments: |
|
- group_size (:obj:`int`) the ``group_size`` |
|
""" |
|
rank = get_rank() |
|
world_size = get_world_size() |
|
if group_size is None: |
|
group_size = world_size |
|
assert (world_size % group_size == 0) |
|
return simple_group_split(world_size, rank, world_size // group_size) |
|
|
|
|
|
def dist_mode(func: Callable) -> Callable: |
|
""" |
|
Overview: |
|
Wrap the function so that in can init and finalize automatically before each call |
|
Arguments: |
|
- func (:obj:`Callable`): the function to be wrapped |
|
""" |
|
|
|
def wrapper(*args, **kwargs): |
|
dist_init() |
|
func(*args, **kwargs) |
|
dist_finalize() |
|
|
|
return wrapper |
|
|
|
|
|
def dist_init(backend: str = 'nccl', |
|
addr: str = None, |
|
port: str = None, |
|
rank: int = None, |
|
world_size: int = None) -> Tuple[int, int]: |
|
""" |
|
Overview: |
|
Initialize the distributed training setting |
|
Arguments: |
|
- backend (:obj:`str`): The backend of the distributed training, support ``['nccl', 'gloo']`` |
|
- addr (:obj:`str`): The address of the master node |
|
- port (:obj:`str`): The port of the master node |
|
- rank (:obj:`int`): The rank of current process |
|
- world_size (:obj:`int`): The total number of processes |
|
""" |
|
|
|
assert backend in ['nccl', 'gloo'], backend |
|
os.environ['MASTER_ADDR'] = addr or os.environ.get('MASTER_ADDR', "localhost") |
|
os.environ['MASTER_PORT'] = port or os.environ.get('MASTER_PORT', "10314") |
|
|
|
if rank is None: |
|
local_id = os.environ.get('SLURM_LOCALID', os.environ.get('RANK', None)) |
|
if local_id is None: |
|
raise RuntimeError("please indicate rank explicitly in dist_init method") |
|
else: |
|
rank = int(local_id) |
|
if world_size is None: |
|
ntasks = os.environ.get('SLURM_NTASKS', os.environ.get('WORLD_SIZE', None)) |
|
if ntasks is None: |
|
raise RuntimeError("please indicate world_size explicitly in dist_init method") |
|
else: |
|
world_size = int(ntasks) |
|
|
|
dist.init_process_group(backend=backend, rank=rank, world_size=world_size) |
|
|
|
num_gpus = torch.cuda.device_count() |
|
torch.cuda.set_device(rank % num_gpus) |
|
world_size = get_world_size() |
|
rank = get_rank() |
|
return rank, world_size |
|
|
|
|
|
def dist_finalize() -> None: |
|
""" |
|
Overview: |
|
Finalize distributed training resources |
|
""" |
|
|
|
|
|
pass |
|
|
|
|
|
class DDPContext: |
|
""" |
|
Overview: |
|
A context manager for ``linklink`` distribution |
|
Interfaces: |
|
``__init__``, ``__enter__``, ``__exit__`` |
|
""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Overview: |
|
Initialize the ``DDPContext`` |
|
""" |
|
|
|
pass |
|
|
|
def __enter__(self) -> None: |
|
""" |
|
Overview: |
|
Initialize ``linklink`` distribution |
|
""" |
|
|
|
dist_init() |
|
|
|
def __exit__(self, *args, **kwargs) -> Any: |
|
""" |
|
Overview: |
|
Finalize ``linklink`` distribution |
|
""" |
|
|
|
dist_finalize() |
|
|
|
|
|
def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: |
|
""" |
|
Overview: |
|
Split the group according to ``worldsize``, ``rank`` and ``num_groups`` |
|
Arguments: |
|
- world_size (:obj:`int`): The world size |
|
- rank (:obj:`int`): The rank |
|
- num_groups (:obj:`int`): The number of groups |
|
|
|
.. note:: |
|
With faulty input, raise ``array split does not result in an equal division`` |
|
""" |
|
groups = [] |
|
rank_list = np.split(np.arange(world_size), num_groups) |
|
rank_list = [list(map(int, x)) for x in rank_list] |
|
for i in range(num_groups): |
|
groups.append(dist.new_group(rank_list[i])) |
|
group_size = world_size // num_groups |
|
return groups[rank // group_size] |
|
|
|
|
|
def to_ddp_config(cfg: EasyDict) -> EasyDict: |
|
""" |
|
Overview: |
|
Convert the config to ddp config |
|
Arguments: |
|
- cfg (:obj:`EasyDict`): The config to be converted |
|
""" |
|
|
|
w = get_world_size() |
|
if 'batch_size' in cfg.policy: |
|
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w)) |
|
if 'batch_size' in cfg.policy.learn: |
|
cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w)) |
|
if 'n_sample' in cfg.policy.collect: |
|
cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample / w)) |
|
if 'n_episode' in cfg.policy.collect: |
|
cfg.policy.collect.n_episode = int(np.ceil(cfg.policy.collect.n_episode / w)) |
|
return cfg |
|
|