# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved """ [Copied from detectron2] This file contains primitives for multi-gpu communication. This is useful when doing distributed training. """ import functools import logging import numpy as np import pickle import torch import torch.distributed as dist _LOCAL_PROCESS_GROUP = None """ A torch process group which only includes processes that on the same machine as the current process. This variable is set when processes are spawned by `launch()` in "engine/launch.py". """ def get_world_size() -> int: if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def get_rank() -> int: if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def get_local_rank() -> int: """ Returns: The rank of the current process within the local (per-machine) process group. """ if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 assert _LOCAL_PROCESS_GROUP is not None return dist.get_rank(group=_LOCAL_PROCESS_GROUP) def get_local_size() -> int: """ Returns: The size of the per-machine process group, i.e. the number of processes per machine. """ if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) def is_main_process() -> bool: return get_rank() == 0 def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() @functools.lru_cache() def _get_global_gloo_group(): """ Return a process group based on gloo backend, containing all the ranks The result is cached. """ if dist.get_backend() == "nccl": return dist.new_group(backend="gloo") else: return dist.group.WORLD def _serialize_to_tensor(data, group): backend = dist.get_backend(group) assert backend in ["gloo", "nccl"] device = torch.device("cpu" if backend == "gloo" else "cuda") buffer = pickle.dumps(data) if len(buffer) > 1024 ** 3: logger = logging.getLogger(__name__) logger.warning( "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( get_rank(), len(buffer) / (1024 ** 3), device ) ) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(device=device) return tensor def _pad_to_largest_tensor(tensor, group): """ Returns: list[int]: size of the tensor, on each rank Tensor: padded tensor that has the max size """ world_size = dist.get_world_size(group=group) assert ( world_size >= 1 ), "comm.gather/all_gather must be called from ranks within the given group!" local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) size_list = [ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) ] dist.all_gather(size_list, local_size, group=group) size_list = [int(size.item()) for size in size_list] max_size = max(size_list) # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes if local_size != max_size: padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) tensor = torch.cat((tensor, padding), dim=0) return size_list, tensor def all_gather(data, group=None): """ Run all_gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: list of data gathered from each rank """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() if dist.get_world_size(group) == 1: return [data] tensor = _serialize_to_tensor(data, group) size_list, tensor = _pad_to_largest_tensor(tensor, group) max_size = max(size_list) # receiving Tensor from all ranks tensor_list = [ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list ] dist.all_gather(tensor_list, tensor, group=group) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list def gather(data, dst=0, group=None): """ Run gather on arbitrary picklable data (not necessarily tensors). Args: data: any picklable object dst (int): destination rank group: a torch process group. By default, will use a group which contains all ranks on gloo backend. Returns: list[data]: on dst, a list of data gathered from each rank. Otherwise, an empty list. """ if get_world_size() == 1: return [data] if group is None: group = _get_global_gloo_group() if dist.get_world_size(group=group) == 1: return [data] rank = dist.get_rank(group=group) tensor = _serialize_to_tensor(data, group) size_list, tensor = _pad_to_largest_tensor(tensor, group) # receiving Tensor from all ranks if rank == dst: max_size = max(size_list) tensor_list = [ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list ] dist.gather(tensor, tensor_list, dst=dst, group=group) data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list else: dist.gather(tensor, [], dst=dst, group=group) return [] def shared_random_seed(): """ Returns: int: a random number that is the same across all workers. If workers need a shared RNG, they can use this shared seed to create one. All workers must call this function, otherwise it will deadlock. """ ints = np.random.randint(2 ** 31) all_ints = all_gather(ints) return all_ints[0] def reduce_dict(input_dict, average=True): """ Reduce the values in the dictionary from all processes so that process with rank 0 has the reduced results. Args: input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. average (bool): whether to do average or sum Returns: a dict with the same keys as input_dict, after reduction. """ world_size = get_world_size() if world_size < 2: return input_dict with torch.no_grad(): names = [] values = [] # sort the keys so that they are consistent across processes for k in sorted(input_dict.keys()): names.append(k) values.append(input_dict[k]) values = torch.stack(values, dim=0) dist.reduce(values, dst=0) if dist.get_rank() == 0 and average: # only main process gets accumulated, so only divide by # world_size in this case values /= world_size reduced_dict = {k: v for k, v in zip(names, values)} return reduced_dict