# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch from torch import nn from fairseq.distributed import utils class TPUDistributedDataParallel(nn.Module): def __init__(self, module, process_group): super().__init__() self.module = module self.process_group = process_group self.world_size = utils.get_world_size(self.process_group) def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) def all_reduce_grads(self): gradients = [] for p in self.parameters(): if not p.requires_grad: continue if p.grad is None: p.grad = torch.zeros_like(p) if p.grad.requires_grad: raise RuntimeError( "TPUDistributedDataParallel only works with gradients that don't " "require grad" ) gradients.append(p.grad) import torch_xla.core.xla_model as xm xm.all_reduce( "sum", gradients, scale=1.0 / self.world_size, groups=self.process_group[1], )