File size: 1,285 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from torch import nn
from fairseq.distributed import utils
class TPUDistributedDataParallel(nn.Module):
def __init__(self, module, process_group):
super().__init__()
self.module = module
self.process_group = process_group
self.world_size = utils.get_world_size(self.process_group)
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
gradients = []
for p in self.parameters():
if not p.requires_grad:
continue
if p.grad is None:
p.grad = torch.zeros_like(p)
if p.grad.requires_grad:
raise RuntimeError(
"TPUDistributedDataParallel only works with gradients that don't "
"require grad"
)
gradients.append(p.grad)
import torch_xla.core.xla_model as xm
xm.all_reduce(
"sum",
gradients,
scale=1.0 / self.world_size,
groups=self.process_group[1],
)
|