Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Dict, Optional, Union | |
from mmengine.optim import _ParamScheduler | |
from mmengine.registry import HOOKS | |
from mmengine.utils import is_list_of | |
from .hook import Hook | |
DATA_BATCH = Optional[Union[dict, tuple, list]] | |
class ParamSchedulerHook(Hook): | |
"""A hook to update some hyper-parameters in optimizer, e.g., learning rate | |
and momentum.""" | |
priority = 'LOW' | |
def after_train_iter(self, | |
runner, | |
batch_idx: int, | |
data_batch: DATA_BATCH = None, | |
outputs: Optional[dict] = None) -> None: | |
"""Call step function for each scheduler after each training iteration. | |
Args: | |
runner (Runner): The runner of the training process. | |
batch_idx (int): The index of the current batch in the train loop. | |
data_batch (dict or tuple or list, optional): Data from dataloader. | |
In order to keep this interface consistent with other hooks, | |
we keep ``data_batch`` here. | |
outputs (dict, optional): Outputs from model. | |
In order to keep this interface consistent with other hooks, we | |
keep ``data_batch`` here. | |
""" | |
if runner.param_schedulers is None: | |
return | |
def step(param_schedulers): | |
assert isinstance(param_schedulers, list) | |
for scheduler in param_schedulers: | |
if not scheduler.by_epoch: | |
scheduler.step() | |
if isinstance(runner.param_schedulers, list): | |
step(runner.param_schedulers) | |
elif isinstance(runner.param_schedulers, dict): | |
for param_schedulers in runner.param_schedulers.values(): | |
step(param_schedulers) | |
else: | |
raise TypeError( | |
'runner.param_schedulers should be list of ParamScheduler or ' | |
'a dict containing list of ParamScheduler, ' | |
f'but got {runner.param_schedulers}') | |
def after_train_epoch(self, runner) -> None: | |
"""Call step function for each scheduler after each training epoch. | |
Args: | |
runner (Runner): The runner of the training process. | |
""" | |
if runner.param_schedulers is None: | |
return | |
def step(param_schedulers): | |
assert isinstance(param_schedulers, list) | |
for scheduler in param_schedulers: | |
if scheduler.by_epoch: | |
scheduler.step() | |
if isinstance(runner.param_schedulers, list): | |
step(runner.param_schedulers) | |
elif isinstance(runner.param_schedulers, dict): | |
for param_schedulers in runner.param_schedulers.values(): | |
step(param_schedulers) | |
else: | |
raise TypeError( | |
'runner.param_schedulers should be list of ParamScheduler or ' | |
'a dict containing list of ParamScheduler, ' | |
f'but got {runner.param_schedulers}') | |
def after_val_epoch(self, | |
runner, | |
metrics: Optional[Dict[str, float]] = None) -> None: | |
"""Call step function for each scheduler which has attribute | |
``need_val_args`` after each validation epoch. | |
Args: | |
runner (Runner): The runner of the validation process. | |
metrics (Dict[str, float], optional): Evaluation results of all | |
metrics on validation dataset. The keys are the names of the | |
metrics, and the values are corresponding results. | |
Note: | |
if ``runner.param_schedulers`` is not built before, | |
the hook ``after_val_epoch`` will be skipped. | |
""" | |
if runner.param_schedulers is None: | |
return | |
# avoid counting scheduler._global_step | |
# it has counted in after_train_* hook | |
if metrics is None: | |
return | |
def step(param_schedulers): | |
# check param_schedulers is list and built | |
if not is_list_of(param_schedulers, _ParamScheduler): | |
return | |
for scheduler in param_schedulers: | |
if (scheduler.by_epoch | |
and getattr(scheduler, 'need_val_args', False)): | |
scheduler.step(metrics) | |
if isinstance(runner.param_schedulers, list): | |
step(runner.param_schedulers) | |
elif isinstance(runner.param_schedulers, dict): | |
for param_schedulers in runner.param_schedulers.values(): | |
step(param_schedulers) | |
else: | |
raise TypeError( | |
'runner.param_schedulers should be list of ParamScheduler or ' | |
'a dict containing list of ParamScheduler, ' | |
f'but got {runner.param_schedulers}') | |