|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from fairseq import utils |
|
from fairseq.dataclass.utils import gen_parser_from_dataclass |
|
|
|
|
|
class FairseqOptimizer(object): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
|
|
@classmethod |
|
def add_args(cls, parser): |
|
"""Add optimizer-specific arguments to the parser.""" |
|
dc = getattr(cls, "__dataclass", None) |
|
if dc is not None: |
|
gen_parser_from_dataclass(parser, dc()) |
|
|
|
@property |
|
def optimizer(self): |
|
"""Return a torch.optim.optimizer.Optimizer instance.""" |
|
if not hasattr(self, "_optimizer"): |
|
raise NotImplementedError |
|
if not isinstance(self._optimizer, torch.optim.Optimizer): |
|
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") |
|
return self._optimizer |
|
|
|
@optimizer.setter |
|
def optimizer(self, optimizer): |
|
"""Reset optimizer instance.""" |
|
if not hasattr(self, "_optimizer"): |
|
raise NotImplementedError |
|
if not isinstance(self._optimizer, torch.optim.Optimizer): |
|
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") |
|
self._optimizer = optimizer |
|
|
|
@property |
|
def optimizer_config(self): |
|
""" |
|
Return a kwarg dictionary that will be used to override optimizer |
|
args stored in checkpoints. This allows us to load a checkpoint and |
|
resume training using a different set of optimizer args, e.g., with a |
|
different learning rate. |
|
""" |
|
raise NotImplementedError |
|
|
|
@property |
|
def params(self): |
|
"""Return an iterable of the parameters held by the optimizer.""" |
|
for param_group in self.param_groups: |
|
for p in param_group["params"]: |
|
yield p |
|
|
|
@property |
|
def param_groups(self): |
|
return self.optimizer.param_groups |
|
|
|
def __getstate__(self): |
|
return self._optimizer.__getstate__() |
|
|
|
def get_lr(self): |
|
"""Return the current learning rate.""" |
|
return self.param_groups[0]["lr"] |
|
|
|
def set_lr(self, lr): |
|
"""Set the learning rate.""" |
|
for param_group in self.param_groups: |
|
param_group["lr"] = lr |
|
|
|
def state_dict(self): |
|
"""Return the optimizer's state dict.""" |
|
return self.optimizer.state_dict() |
|
|
|
def load_state_dict(self, state_dict, optimizer_overrides=None): |
|
"""Load an optimizer state dict. |
|
|
|
In general we should prefer the configuration of the existing optimizer |
|
instance (e.g., learning rate) over that found in the state_dict. This |
|
allows us to resume training from a checkpoint using a new set of |
|
optimizer args. |
|
""" |
|
self.optimizer.load_state_dict(state_dict) |
|
|
|
if optimizer_overrides is not None and len(optimizer_overrides) > 0: |
|
|
|
for group in self.param_groups: |
|
group.update(optimizer_overrides) |
|
|
|
def backward(self, loss): |
|
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" |
|
loss.backward() |
|
|
|
def all_reduce_grads(self, module): |
|
"""Manually all-reduce gradients (if required).""" |
|
if hasattr(module, "all_reduce_grads"): |
|
module.all_reduce_grads() |
|
|
|
def multiply_grads(self, c): |
|
"""Multiplies grads by a constant *c*.""" |
|
for p in self.params: |
|
if p.grad is not None: |
|
if torch.is_tensor(c): |
|
c = c.to(p.grad.device) |
|
p.grad.data.mul_(c) |
|
|
|
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): |
|
"""Clips gradient norm.""" |
|
return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) |
|
|
|
def step(self, closure=None, scale=1.0, groups=None): |
|
"""Performs a single optimization step.""" |
|
if self.supports_step_with_scale: |
|
if self.supports_groups: |
|
self.optimizer.step(closure, scale=scale, groups=groups) |
|
else: |
|
self.optimizer.step(closure, scale=scale) |
|
else: |
|
if scale != 1.0: |
|
self.multiply_grads(1.0 / scale) |
|
if self.supports_groups: |
|
self.optimizer.step(closure, groups=groups) |
|
else: |
|
self.optimizer.step(closure) |
|
|
|
def zero_grad(self): |
|
"""Clears the gradients of all optimized parameters.""" |
|
for p in self.params: |
|
p.grad = None |
|
self.optimizer.zero_grad() |
|
|
|
@property |
|
def supports_memory_efficient_fp16(self): |
|
if hasattr(self.optimizer, "supports_memory_efficient_fp16"): |
|
return self.optimizer.supports_memory_efficient_fp16 |
|
return False |
|
|
|
@property |
|
def supports_step_with_scale(self): |
|
if hasattr(self.optimizer, "supports_step_with_scale"): |
|
return self.optimizer.supports_step_with_scale |
|
return False |
|
|
|
@property |
|
def supports_groups(self): |
|
if hasattr(self.optimizer, "supports_groups"): |
|
return self.optimizer.supports_groups |
|
return False |
|
|
|
@property |
|
def supports_flat_params(self): |
|
""" |
|
Whether the optimizer supports collapsing of the model |
|
parameters/gradients into a single contiguous Tensor. |
|
""" |
|
if hasattr(self.optimizer, "supports_flat_params"): |
|
return self.optimizer.supports_flat_params |
|
return False |
|
|
|
def average_params(self): |
|
pass |
|
|
|
def broadcast_global_state_dict(self, state_dict): |
|
""" |
|
Broadcasts a global state dict to all ranks. |
|
Useful for optimizers that shard state between ranks. |
|
""" |
|
if hasattr(self.optimizer, "broadcast_global_state_dict"): |
|
return self.optimizer.broadcast_global_state_dict(state_dict) |
|
else: |
|
return state_dict |
|
|
|
|
|
class LegacyFairseqOptimizer(FairseqOptimizer): |
|
def __init__(self, args): |
|
self.args = args |
|
|