|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import signal |
|
import threading |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
from fairseq.distributed import ( |
|
DistributedTimeoutWrapper, |
|
LegacyDistributedDataParallel, |
|
ModuleProxyWrapper, |
|
TPUDistributedDataParallel, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
_SLOWMO_DDP_DISABLED = False |
|
try: |
|
from fairscale.experimental.nn.data_parallel import ( |
|
SlowMoBaseAlgorithm, |
|
SlowMoDistributedDataParallel, |
|
) |
|
except ImportError: |
|
_SLOWMO_DDP_DISABLED = True |
|
|
|
|
|
def DistributedFairseqModel(args, model, process_group, device): |
|
""" |
|
Wrap a *model* to support distributed data parallel training. |
|
|
|
This is similar to the built-in DistributedDataParallel, but allows |
|
additional configuration of the DistributedDataParallel class to |
|
use, and also provides easier access to the wrapped model by |
|
forwarding requests for missing attributes to the wrapped model. |
|
|
|
Args: |
|
args (argparse.Namespace): fairseq args |
|
model (BaseFairseqModel): model to wrap |
|
process_group: the c10d process group to be used for distributed data |
|
parallel all-reduction. |
|
device: device to move model to |
|
""" |
|
assert isinstance(model, nn.Module) |
|
if args.tpu: |
|
wrapped_model = TPUDistributedDataParallel( |
|
module=model.to(device), |
|
process_group=process_group, |
|
) |
|
|
|
wrapped_model = ModuleProxyWrapper(wrapped_model) |
|
elif args.ddp_backend in {"c10d", "pytorch_ddp"}: |
|
wrapped_model = DistributedDataParallel( |
|
module=model.to(device), |
|
device_ids=[args.device_id], |
|
output_device=args.device_id, |
|
broadcast_buffers=args.broadcast_buffers, |
|
bucket_cap_mb=args.bucket_cap_mb, |
|
process_group=process_group, |
|
find_unused_parameters=args.find_unused_parameters, |
|
gradient_as_bucket_view=args.gradient_as_bucket_view, |
|
) |
|
if args.ddp_comm_hook == "fp16": |
|
logger.info("enable fp16 communication hook in DDP") |
|
try: |
|
from torch.distributed.algorithms.ddp_comm_hooks import ( |
|
DDPCommHookType, |
|
register_ddp_comm_hook, |
|
) |
|
except: |
|
logger.error( |
|
"Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version" |
|
) |
|
raise |
|
|
|
register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model) |
|
|
|
wrapped_model = ModuleProxyWrapper(wrapped_model) |
|
elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: |
|
wrapped_model = LegacyDistributedDataParallel( |
|
module=model.to(device), |
|
buffer_size=2**28, |
|
process_group=process_group, |
|
) |
|
|
|
wrapped_model = ModuleProxyWrapper(wrapped_model) |
|
elif args.ddp_backend == "slowmo": |
|
if _SLOWMO_DDP_DISABLED: |
|
raise ImportError( |
|
"Cannot find SlowMoDistributedDataParallel. " |
|
"Please install fairscale with: pip install fairscale" |
|
) |
|
|
|
|
|
|
|
if args.slowmo_momentum is None: |
|
if args.distributed_world_size <= 16: |
|
args.slowmo_momentum = 0.0 |
|
elif args.distributed_world_size <= 32: |
|
args.slowmo_momentum = 0.2 |
|
elif args.distributed_world_size <= 64: |
|
args.slowmo_momentum = 0.5 |
|
else: |
|
args.slowmo_momentum = 0.6 |
|
slowmo_base_algorithm = SlowMoBaseAlgorithm[args.slowmo_base_algorithm.upper()] |
|
|
|
wrapped_model = SlowMoDistributedDataParallel( |
|
module=model.to(device), |
|
broadcast_buffers=args.broadcast_buffers, |
|
nprocs_per_node=args.nprocs_per_node, |
|
slowmo_momentum=args.slowmo_momentum, |
|
slowmo_base_algorithm=slowmo_base_algorithm, |
|
localsgd_frequency=args.localsgd_frequency, |
|
) |
|
|
|
wrapped_model = ModuleProxyWrapper(wrapped_model) |
|
elif args.ddp_backend == "fully_sharded": |
|
try: |
|
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP |
|
except ImportError: |
|
raise ImportError( |
|
"Cannot find FullyShardedDataParallel. " |
|
"Please install fairscale with: pip install fairscale" |
|
) |
|
assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" |
|
wrapped_model = model |
|
if args.memory_efficient_fp16: |
|
wrapped_model = wrapped_model.half() |
|
if not args.cpu_offload: |
|
wrapped_model = wrapped_model.to(device=device) |
|
else: |
|
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) |
|
|
|
|
|
if getattr(args, "heartbeat_timeout", -1) > 0: |
|
wrapped_model = DistributedTimeoutWrapper( |
|
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) |
|
) |
|
|
|
return wrapped_model |
|
|