Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# ------------------------------------------------------------------------ | |
# Modified from https://github.com/pytorch/pytorch | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# ------------------------------------------------------------------------ | |
import math | |
import warnings | |
import weakref | |
from collections import Counter | |
from functools import wraps | |
from typing import Callable, List, Optional, Sequence, Union | |
from torch.optim import Optimizer | |
from mmengine.logging import print_log | |
from mmengine.optim import BaseOptimWrapper | |
from mmengine.registry import PARAM_SCHEDULERS | |
INF = int(1e9) | |
OptimizerType = Union[BaseOptimWrapper, Optimizer] | |
class _ParamScheduler: | |
"""Base class for parameter schedulers. | |
It should be inherited by all schedulers that schedule parameters in the | |
optimizer's ``param_groups``. All subclasses should overwrite the | |
``_get_value()`` according to their own schedule strategy. | |
The implementation is motivated by | |
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py. | |
Args: | |
optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resuming without | |
state dict. Default value ``-1`` means the ``step`` function is | |
never be called before. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" # noqa: E501 | |
def __init__(self, | |
optimizer: OptimizerType, | |
param_name: str, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
# Attach optimizer | |
if not isinstance(optimizer, (Optimizer, BaseOptimWrapper)): | |
raise TypeError('``optimizer`` should be an Optimizer,' | |
'but got {}'.format(type(optimizer).__name__)) | |
self.optimizer = optimizer | |
self.param_name = param_name | |
if end <= begin: | |
raise ValueError('end should be larger than begin, but got' | |
' begin={}, end={}'.format(begin, end)) | |
self.begin = begin | |
self.end = end | |
self.by_epoch = by_epoch | |
assert isinstance(last_step, int) and last_step >= -1 | |
# Initialize valid step count and base values | |
if last_step == -1: | |
for group in optimizer.param_groups: | |
# If the param is never be scheduled, record the current value | |
# as the initial value. | |
group.setdefault(f'initial_{param_name}', group[param_name]) | |
else: | |
for i, group in enumerate(optimizer.param_groups): | |
if f'initial_{param_name}' not in group: | |
raise KeyError( | |
f"param 'initial_{param_name}' is not specified " | |
'in param_groups[{}] when resuming an optimizer'. | |
format(i)) | |
self.base_values = [ | |
group[f'initial_{param_name}'] for group in optimizer.param_groups | |
] | |
self.last_step = last_step | |
# Following https://github.com/pytorch/pytorch/issues/20124 | |
# We would like to ensure that `scheduler.step()` is called after | |
# `optimizer.step()` | |
def with_counter(method: Callable): | |
if getattr(method, '_with_counter', False): | |
# `optimizer.step()` has already been replaced, return. | |
return method | |
# Keep a weak reference to the optimizer instance to prevent | |
# cyclic references. | |
instance_ref = weakref.ref(method.__self__) # type: ignore | |
# Get the unbound method for the same purpose. | |
func = method.__func__ # type: ignore | |
cls = instance_ref().__class__ # type: ignore | |
del method | |
def wrapper(*args, **kwargs): | |
instance = instance_ref() | |
instance._global_step += 1 | |
wrapped = func.__get__(instance, cls) | |
return wrapped(*args, **kwargs) | |
# Note that the returned function here is no longer a bound method, | |
# so attributes like `__func__` and `__self__` no longer exist. | |
wrapper._with_counter = True # type: ignore | |
return wrapper | |
# add counter to optimizer | |
self.optimizer.step = with_counter(self.optimizer.step) # type: ignore | |
self.optimizer._global_step = -1 # type: ignore | |
self._global_step = -1 | |
self.verbose = verbose | |
self.step() | |
def state_dict(self) -> dict: | |
"""Returns the state of the scheduler as a :class:`dict`. | |
It contains an entry for every variable in self.__dict__ which is not | |
the optimizer. | |
Returns: | |
dict: scheduler state. | |
""" | |
return { | |
key: value | |
for key, value in self.__dict__.items() if key != 'optimizer' | |
} | |
def load_state_dict(self, state_dict: dict): | |
"""Loads the schedulers state. | |
Args: | |
state_dict (dict): scheduler state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
self.__dict__.update(state_dict) | |
def get_last_value(self): | |
"""Return the last computed value by current scheduler. | |
Returns: | |
list: A list of the last computed value of the optimizer's | |
``param_group``. | |
""" | |
return self._last_value | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
raise NotImplementedError | |
def print_value(self, is_verbose: bool, group: int, value: float): | |
"""Display the current parameter value. | |
Args: | |
is_verbose (bool): Whether to print the value. | |
group (int): The index of the current ``param_group``. | |
value (float): The parameter value. | |
""" | |
if is_verbose: | |
print_log( | |
f'Adjusting parameter value of group {group} to {value:.4e}.', | |
logger='current') | |
def step(self): | |
"""Adjusts the parameter value of each parameter group based on the | |
specified schedule.""" | |
# Raise a warning if old pattern is detected | |
# https://github.com/pytorch/pytorch/issues/20124 | |
if self._global_step == 0: | |
if not hasattr(self.optimizer.step, '_with_counter'): | |
warnings.warn( | |
'Seems like `optimizer.step()` has been overridden after ' | |
'parameter value scheduler initialization. Please, make ' | |
'sure to call `optimizer.step()` before ' | |
'`scheduler.step()`. See more details at ' | |
'https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501 | |
UserWarning) | |
# Just check if there were two first scheduler.step() calls | |
# before optimizer.step() | |
elif self.optimizer._global_step < 0: | |
warnings.warn( | |
'Detected call of `scheduler.step()` before ' | |
'`optimizer.step()`. In PyTorch 1.1.0 and later, you ' | |
'should call them in the opposite order: ' | |
'`optimizer.step()` before `scheduler.step()`. ' | |
'Failure to do this will result in PyTorch skipping ' | |
'the first value of the parameter value schedule. ' | |
'See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501 | |
UserWarning) | |
self._global_step += 1 | |
# Compute parameter value per param group in the effective range | |
if self.begin <= self._global_step < self.end: | |
self.last_step += 1 | |
values = self._get_value() | |
for i, data in enumerate(zip(self.optimizer.param_groups, values)): | |
param_group, value = data | |
param_group[self.param_name] = value | |
self.print_value(self.verbose, i, value) | |
self._last_value = [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
class StepParamScheduler(_ParamScheduler): | |
"""Decays the parameter value of each parameter group by gamma every | |
step_size epochs. Notice that such decay can happen simultaneously with | |
other changes to the parameter value from outside this scheduler. | |
Args: | |
optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
step_size (int): Period of parameter value decay. | |
gamma (float): Multiplicative factor of parameter value decay. | |
Defaults to 0.1. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" | |
def __init__(self, | |
optimizer: OptimizerType, | |
param_name: str, | |
step_size: int, | |
gamma: float = 0.1, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
self.step_size = step_size | |
self.gamma = gamma | |
super().__init__( | |
optimizer=optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
step_size, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
step_size = step_size * epoch_length | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls( | |
*args, | |
step_size=step_size, | |
begin=begin, | |
end=end, | |
by_epoch=by_epoch, | |
**kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
if (self.last_step == 0) or (self.last_step % self.step_size != 0): | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
return [ | |
group[self.param_name] * self.gamma | |
for group in self.optimizer.param_groups | |
] | |
class MultiStepParamScheduler(_ParamScheduler): | |
"""Decays the specified parameter in each parameter group by gamma once the | |
number of epoch reaches one of the milestones. Notice that such decay can | |
happen simultaneously with other changes to the parameter from outside this | |
scheduler. | |
Args: | |
optimizer (BaseOptimWrapper or Optimizer): Wrapped optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
milestones (list): List of epoch indices. Must be increasing. | |
gamma (float): Multiplicative factor of parameter value decay. | |
Defaults to 0.1. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" | |
def __init__(self, | |
optimizer: OptimizerType, | |
param_name: str, | |
milestones: List[int], | |
gamma: float = 0.1, | |
last_step: int = -1, | |
begin: int = 0, | |
end: int = INF, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
self.milestones = Counter(milestones) | |
self.gamma = gamma | |
super().__init__( | |
optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
milestones, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
milestones = [i * epoch_length for i in milestones] | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls( | |
*args, | |
milestones=milestones, | |
begin=begin, | |
end=end, | |
by_epoch=by_epoch, | |
**kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
if self.last_step not in self.milestones: | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
return [ | |
group[self.param_name] * | |
self.gamma**self.milestones[self.last_step] | |
for group in self.optimizer.param_groups | |
] | |
class ConstantParamScheduler(_ParamScheduler): | |
"""Decays the parameter value of each parameter group by a small constant | |
factor until the number of epoch reaches a pre-defined milestone: ``end``. | |
Notice that such decay can happen simultaneously with other changes to the | |
parameter value from outside this scheduler. | |
Args: | |
optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped | |
optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
factor (float): The number we multiply parameter value until the | |
milestone. Defaults to 1./3. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" | |
def __init__(self, | |
optimizer: OptimizerType, | |
param_name: str, | |
factor: float = 1.0 / 3, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
if factor > 1.0 or factor < 0: | |
raise ValueError( | |
'Constant multiplicative factor should between 0 and 1.') | |
self.factor = factor | |
self.total_iters = end - begin - 1 | |
super().__init__( | |
optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
if self.last_step == 0: | |
return [ | |
group[self.param_name] * self.factor | |
for group in self.optimizer.param_groups | |
] | |
if (self.last_step > self.total_iters | |
or (self.last_step != self.total_iters)): | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
if self.last_step == self.total_iters: | |
return [ | |
group[self.param_name] * (1.0 / self.factor) | |
for group in self.optimizer.param_groups | |
] | |
class ExponentialParamScheduler(_ParamScheduler): | |
"""Decays the parameter value of each parameter group by gamma every epoch. | |
Args: | |
optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped | |
optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
gamma (float): Multiplicative factor of parameter value decay. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" | |
def __init__(self, | |
optimizer: OptimizerType, | |
param_name: str, | |
gamma: float, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
self.gamma = gamma | |
super().__init__( | |
optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
if self.last_step == 0: | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
return [ | |
group[self.param_name] * self.gamma | |
for group in self.optimizer.param_groups | |
] | |
class CosineAnnealingParamScheduler(_ParamScheduler): | |
r"""Set the parameter value of each parameter group using a cosine | |
annealing schedule, where :math:`\eta_{max}` is set to the initial value | |
and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: | |
.. math:: | |
\begin{aligned} | |
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 | |
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), | |
& T_{cur} \neq (2k+1)T_{max}; \\ | |
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) | |
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), | |
& T_{cur} = (2k+1)T_{max}. | |
\end{aligned} | |
Notice that because the schedule | |
is defined recursively, the parameter value can be simultaneously modified | |
outside this scheduler by other operators. If the parameter value is set | |
solely by this scheduler, the parameter value at each step becomes: | |
.. math:: | |
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + | |
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) | |
It has been proposed in | |
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this | |
only implements the cosine annealing part of SGDR, and not the restarts. | |
Args: | |
optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped | |
optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
T_max (int, optional): Maximum number of iterations. If not specified, | |
use ``end - begin``. Defaults to None. | |
eta_min (float, optional): Minimum parameter value. Defaults to None. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
eta_min_ratio (float, optional): The ratio of the minimum parameter | |
value to the base parameter value. Either `eta_min` or | |
`eta_min_ratio` should be specified. Defaults to None. | |
New in version 0.3.2. | |
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | |
https://arxiv.org/abs/1608.03983 | |
""" | |
def __init__(self, | |
optimizer: Union[Optimizer, BaseOptimWrapper], | |
param_name: str, | |
T_max: Optional[int] = None, | |
eta_min: Optional[float] = None, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False, | |
eta_min_ratio: Optional[float] = None): | |
# To preserve backwards compatibility | |
if eta_min is None and eta_min_ratio is None: | |
eta_min = 0. | |
assert (eta_min is None) ^ (eta_min_ratio is None), \ | |
'Either `eta_min` or `eta_min_ratio should be specified' | |
self.T_max = T_max or (end - begin) | |
self.eta_min = eta_min | |
self.eta_min_ratio = eta_min_ratio | |
super().__init__( | |
optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
T_max=None, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
if T_max is not None: | |
T_max = T_max * epoch_length | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls( | |
*args, | |
T_max=T_max, | |
begin=begin, | |
end=end, | |
by_epoch=by_epoch, | |
**kwargs) | |
def _get_value(self) -> list: | |
"""Compute value using chainable form of the scheduler.""" | |
def _get_eta_min(base_value): | |
if self.eta_min_ratio is None: | |
return self.eta_min | |
return base_value * self.eta_min_ratio | |
if self.last_step == 0: | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: | |
return [ | |
group[self.param_name] + | |
(base_value - _get_eta_min(base_value)) * | |
(1 - math.cos(math.pi / self.T_max)) / 2 | |
for base_value, group in zip(self.base_values, | |
self.optimizer.param_groups) | |
] | |
return [(1 + math.cos(math.pi * self.last_step / self.T_max)) / | |
(1 + math.cos(math.pi * (self.last_step - 1) / self.T_max)) * | |
(group[self.param_name] - _get_eta_min(base_value)) + | |
_get_eta_min(base_value) for base_value, group in zip( | |
self.base_values, self.optimizer.param_groups)] | |
class LinearParamScheduler(_ParamScheduler): | |
"""Decays the parameter value of each parameter group by linearly changing | |
small multiplicative factor until the number of epoch reaches a pre-defined | |
milestone: ``end``. | |
Notice that such decay can happen simultaneously with other changes to the | |
parameter value from outside this scheduler. | |
Args: | |
optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped | |
optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
start_factor (float): The number we multiply parameter value in the | |
first epoch. The multiplication factor changes towards end_factor | |
in the following epochs. Defaults to 1./3. | |
end_factor (float): The number we multiply parameter value at the end | |
of linear changing process. Defaults to 1.0. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" | |
def __init__(self, | |
optimizer: Union[Optimizer, BaseOptimWrapper], | |
param_name: str, | |
start_factor: float = 1.0 / 3, | |
end_factor: float = 1.0, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
if start_factor > 1.0 or start_factor < 0: | |
raise ValueError( | |
'Starting multiplicative factor should between 0 and 1.') | |
if end_factor > 1.0 or end_factor < 0: | |
raise ValueError( | |
'Ending multiplicative factor should between 0 and 1.') | |
self.start_factor = start_factor | |
self.end_factor = end_factor | |
self.total_iters = end - begin - 1 | |
super().__init__( | |
optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
if self.last_step == 0: | |
return [ | |
group[self.param_name] * self.start_factor | |
for group in self.optimizer.param_groups | |
] | |
return [ | |
group[self.param_name] * | |
(1. + (self.end_factor - self.start_factor) / | |
(self.total_iters * self.start_factor + (self.last_step - 1) * | |
(self.end_factor - self.start_factor))) | |
for group in self.optimizer.param_groups | |
] | |
class PolyParamScheduler(_ParamScheduler): | |
"""Decays the parameter value of each parameter group in a polynomial decay | |
scheme. | |
Notice that such decay can happen simultaneously with other changes to the | |
parameter value from outside this scheduler. | |
Args: | |
optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped | |
optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
eta_min (float): Minimum parameter value at the end of scheduling. | |
Defaults to 0. | |
power (float): The power of the polynomial. Defaults to 1.0. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" | |
def __init__(self, | |
optimizer: Union[Optimizer, BaseOptimWrapper], | |
param_name: str, | |
eta_min: float = 0, | |
power: float = 1.0, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
self.eta_min = eta_min | |
self.power = power | |
self.total_iters = end - begin - 1 | |
super().__init__( | |
optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
if self.last_step == 0: | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
return [(group[self.param_name] - self.eta_min) * | |
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power + | |
self.eta_min for group in self.optimizer.param_groups] | |
class OneCycleParamScheduler(_ParamScheduler): | |
r"""Sets the parameters of each parameter group according to the | |
1cycle learning rate policy. The 1cycle policy anneals the learning | |
rate from an initial learning rate to some maximum learning rate and then | |
from that maximum learning rate to some minimum learning rate much lower | |
than the initial learning rate. | |
This policy was initially described in the paper `Super-Convergence: | |
Very Fast Training of Neural Networks Using Large Learning Rates`_. | |
The 1cycle learning rate policy changes the learning rate after every | |
batch. `step` should be called after a batch has been used for training. | |
This scheduler is not chainable. | |
Note also that the total number of steps in the cycle can be determined in | |
one of two ways (listed in order of precedence): | |
#. A value for total_steps is explicitly provided. | |
#. If total_steps is not defined, begin and end of the ParamSchedul will | |
works for it. In this case, the number of total steps is inferred by | |
total_steps = end - begin | |
The default behaviour of this scheduler follows the fastai implementation | |
of 1cycle, which claims that "unpublished work has shown even better | |
results by using only two phases". To mimic the behaviour of the original | |
paper instead, set ``three_phase=True``. | |
Args: | |
optimizer (Optimizer): Wrapped optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
eta_max (float or list): Upper parameter value boundaries in the cycle | |
for each parameter group. | |
total_steps (int): The total number of steps in the cycle. Note that | |
if a value is not provided here, then it will be equal to | |
``end - begin``. Defaults to None | |
pct_start (float): The percentage of the cycle (in number of steps) | |
spent increasing the learning rate. | |
Defaults to 0.3 | |
anneal_strategy (str): {'cos', 'linear'} | |
Specifies the annealing strategy: "cos" for cosine annealing, | |
"linear" for linear annealing. | |
Defaults to 'cos' | |
div_factor (float): Determines the initial learning rate via | |
initial_param = eta_max/div_factor | |
Defaults to 25 | |
final_div_factor (float): Determines the minimum learning rate via | |
eta_min = initial_param/final_div_factor | |
Defaults to 1e4 | |
three_phase (bool): If ``True``, use a third phase of the schedule to | |
annihilate the learning rate according to 'final_div_factor' | |
instead of modifying the second phase (the first two phases will be | |
symmetrical about the step indicated by 'pct_start'). | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: | |
https://arxiv.org/abs/1708.07120 | |
""" # noqa E501 | |
def __init__(self, | |
optimizer: Union[Optimizer, BaseOptimWrapper], | |
param_name: str, | |
eta_max: float = 0, | |
total_steps: Optional[int] = None, | |
pct_start: float = 0.3, | |
anneal_strategy: str = 'cos', | |
div_factor: float = 25., | |
final_div_factor: float = 1e4, | |
three_phase: bool = False, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
assert param_name == 'lr', ('OneCycle only works for learning rate ' | |
'updating, but got patam_name as ' | |
f'{param_name}') | |
self.eta_max = eta_max | |
self.div_factor = div_factor | |
self.final_div_factor = final_div_factor | |
# Validate total_steps | |
if total_steps is not None: | |
if total_steps <= 0 or not isinstance(total_steps, int): | |
raise ValueError('Expected positive integer total_steps, ' | |
f'but got {total_steps}') | |
self.total_steps = total_steps | |
else: | |
self.total_steps = self.end - self.begin | |
# Validate pct_start | |
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): | |
raise ValueError('Expected float between 0 and 1 pct_start, ' | |
f'but got {pct_start}') | |
# Validate anneal_strategy | |
if anneal_strategy not in ['cos', 'linear']: | |
raise ValueError( | |
'anneal_strategy must by one of "cos" or "linear", ' | |
f'instead got {anneal_strategy}') | |
elif anneal_strategy == 'cos': | |
self.anneal_func = self._annealing_cos | |
elif anneal_strategy == 'linear': | |
self.anneal_func = self._annealing_linear | |
if three_phase: | |
self._schedule_phases = [ | |
{ | |
'end_step': float(pct_start * self.total_steps) - 1, | |
f'start_{param_name}': f'initial_{param_name}', | |
f'end_{param_name}': f'max_{param_name}' | |
}, | |
{ | |
'end_step': float(2 * pct_start * self.total_steps) - 2, | |
f'start_{param_name}': f'max_{param_name}', | |
f'end_{param_name}': f'initial_{param_name}' | |
}, | |
{ | |
'end_step': self.total_steps - 1, | |
f'start_{param_name}': f'initial_{param_name}', | |
f'end_{param_name}': f'min_{param_name}' | |
}, | |
] | |
else: | |
self._schedule_phases = [ | |
{ | |
'end_step': float(pct_start * self.total_steps) - 1, | |
f'start_{param_name}': f'initial_{param_name}', | |
f'end_{param_name}': f'max_{param_name}' | |
}, | |
{ | |
'end_step': self.total_steps - 1, | |
f'start_{param_name}': f'max_{param_name}', | |
f'end_{param_name}': f'min_{param_name}' | |
}, | |
] | |
# Initialize parameters | |
max_values = self._format_param(f'max_{param_name}', optimizer, | |
eta_max) | |
if last_step == -1: | |
for idx, group in enumerate(optimizer.param_groups): | |
group[f'initial_{param_name}'] = max_values[idx] / div_factor | |
group[f'max_{param_name}'] = max_values[idx] | |
group[f'min_{param_name}'] = \ | |
group[f'initial_{param_name}'] / final_div_factor | |
super().__init__( | |
optimizer=optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def _format_param(self, name, optimizer, param): | |
"""Return correctly formatted lr/momentum for each param group.""" | |
if isinstance(param, (list, tuple)): | |
if len(param) != len(optimizer.param_groups): | |
raise ValueError( | |
f'expected {len(optimizer.param_groups)} values ' | |
f'for {name}, got {len(param)}') | |
return param | |
else: | |
return [param] * len(optimizer.param_groups) | |
def _annealing_cos(start, end, pct): | |
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0.""" | |
cos_out = math.cos(math.pi * pct) + 1 | |
return end + (start - end) / 2.0 * cos_out | |
def _annealing_linear(start, end, pct): | |
"""Linearly anneal from `start` to `end` as pct goes from 0.0 to | |
1.0.""" | |
return (end - start) * pct + start | |
def build_iter_from_epoch(cls, | |
*args, | |
begin=0, | |
end=INF, | |
total_steps=None, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
by_epoch = False | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
if total_steps is not None: | |
total_steps = total_steps * epoch_length | |
return cls( | |
*args, | |
begin=begin, | |
end=end, | |
total_steps=total_steps, | |
by_epoch=by_epoch, | |
**kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
params = [] | |
step_num = self.last_step | |
if step_num > self.total_steps: | |
raise ValueError( | |
f'Tried to step {step_num + 1} times. ' | |
f'The specified number of total steps is {self.total_steps}') | |
for group in self.optimizer.param_groups: | |
start_step = 0 | |
for i, phase in enumerate(self._schedule_phases): | |
end_step = phase['end_step'] | |
if step_num <= end_step or i == len(self._schedule_phases) - 1: | |
pct = (step_num - start_step) / (end_step - start_step) | |
computed_param = self.anneal_func( | |
group[phase['start_' + self.param_name]], | |
group[phase['end_' + self.param_name]], pct) | |
break | |
start_step = phase['end_step'] | |
params.append(computed_param) | |
return params | |
class CosineRestartParamScheduler(_ParamScheduler): | |
"""Sets the parameters of each parameter group according to the cosine | |
annealing with restarts scheme. The cosine restart policy anneals the | |
parameter from the initial value to `eta_min` with a cosine annealing | |
schedule and then restarts another period from the maximum value multiplied | |
with `restart_weight`. | |
Args: | |
optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped | |
optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
periods (list[int]): Periods for each cosine anneling cycle. | |
restart_weights (list[float]): Restart weights at each | |
restart iteration. Defaults to [1]. | |
eta_min (float, optional): Minimum parameter value at the end of | |
scheduling. Defaults to None. | |
eta_min_ratio (float, optional): The ratio of minimum parameter value | |
to the base parameter value. Either `eta_min` or `eta_min_ratio` | |
should be specified. Defaults to None. | |
begin (int): Step at which to start updating the parameters. | |
Defaults to 0. | |
end (int): Step at which to stop updating the parameters. | |
Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
""" | |
def __init__(self, | |
optimizer: Union[Optimizer, BaseOptimWrapper], | |
param_name: str, | |
periods: List[int], | |
restart_weights: Sequence[float] = (1, ), | |
eta_min: Optional[float] = None, | |
eta_min_ratio: Optional[float] = None, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
assert (eta_min is None) ^ (eta_min_ratio is None) | |
self.periods = periods | |
self.eta_min = eta_min | |
self.eta_min_ratio = eta_min_ratio | |
self.restart_weights = restart_weights | |
assert (len(self.periods) == len(self.restart_weights) | |
), 'periods and restart_weights should have the same length.' | |
self.cumulative_periods = [ | |
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) | |
] | |
super().__init__( | |
optimizer, | |
param_name=param_name, | |
begin=begin, | |
end=end, | |
last_step=last_step, | |
by_epoch=by_epoch, | |
verbose=verbose) | |
def build_iter_from_epoch(cls, | |
*args, | |
periods, | |
begin=0, | |
end=INF, | |
by_epoch=True, | |
epoch_length=None, | |
**kwargs): | |
"""Build an iter-based instance of this scheduler from an epoch-based | |
config.""" | |
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ | |
'be converted to iter-based.' | |
assert epoch_length is not None and epoch_length > 0, \ | |
f'`epoch_length` must be a positive integer, ' \ | |
f'but got {epoch_length}.' | |
periods = [p * epoch_length for p in periods] | |
by_epoch = False | |
begin = int(begin * epoch_length) | |
if end != INF: | |
end = int(end * epoch_length) | |
return cls( | |
*args, | |
periods=periods, | |
begin=begin, | |
end=end, | |
by_epoch=by_epoch, | |
**kwargs) | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
idx = self.get_position_from_periods(self.last_step, | |
self.cumulative_periods) | |
# if current step is not in the periods, return origin parameters | |
if idx is None: | |
return [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
current_weight = self.restart_weights[idx] | |
nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] | |
current_periods = self.periods[idx] | |
step = self.last_step - nearest_restart | |
values = [] | |
for base_value, group in zip(self.base_values, | |
self.optimizer.param_groups): | |
eta_max = base_value * current_weight | |
if self.eta_min_ratio is None: | |
eta_min = self.eta_min | |
else: | |
eta_min = base_value * self.eta_min_ratio | |
if step == 0: | |
values.append(eta_max) | |
else: | |
values.append( | |
(1 + math.cos(math.pi * step / current_periods)) / | |
(1 + math.cos(math.pi * (step - 1) / current_periods)) * | |
(group[self.param_name] - eta_min) + eta_min) | |
return values | |
def get_position_from_periods( | |
iteration: int, cumulative_periods: List[int]) -> Optional[int]: | |
"""Get the position from a period list. | |
It will return the index of the right-closest number in the period | |
list. | |
For example, the cumulative_periods = [100, 200, 300, 400], | |
if iteration == 50, return 0; | |
if iteration == 210, return 2; | |
if iteration == 300, return 3. | |
Args: | |
iteration (int): Current iteration. | |
cumulative_periods (list[int]): Cumulative period list. | |
Returns: | |
Optional[int]: The position of the right-closest number in the | |
period list. If not in the period, return None. | |
""" | |
for i, period in enumerate(cumulative_periods): | |
if iteration < period: | |
return i | |
return None | |
class ReduceOnPlateauParamScheduler(_ParamScheduler): | |
"""Reduce the parameters of each parameter group when a metric has stopped | |
improving. Models often benefit from reducing the parameters by a factor of | |
2-10 once learning stagnates. This scheduler reads a metrics quantity and | |
if no improvement is seen for a ``patience`` number of epochs, the | |
parameters are reduced. | |
The implementation is motivated by `PyTorch ReduceLROnPlateau`_. | |
Args: | |
optimizer (Optimizer or BaseOptimWrapper): optimizer or Wrapped | |
optimizer. | |
param_name (str): Name of the parameter to be adjusted, such as | |
``lr``, ``momentum``. | |
monitor (str): The name of the metric to measure whether | |
the performance of the model is improved. | |
rule (str): One of `less`, `greater`. In `less` rule, parameters will | |
be reduced when the quantity monitored has stopped | |
decreasing; in `greater` rule it will be reduced when the | |
quantity monitored has stopped increasing. Defaults to 'less'. | |
The ``rule`` is the renaming of ``mode`` in pytorch. | |
factor (float): Factor by which the parameters will be | |
reduced. new_param = param * factor. Defaults to 0.1. | |
patience (int): Number of epochs with no improvement after | |
which parameters will be reduced. For example, if | |
``patience = 2``, then we will ignore the first 2 epochs | |
with no improvement, and will only decrease the parameters after | |
the 3rd epoch if the monitor value still hasn't improved then. | |
Defaults to 10. | |
threshold (float): Threshold for measuring the new optimum, | |
to only focus on significant changes. Defaults to 1e-4. | |
threshold_rule (str): One of `rel`, `abs`. In `rel` rule, | |
dynamic_threshold = best * ( 1 + threshold ) in 'greater' | |
rule or best * ( 1 - threshold ) in `less` rule. | |
In `abs` rule, dynamic_threshold = best + threshold in | |
`greater` rule or best - threshold in `less` rule. | |
Defaults to 'rel'. | |
cooldown (int): Number of epochs to wait before resuming | |
normal operation after parameters have been reduced. Defaults to 0. | |
min_value (float or list[float]): A scalar or a sequence of scalars. | |
A lower bound on the parameters of each parameter group | |
respectively. Defaults to 0. . | |
eps (float): Minimal decay applied to parameters. If the difference | |
between new and old parameters are smaller than eps, the update is | |
ignored. Defaults to 1e-8. | |
begin (int): Step at which to start triggering the scheduler | |
to monitor in val within the interval calculated | |
according to epoch of training. Defaults to 0. | |
end (int): Step at which to stop triggering the scheduler | |
to monitor in val within the interval calculated | |
according to epoch of training. Defaults to INF. | |
last_step (int): The index of last step. Used for resume without | |
state dict. Defaults to -1. | |
by_epoch (bool): Whether the scheduled parameters are updated by | |
epochs. Defaults to True. | |
verbose (bool): Whether to print the value for each update. | |
Defaults to False. | |
.. _PyTorch ReduceLROnPlateau: | |
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py | |
""" | |
need_val_args = True | |
def __init__(self, | |
optimizer: OptimizerType, | |
param_name: str, | |
monitor: str = 'loss', | |
rule: str = 'less', | |
factor: float = 0.1, | |
patience: int = 10, | |
threshold: float = 1e-4, | |
threshold_rule: str = 'rel', | |
cooldown: int = 0, | |
min_value: Union[float, Sequence[float]] = 0., | |
eps: float = 1e-8, | |
begin: int = 0, | |
end: int = INF, | |
last_step: int = -1, | |
by_epoch: bool = True, | |
verbose: bool = False): | |
# Attach optimizer | |
if not isinstance(optimizer, (Optimizer, BaseOptimWrapper)): | |
raise TypeError('``optimizer`` should be an Optimizer,' | |
'but got {}'.format(type(optimizer).__name__)) | |
self.optimizer = optimizer | |
self.param_name = param_name | |
if end <= begin: | |
raise ValueError('end should be larger than begin, but got' | |
' begin={}, end={}'.format(begin, end)) | |
self.begin = begin | |
self.end = end | |
assert by_epoch, \ | |
f'Now {type(self).__name__} only support by_epoch=True' | |
self.by_epoch = by_epoch | |
assert isinstance(last_step, int) and last_step >= -1 | |
# Initialize valid step count and base values | |
if last_step == -1: | |
for group in optimizer.param_groups: | |
# If the param is never be scheduled, record the current value | |
# as the initial value. | |
group.setdefault(f'initial_{param_name}', group[param_name]) | |
else: | |
for i, group in enumerate(optimizer.param_groups): | |
if f'initial_{param_name}' not in group: | |
raise KeyError( | |
f"param 'initial_{param_name}' is not specified " | |
'in param_groups[{}] when resuming an optimizer'. | |
format(i)) | |
self.last_step = last_step | |
self._global_step = 0 | |
self.verbose = verbose | |
if factor >= 1.0: | |
raise ValueError('Factor should be < 1.0.') | |
self.factor = factor | |
# This code snippet handles compatibility with the optimizer wrapper. | |
# The optimizer wrapper includes an additional parameter to record the | |
# base learning rate (lr) which is not affected by the paramwise_cfg. | |
# By retrieving the base lr, we can obtain the actual base lr that | |
# reflects the learning progress. | |
if isinstance(optimizer, BaseOptimWrapper): | |
raw_optimizer = optimizer.optimizer | |
else: | |
raw_optimizer = optimizer | |
if isinstance(min_value, (list, tuple)): | |
if len(min_value) != len(raw_optimizer.param_groups): | |
raise ValueError('expected {} min_lrs, got {}'.format( | |
len(raw_optimizer.param_groups), len(min_value))) | |
self.min_values = list(min_value) | |
# Consider the `min_value` of the last param_groups | |
# as the base setting. And we only add this value when | |
# the optimizer is OptimWrapper. | |
if isinstance(optimizer, BaseOptimWrapper) and \ | |
optimizer.base_param_settings is not None: # type: ignore | |
self.min_values.append(self.min_values[-1]) | |
else: | |
self.min_values = [min_value] * len( # type: ignore | |
optimizer.param_groups) | |
self.patience = patience | |
self.cooldown = cooldown | |
self.cooldown_counter = 0 | |
self.rule_worse = None # the worse value for the chosen mode | |
self.best = None | |
self.num_bad_epochs = 0 | |
self.eps = eps | |
self.monitor = monitor | |
self._init_is_better( | |
rule=rule, threshold=threshold, threshold_rule=threshold_rule) | |
self._reset() | |
# remove call self.step() and init self._global_step = 0 | |
self._last_value = [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
def step(self, metrics=None): | |
"""Adjusts the parameter value of each parameter group based on the | |
specified schedule. | |
Args: | |
metrics (Dict[str, float], optional): Evaluation results of all | |
metrics on validation dataset. The keys are the names of the | |
metrics, and the values are corresponding results. | |
Defaults to None. | |
""" | |
if metrics is None: | |
# only to count self._global_step | |
self._global_step += 1 | |
return | |
if not isinstance(metrics, dict): | |
raise TypeError('metrics type should be dict,' | |
f' but got type {type(metrics)}') | |
# Compute parameter value per param group in the effective range | |
if self.begin <= self._global_step < self.end: | |
self.last_step += 1 | |
# convert `metric` to float, in case it's a zero-dim Tensor | |
metric = metrics.get(self.monitor, None) | |
if metric is not None: | |
if self._is_better(metric, self.best): | |
self.best = metric | |
self.num_bad_epochs = 0 | |
else: | |
self.num_bad_epochs += 1 | |
if self._in_cooldown(): | |
self.cooldown_counter -= 1 | |
self.num_bad_epochs = 0 # ignore bad epochs in cooldown | |
if self.num_bad_epochs > self.patience: | |
values = self._get_value() | |
for i, data in enumerate( | |
zip(self.optimizer.param_groups, values)): | |
param_group, value = data | |
if param_group[self.param_name] - value > self.eps: | |
param_group[self.param_name] = value | |
self.print_value(self.verbose, i, value) | |
self.cooldown_counter = self.cooldown | |
self.num_bad_epochs = 0 | |
else: | |
raise KeyError(f'Excepted key in {list(metrics.keys())},' | |
f' but got key {self.monitor} is not in dict') | |
self._last_value = [ | |
group[self.param_name] for group in self.optimizer.param_groups | |
] | |
def print_value(self, is_verbose: bool, group: int, value: float) -> None: | |
"""Display the current parameter value. | |
Args: | |
is_verbose (bool): Whether to print the value. | |
group (int): The index of the current ``param_group``. | |
value (float): The parameter value. | |
""" | |
if is_verbose: | |
step_name = 'epoch' if self.by_epoch else 'iter' | |
print_log( | |
f'Adjusting parameter value of group {group} to {value:.4e} ' | |
f'in {step_name} {self.last_step}.', | |
logger='current') | |
def _get_value(self): | |
"""Compute value using chainable form of the scheduler.""" | |
values = [ | |
float(group[self.param_name]) * self.factor | |
for group in self.optimizer.param_groups | |
] | |
return [max(v, min_v) for v, min_v in zip(values, self.min_values)] | |
def _in_cooldown(self): | |
"""Judge whether it is in cooldown.""" | |
return self.cooldown_counter > 0 | |
def _is_better(self, a, best): | |
"""Judge whether the monitor value is better.""" | |
if self.rule == 'less' and self.threshold_rule == 'rel': | |
rel_epsilon = 1. - self.threshold | |
return a < best * rel_epsilon | |
elif self.rule == 'less' and self.threshold_rule == 'abs': | |
return a < best - self.threshold | |
elif self.rule == 'greater' and self.threshold_rule == 'rel': | |
rel_epsilon = self.threshold + 1. | |
return a > best * rel_epsilon | |
else: # rule == 'greater' and epsilon_mode == 'abs': | |
return a > best + self.threshold | |
def _init_is_better(self, rule, threshold, threshold_rule): | |
"""Initialize rule and its associated values.""" | |
if threshold < 0: | |
raise ValueError(f'threshold {threshold} should be >= 0.') | |
if rule not in {'less', 'greater'}: | |
raise ValueError(f'mode {rule} is unknown!') | |
if threshold_rule not in {'rel', 'abs'}: | |
raise ValueError(f'threshold mode {threshold_rule}' | |
' is unknown!') | |
if rule == 'less': | |
self.rule_worse = INF | |
else: # rule == 'greater': | |
self.rule_worse = -INF | |
self.rule = rule | |
self.threshold = threshold | |
self.threshold_rule = threshold_rule | |
def _reset(self): | |
"""Resets num_bad_epochs counter and cooldown counter.""" | |
self.best = self.rule_worse | |
self.cooldown_counter = 0 | |
self.num_bad_epochs = 0 | |