SillyTavern-Extras1
/
modules
/voice_conversion
/fairseq
/distributed
/legacy_distributed_data_parallel.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
A modified version of the legacy DistributedDataParallel module that uses c10d | |
communication primitives. This version is simpler than the latest PyTorch | |
version and is useful for debugging. Notably it does not overlap gradient | |
communication with the backward pass, which makes it slower but more robust | |
than the PyTorch version. | |
This version also supports the *no_sync* context manager, which allows faster | |
training with `--update-freq`. | |
""" | |
from collections import OrderedDict | |
from contextlib import contextmanager | |
import torch | |
from torch import nn | |
from fairseq.distributed import utils | |
class LegacyDistributedDataParallel(nn.Module): | |
"""Implements distributed data parallelism at the module level. | |
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. | |
This version uses a c10d process group for communication and does not | |
broadcast buffers. | |
Args: | |
module (~torch.nn.Module): module to be parallelized | |
process_group: the c10d process group to be used for distributed data | |
parallel all-reduction. | |
buffer_size (int, optional): number of elements to buffer before | |
performing all-reduce (default: 256M). | |
""" | |
def __init__(self, module, process_group, buffer_size=2**28): | |
super().__init__() | |
self.module = module | |
self.process_group = process_group | |
self.world_size = utils.get_world_size(self.process_group) | |
# Never use a bigger buffer than the number of model params | |
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) | |
self.buffer = None | |
# We can also forcibly accumulate grads locally and only do the | |
# all-reduce at some later time | |
self.accumulate_grads = False | |
# make per-device lists of parameters | |
paramlists = OrderedDict() | |
for param in self.module.parameters(): | |
device = param.device | |
if paramlists.get(device) is None: | |
paramlists[device] = [] | |
paramlists[device] += [param] | |
self.per_device_params = list(paramlists.values()) | |
def no_sync(self): | |
"""A context manager to disable gradient synchronization.""" | |
old_accumulate_grads = self.accumulate_grads | |
self.accumulate_grads = True | |
yield | |
self.accumulate_grads = old_accumulate_grads | |
def forward(self, *inputs, **kwargs): | |
return self.module(*inputs, **kwargs) | |
def all_reduce_grads(self): | |
""" | |
This function must be called explicitly after backward to reduce | |
gradients. There is no automatic hook like c10d. | |
""" | |
def all_reduce_params(params): | |
buffer = self.buffer | |
nonzero_buffer = False | |
if len(params) > 1: | |
offset = 0 | |
for p in params: | |
sz = p.numel() | |
if p.grad is not None: | |
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) | |
nonzero_buffer = True | |
else: | |
buffer[offset : offset + sz].zero_() | |
offset += sz | |
else: | |
# we only have a single grad to all-reduce | |
p = params[0] | |
if p.grad is not None: | |
buffer = p.grad.data | |
nonzero_buffer = True | |
elif p.numel() <= self.buffer.numel(): | |
buffer = buffer[: p.numel()] | |
buffer.zero_() | |
else: | |
buffer = torch.zeros_like(p) | |
if nonzero_buffer: | |
buffer.div_(self.world_size) | |
utils.all_reduce(buffer, self.process_group) | |
# copy all-reduced grads back into their original place | |
offset = 0 | |
for p in params: | |
sz = p.numel() | |
if p.grad is not None: | |
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) | |
else: | |
p.grad = buffer[offset : offset + sz].view_as(p).clone() | |
offset += sz | |
def reduction_fn(): | |
# This function only needs to be called once | |
if self.accumulate_grads: | |
return | |
if self.buffer is None: | |
self.buffer = next(self.module.parameters()).new(self.buffer_size) | |
for params in self.per_device_params: | |
# All-reduce the gradients in buckets | |
offset = 0 | |
buffered_params = [] | |
for param in params: | |
if not param.requires_grad: | |
continue | |
if param.grad is None: | |
param.grad = torch.zeros_like(param) | |
if hasattr(param, "expert"): | |
# Skip gradient sync for unshared parameters | |
continue | |
if param.grad.requires_grad: | |
raise RuntimeError( | |
"DistributedDataParallel only works " | |
"with gradients that don't require " | |
"grad" | |
) | |
sz = param.numel() | |
if sz > self.buffer.numel(): | |
# all-reduce big params directly | |
all_reduce_params([param]) | |
else: | |
if offset + sz > self.buffer.numel(): | |
all_reduce_params(buffered_params) | |
offset = 0 | |
buffered_params.clear() | |
buffered_params.append(param) | |
offset += sz | |
if len(buffered_params) > 0: | |
all_reduce_params(buffered_params) | |
reduction_fn() | |