SillyTavern-Extras1
/
modules
/voice_conversion
/fairseq
/distributed
/tpu_distributed_data_parallel.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from torch import nn | |
from fairseq.distributed import utils | |
class TPUDistributedDataParallel(nn.Module): | |
def __init__(self, module, process_group): | |
super().__init__() | |
self.module = module | |
self.process_group = process_group | |
self.world_size = utils.get_world_size(self.process_group) | |
def forward(self, *inputs, **kwargs): | |
return self.module(*inputs, **kwargs) | |
def all_reduce_grads(self): | |
gradients = [] | |
for p in self.parameters(): | |
if not p.requires_grad: | |
continue | |
if p.grad is None: | |
p.grad = torch.zeros_like(p) | |
if p.grad.requires_grad: | |
raise RuntimeError( | |
"TPUDistributedDataParallel only works with gradients that don't " | |
"require grad" | |
) | |
gradients.append(p.grad) | |
import torch_xla.core.xla_model as xm | |
xm.all_reduce( | |
"sum", | |
gradients, | |
scale=1.0 / self.world_size, | |
groups=self.process_group[1], | |
) | |