""" Copyright (c) Microsoft Corporation. Licensed under the MIT license. distributed API using Horovod Modified from OpenNMT's native pytorch distributed utils (https://github.com/OpenNMT/OpenNMT-py) """ import math import pickle import torch import torch.distributed as dist from time import time from torch.autograd import Function from torch.utils.data.distributed import DistributedSampler class ddp_allgather_with_grads(Function): @staticmethod def forward(ctx, x): tmp_input = x.cuda() size = torch.tensor(tmp_input.shape[0]).cuda() size_list = [torch.zeros_like(size) for i in range(dist.get_world_size())] dist.all_gather(size_list, size) max_size = max(size_list).item() padding_size = max_size - size if padding_size > 0 : padding_tensor = torch.zeros(padding_size,*tmp_input.shape[1:]).to(tmp_input) tmp_input = torch.cat((tmp_input, padding_tensor), dim = 0) tmp_list = [torch.zeros_like(tmp_input) for i in range(dist.get_world_size())] dist.all_gather(tmp_list, tmp_input) ctx.size = size_list output = [] for t, s in zip(tmp_list, size_list): output.append(t[:s]) output = torch.cat(output,dim=0) output.requires_grad = True return output @staticmethod def backward(ctx, grad_output): grad_x = None if grad_output is not None: grad_output.detach() #grad_x = grad_output.chunk(dist.get_world_size(),dim=0)[dist.get_rank()] start = sum(ctx.size[:dist.get_rank()]) end = start + ctx.size[dist.get_rank()] grad_x = grad_output[start:end] return grad_x def ddp_allgather(input): tmp_input = input.cuda() size = torch.tensor(tmp_input.shape[0]).cuda() size_list = [torch.zeros_like(size) for i in range(dist.get_world_size())] dist.all_gather(size_list, size) max_size = max(size_list).item() padding_size = max_size - size if padding_size > 0 : padding_tensor = torch.zeros(padding_size,*tmp_input.shape[1:]).to(tmp_input) tmp_input = torch.cat((tmp_input, padding_tensor), dim = 0) tmp_list = [torch.zeros_like(tmp_input) for i in range(dist.get_world_size())] dist.all_gather(tmp_list, tmp_input) output = [] for t, s in zip(tmp_list, size_list): output.append(t[:s]) output = torch.cat(output,dim=0) return output def _encode(enc, max_size, use_max_size=False): enc_size = len(enc) enc_byte = max(math.floor(math.log(max_size, 256)+1), 1) if use_max_size: # this is used for broadcasting buffer_ = torch.cuda.ByteTensor(max_size+enc_byte) else: buffer_ = torch.cuda.ByteTensor(enc_size+enc_byte) remainder = enc_size for i in range(enc_byte): base = 256 ** (enc_byte-i-1) buffer_[i] = remainder // base remainder %= base buffer_[enc_byte:enc_byte+enc_size] = torch.ByteTensor(list(enc)) return buffer_, enc_byte def _decode(buffer_, enc_byte): size = sum(256 ** (enc_byte-i-1) * buffer_[i].item() for i in range(enc_byte)) bytes_list = bytes(buffer_[enc_byte:enc_byte+size].tolist()) shift = size + enc_byte return bytes_list, shift _BUFFER_SIZE = 4096 def all_gather_list(data): """Gathers arbitrary data from all nodes into a list.""" enc = pickle.dumps(data) enc_size = len(enc) max_size = ddp_allgather(torch.tensor([enc_size]).cuda()).max().item() in_buffer, enc_byte = _encode(enc, max_size) out_buffer = ddp_allgather(in_buffer[:enc_byte+enc_size]) results = [] for _ in range(dist.get_world_size()): bytes_list, shift = _decode(out_buffer, enc_byte) out_buffer = out_buffer[shift:] result = pickle.loads(bytes_list) results.append(result) return results def any_broadcast(data, root_rank): """broadcast arbitrary data from root_rank to all nodes.""" enc = pickle.dumps(data) max_size = ddp_allgather(torch.tensor([len(enc)]).cuda()).max().item() buffer_, enc_byte = _encode(enc, max_size, use_max_size=True) dist.broadcast(buffer_, root_rank) bytes_list, _ = _decode(buffer_, enc_byte) result = pickle.loads(bytes_list) return result class DistributedSampler_wopadding(DistributedSampler): def __iter__(self): if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if self.drop_last: indices = indices[:self.total_size] #assert len(indices) == self.total_size # subsample indices = indices[self.rank:len(indices):self.num_replicas] # assert len(indices) == self.num_samples return iter(indices) class GatherLayer(torch.autograd.Function): """ Gather tensors from all workers with support for backward propagation: This implementation does not cut the gradients as torch.distributed.all_gather does. """ @staticmethod def forward(ctx, x): output = [ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) ] torch.distributed.all_gather(output, x) return tuple(output) @staticmethod def backward(ctx, *grads): all_gradients = torch.stack(grads) torch.distributed.all_reduce(all_gradients) return all_gradients[torch.distributed.get_rank()] def all_gather_with_grad(tensors): """ Performs all_gather operation on the provided tensors. Graph remains connected for backward grad computation. """ # Queue the gathered tensors world_size = torch.distributed.get_world_size() # There is no need for reduction in the single-proc case if world_size == 1: return tensors # tensor_all = GatherLayer.apply(tensors) tensor_all = GatherLayer.apply(tensors) return torch.cat(tensor_all, dim=0)