File size: 1,624 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict
from fairseq.distributed import utils
try:
from fairscale.optim import OSS
_has_fairscale = True
except ImportError:
_has_fairscale = False
def shard_(optimizer, group):
if not _has_fairscale:
raise ImportError(
"\n\nPlease install the fairscale package:" "\n\n pip install fairscale"
)
class FairseqOSS(OSS):
@property
def disable_mem_eff_fp16_loading_hack(self):
return True
def __getattr__(self, name):
if name.startswith("supports") and hasattr(self.optim, name):
return getattr(self.optim, name)
raise AttributeError(
"'FairseqOSS' object has no attribute {0!r}".format(name)
)
def broadcast_global_state_dict(
self, state_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Broadcasts the entire state_dict to all other ranks
each rank is responsible to load their own partition of data
"""
return utils.broadcast_object(
state_dict,
src_rank=0,
group=self.group,
)
torch_optimizer = optimizer.optimizer
optim_cls = type(torch_optimizer)
optimizer.optimizer = FairseqOSS(
torch_optimizer.param_groups,
optim_cls,
group=group,
**optimizer.optimizer_config
)
|