File size: 7,582 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
from .default_helper import deep_merge_dicts
from easydict import EasyDict
class Scheduler(object):
"""
Overview:
Update learning parameters when the trueskill metrics has stopped improving.
For example, models often benefits from reducing entropy weight once the learning process stagnates.
This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs,
the corresponding parameter is increased or decreased, which decides on the 'schedule_mode'.
Arguments:
- schedule_flag (:obj:`bool`): Indicates whether to use scheduler in training pipeline.
Default: False
- schedule_mode (:obj:`str`): One of 'reduce', 'add','multi','div'. The schecule_mode
decides the way of updating the parameters. Default:'reduce'.
- factor (:obj:`float`) : Amount (greater than 0) by which the parameter will be
increased/decreased. Default: 0.05
- change_range (:obj:`list`): Indicates the minimum and maximum value
the parameter can reach respectively. Default: [-1,1]
- threshold (:obj:`float`): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
- optimize_mode (:obj:`str`): One of 'min', 'max', which indicates the sign of
optimization objective. Dynamic_threshold = last_metrics + threshold in `max`
mode or last_metrics - threshold in `min` mode. Default: 'min'
- patience (:obj:`int`): Number of epochs with no improvement after which
the parameter will be updated. For example, if `patience = 2`, then we
will ignore the first 2 epochs with no improvement, and will only update
the parameter after the 3rd epoch if the metrics still hasn't improved then.
Default: 10.
- cooldown (:obj:`int`): Number of epochs to wait before resuming
normal operation after the parameter has been updated. Default: 0.
Interfaces:
__init__, update_param, step
Property:
in_cooldown, is_better
"""
config = dict(
schedule_flag=False,
schedule_mode='reduce',
factor=0.05,
change_range=[-1, 1],
threshold=1e-4,
optimize_mode='min',
patience=10,
cooldown=0,
)
def __init__(self, merged_scheduler_config: EasyDict) -> None:
"""
Overview:
Initialize the scheduler.
Arguments:
- merged_scheduler_config (:obj:`EasyDict`): the scheduler config, which merges the user
config and defaul config
"""
schedule_mode = merged_scheduler_config.schedule_mode
factor = merged_scheduler_config.factor
change_range = merged_scheduler_config.change_range
threshold = merged_scheduler_config.threshold
optimize_mode = merged_scheduler_config.optimize_mode
patience = merged_scheduler_config.patience
cooldown = merged_scheduler_config.cooldown
assert schedule_mode in [
'reduce', 'add', 'multi', 'div'
], 'The schedule mode should be one of [\'reduce\', \'add\', \'multi\',\'div\']'
self.schedule_mode = schedule_mode
assert isinstance(factor, (float, int)), 'The factor should be a float/int number '
assert factor > 0, 'The factor should be greater than 0'
self.factor = float(factor)
assert isinstance(change_range,
list) and len(change_range) == 2, 'The change_range should be a list with 2 float numbers'
assert (isinstance(change_range[0], (float, int))) and (
isinstance(change_range[1], (float, int))
), 'The change_range should be a list with 2 float/int numbers'
assert change_range[0] < change_range[1], 'The first num should be smaller than the second num'
self.change_range = change_range
assert isinstance(threshold, (float, int)), 'The threshold should be a float/int number'
self.threshold = threshold
assert optimize_mode in ['min', 'max'], 'The optimize_mode should be one of [\'min\', \'max\']'
self.optimize_mode = optimize_mode
assert isinstance(patience, int), 'The patience should be a integer greater than or equal to 0'
assert patience >= 0, 'The patience should be a integer greater than or equal to 0'
self.patience = patience
assert isinstance(cooldown, int), 'The cooldown_counter should be a integer greater than or equal to 0'
assert cooldown >= 0, 'The cooldown_counter should be a integer greater than or equal to 0'
self.cooldown = cooldown
self.cooldown_counter = cooldown
self.last_metrics = None
self.bad_epochs_num = 0
def step(self, metrics: float, param: float) -> float:
"""
Overview:
Decides whether to update the scheduled parameter
Args:
- metrics (:obj:`float`): current input metrics
- param (:obj:`float`): parameter need to be updated
Returns:
- step_param (:obj:`float`): parameter after one step
"""
assert isinstance(metrics, float), 'The metrics should be converted to a float number'
cur_metrics = metrics
if self.is_better(cur_metrics):
self.bad_epochs_num = 0
else:
self.bad_epochs_num += 1
self.last_metrics = cur_metrics
if self.in_cooldown:
self.cooldown_counter -= 1
self.bad_epochs_num = 0 # ignore any bad epochs in cooldown
if self.bad_epochs_num > self.patience:
param = self.update_param(param)
self.cooldown_counter = self.cooldown
self.bad_epochs_num = 0
return param
def update_param(self, param: float) -> float:
"""
Overview:
update the scheduling parameter
Args:
- param (:obj:`float`): parameter need to be updated
Returns:
- updated param (:obj:`float`): parameter after updating
"""
schedule_fn = {
'reduce': lambda x, y, z: max(x - y, z[0]),
'add': lambda x, y, z: min(x + y, z[1]),
'multi': lambda x, y, z: min(x * y, z[1]) if y >= 1 else max(x * y, z[0]),
'div': lambda x, y, z: max(x / y, z[0]) if y >= 1 else min(x / y, z[1]),
}
schedule_mode_list = list(schedule_fn.keys())
if self.schedule_mode in schedule_mode_list:
return schedule_fn[self.schedule_mode](param, self.factor, self.change_range)
else:
raise KeyError("invalid schedule_mode({}) in {}".format(self.schedule_mode, schedule_mode_list))
@property
def in_cooldown(self) -> bool:
"""
Overview:
Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler
will ignore any bad epochs.
"""
return self.cooldown_counter > 0
def is_better(self, cur: float) -> bool:
"""
Overview:
Checks whether the current metrics is better than last matric with respect to threshold.
Args:
- cur (:obj:`float`): current metrics
"""
if self.last_metrics is None:
return True
elif self.optimize_mode == 'min':
return cur < self.last_metrics - self.threshold
elif self.optimize_mode == 'max':
return cur > self.last_metrics + self.threshold
|