|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from collections import defaultdict |
|
from dataclasses import dataclass, field |
|
from typing import Dict, Any, List, Optional |
|
|
|
import torch.optim |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer |
|
from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler |
|
from omegaconf import II, open_dict |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class OptimizerAndSchedulerConfig(FairseqDataclass): |
|
optimizer: Any = None |
|
lr_scheduler: Optional[Any] = None |
|
lr: List = II("optimization.lr") |
|
lr_float: Optional[ |
|
float |
|
] = None |
|
|
|
|
|
@dataclass |
|
class CompositeOptimizerConfig(FairseqDataclass): |
|
groups: Dict[str, Any] = field( |
|
default_factory=lambda: {}, |
|
metadata={ |
|
"help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " |
|
"Configures a different optimizer and (optionally) lr scheduler for each parameter group" |
|
}, |
|
) |
|
|
|
|
|
@register_optimizer("composite", dataclass=CompositeOptimizerConfig) |
|
class FairseqCompositeOptimizer(FairseqOptimizer): |
|
|
|
optimizers: Dict[str, FairseqOptimizer] = {} |
|
lr_schedulers: Dict[str, FairseqLRScheduler] = {} |
|
lr_scheduler: FairseqLRScheduler = None |
|
_optimizer: torch.optim.Optimizer |
|
|
|
def __init__(self, cfg: CompositeOptimizerConfig, params): |
|
super().__init__(cfg) |
|
|
|
assert ( |
|
len(params) > 1 |
|
), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)" |
|
|
|
groupped_params = defaultdict(list) |
|
for p in params: |
|
group = getattr(p, "param_group", "default") |
|
groupped_params[group].append(p) |
|
|
|
assert groupped_params.keys() == cfg.groups.keys(), ( |
|
f"Parameter groups {groupped_params.keys()} and optimizer groups {cfg.groups.keys()} are not the same! " |
|
"Try setting 'param_group' on your parameters in the model." |
|
) |
|
|
|
for group, group_params in groupped_params.items(): |
|
group_cfg = cfg.groups[group] |
|
with open_dict(group_cfg): |
|
if group_cfg.lr_float is not None: |
|
group_cfg.optimizer.lr = [group_cfg.lr_float] |
|
group_cfg.lr_scheduler.lr = [group_cfg.lr_float] |
|
else: |
|
group_cfg.optimizer.lr = group_cfg.lr |
|
group_cfg.lr_scheduler.lr = group_cfg.lr |
|
self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) |
|
if group_cfg.lr_scheduler is not None: |
|
self.lr_schedulers[group] = build_lr_scheduler( |
|
group_cfg.lr_scheduler, self.optimizers[group] |
|
) |
|
|
|
if len(self.lr_schedulers) > 0: |
|
assert len(self.lr_schedulers) == len(self.optimizers), ( |
|
f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. " |
|
f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}" |
|
) |
|
self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers) |
|
|
|
self._optimizer = CompositeOptimizer(self.optimizers) |
|
|
|
@property |
|
def supports_groups(self): |
|
return True |
|
|
|
@property |
|
def param_groups(self): |
|
for opt in self.optimizers.values(): |
|
for group in opt.param_groups: |
|
yield group |
|
|
|
def get_lr(self): |
|
"""Return the current learning rate.""" |
|
k = ( |
|
"default" |
|
if "default" in self.optimizers |
|
else next(iter(self.optimizers.keys())) |
|
) |
|
return self.optimizers[k].param_groups[0]["lr"] |
|
|
|
def state_dict(self): |
|
"""Return the LR scheduler state dict.""" |
|
return {k: s.state_dict() for k, s in self.optimizers.items()} |
|
|
|
def load_state_dict(self, state_dict, optimizer_overrides=None): |
|
"""Load an LR scheduler state dict.""" |
|
for k, state in state_dict.items(): |
|
if k not in self.optimizers: |
|
|
|
continue |
|
|
|
overrides = ( |
|
optimizer_overrides[k] |
|
if isinstance(optimizer_overrides, dict) and k in optimizer_overrides |
|
else None |
|
) |
|
self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides) |
|
|
|
|
|
class CompositeOptimizer(torch.optim.Optimizer): |
|
def __init__(self, optimizers: Dict[str, FairseqOptimizer]): |
|
self.optimizers = optimizers |
|
|
|
@property |
|
def supports_memory_efficient_fp16(self): |
|
return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values()) |
|
|
|
@property |
|
def supports_flat_params(self): |
|
return all(o.supports_flat_params for o in self.optimizers.values()) |
|
|
|
def step(self, closure=None, groups=None): |
|
"""Performs a single optimization step. |
|
|
|
Args: |
|
closure (callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
loss = closure() |
|
|
|
for k, opt in self.optimizers.items(): |
|
if groups is None or k in groups: |
|
opt.step() |
|
|
|
return loss |
|
|
|
def zero_grad(self): |
|
for opt in self.optimizers.values(): |
|
opt.zero_grad() |
|
|
|
|
|
class CompositeLRScheduler(FairseqLRScheduler): |
|
def __init__(self, lr_schedulers): |
|
super().__init__(None, None) |
|
|
|
self.lr_schedulers = lr_schedulers |
|
|
|
def state_dict(self): |
|
"""Return the LR scheduler state dict.""" |
|
return {k: s.state_dict() for k, s in self.lr_schedulers.items()} |
|
|
|
def load_state_dict(self, state_dict): |
|
"""Load an LR scheduler state dict.""" |
|
for k, state in state_dict.items(): |
|
self.lr_schedulers[k].load_state_dict(state) |
|
|
|
def step_begin_epoch(self, epoch): |
|
"""Update the learning rate at the beginning of the given epoch.""" |
|
for s in self.lr_schedulers.values(): |
|
s.step_begin_epoch(epoch) |
|
|
|
def step(self, epoch, val_loss=None): |
|
"""Update the learning rate at the end of the given epoch.""" |
|
for s in self.lr_schedulers.values(): |
|
s.step(epoch) |
|
|
|
def step_update(self, num_updates): |
|
"""Update the learning rate after each update.""" |
|
return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()} |
|
|