from .default_helper import deep_merge_dicts from easydict import EasyDict class Scheduler(object): """ Overview: Update learning parameters when the trueskill metrics has stopped improving. For example, models often benefits from reducing entropy weight once the learning process stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs, the corresponding parameter is increased or decreased, which decides on the 'schedule_mode'. Arguments: - schedule_flag (:obj:`bool`): Indicates whether to use scheduler in training pipeline. Default: False - schedule_mode (:obj:`str`): One of 'reduce', 'add','multi','div'. The schecule_mode decides the way of updating the parameters. Default:'reduce'. - factor (:obj:`float`) : Amount (greater than 0) by which the parameter will be increased/decreased. Default: 0.05 - change_range (:obj:`list`): Indicates the minimum and maximum value the parameter can reach respectively. Default: [-1,1] - threshold (:obj:`float`): Threshold for measuring the new optimum, to only focus on significant changes. Default: 1e-4. - optimize_mode (:obj:`str`): One of 'min', 'max', which indicates the sign of optimization objective. Dynamic_threshold = last_metrics + threshold in `max` mode or last_metrics - threshold in `min` mode. Default: 'min' - patience (:obj:`int`): Number of epochs with no improvement after which the parameter will be updated. For example, if `patience = 2`, then we will ignore the first 2 epochs with no improvement, and will only update the parameter after the 3rd epoch if the metrics still hasn't improved then. Default: 10. - cooldown (:obj:`int`): Number of epochs to wait before resuming normal operation after the parameter has been updated. Default: 0. Interfaces: __init__, update_param, step Property: in_cooldown, is_better """ config = dict( schedule_flag=False, schedule_mode='reduce', factor=0.05, change_range=[-1, 1], threshold=1e-4, optimize_mode='min', patience=10, cooldown=0, ) def __init__(self, merged_scheduler_config: EasyDict) -> None: """ Overview: Initialize the scheduler. Arguments: - merged_scheduler_config (:obj:`EasyDict`): the scheduler config, which merges the user config and defaul config """ schedule_mode = merged_scheduler_config.schedule_mode factor = merged_scheduler_config.factor change_range = merged_scheduler_config.change_range threshold = merged_scheduler_config.threshold optimize_mode = merged_scheduler_config.optimize_mode patience = merged_scheduler_config.patience cooldown = merged_scheduler_config.cooldown assert schedule_mode in [ 'reduce', 'add', 'multi', 'div' ], 'The schedule mode should be one of [\'reduce\', \'add\', \'multi\',\'div\']' self.schedule_mode = schedule_mode assert isinstance(factor, (float, int)), 'The factor should be a float/int number ' assert factor > 0, 'The factor should be greater than 0' self.factor = float(factor) assert isinstance(change_range, list) and len(change_range) == 2, 'The change_range should be a list with 2 float numbers' assert (isinstance(change_range[0], (float, int))) and ( isinstance(change_range[1], (float, int)) ), 'The change_range should be a list with 2 float/int numbers' assert change_range[0] < change_range[1], 'The first num should be smaller than the second num' self.change_range = change_range assert isinstance(threshold, (float, int)), 'The threshold should be a float/int number' self.threshold = threshold assert optimize_mode in ['min', 'max'], 'The optimize_mode should be one of [\'min\', \'max\']' self.optimize_mode = optimize_mode assert isinstance(patience, int), 'The patience should be a integer greater than or equal to 0' assert patience >= 0, 'The patience should be a integer greater than or equal to 0' self.patience = patience assert isinstance(cooldown, int), 'The cooldown_counter should be a integer greater than or equal to 0' assert cooldown >= 0, 'The cooldown_counter should be a integer greater than or equal to 0' self.cooldown = cooldown self.cooldown_counter = cooldown self.last_metrics = None self.bad_epochs_num = 0 def step(self, metrics: float, param: float) -> float: """ Overview: Decides whether to update the scheduled parameter Args: - metrics (:obj:`float`): current input metrics - param (:obj:`float`): parameter need to be updated Returns: - step_param (:obj:`float`): parameter after one step """ assert isinstance(metrics, float), 'The metrics should be converted to a float number' cur_metrics = metrics if self.is_better(cur_metrics): self.bad_epochs_num = 0 else: self.bad_epochs_num += 1 self.last_metrics = cur_metrics if self.in_cooldown: self.cooldown_counter -= 1 self.bad_epochs_num = 0 # ignore any bad epochs in cooldown if self.bad_epochs_num > self.patience: param = self.update_param(param) self.cooldown_counter = self.cooldown self.bad_epochs_num = 0 return param def update_param(self, param: float) -> float: """ Overview: update the scheduling parameter Args: - param (:obj:`float`): parameter need to be updated Returns: - updated param (:obj:`float`): parameter after updating """ schedule_fn = { 'reduce': lambda x, y, z: max(x - y, z[0]), 'add': lambda x, y, z: min(x + y, z[1]), 'multi': lambda x, y, z: min(x * y, z[1]) if y >= 1 else max(x * y, z[0]), 'div': lambda x, y, z: max(x / y, z[0]) if y >= 1 else min(x / y, z[1]), } schedule_mode_list = list(schedule_fn.keys()) if self.schedule_mode in schedule_mode_list: return schedule_fn[self.schedule_mode](param, self.factor, self.change_range) else: raise KeyError("invalid schedule_mode({}) in {}".format(self.schedule_mode, schedule_mode_list)) @property def in_cooldown(self) -> bool: """ Overview: Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler will ignore any bad epochs. """ return self.cooldown_counter > 0 def is_better(self, cur: float) -> bool: """ Overview: Checks whether the current metrics is better than last matric with respect to threshold. Args: - cur (:obj:`float`): current metrics """ if self.last_metrics is None: return True elif self.optimize_mode == 'min': return cur < self.last_metrics - self.threshold elif self.optimize_mode == 'max': return cur > self.last_metrics + self.threshold