File size: 7,582 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from .default_helper import deep_merge_dicts
from easydict import EasyDict


class Scheduler(object):
    """
    Overview:
        Update learning parameters when the trueskill metrics has stopped improving.
        For example, models often benefits from reducing entropy weight once the learning process stagnates.
        This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs,
        the corresponding parameter is increased or decreased, which decides on the 'schedule_mode'.
    Arguments:
        - schedule_flag (:obj:`bool`): Indicates whether to use scheduler in training pipeline.
            Default: False
        - schedule_mode (:obj:`str`): One of 'reduce', 'add','multi','div'. The schecule_mode
            decides the way of updating the parameters.  Default:'reduce'.
        - factor (:obj:`float`) : Amount (greater than 0) by which the parameter will be
            increased/decreased. Default: 0.05
        - change_range (:obj:`list`): Indicates the minimum and maximum value
            the parameter can reach respectively. Default: [-1,1]
        - threshold (:obj:`float`): Threshold for measuring the new optimum,
            to only focus on significant changes. Default:  1e-4.
        - optimize_mode (:obj:`str`): One of 'min', 'max', which indicates the sign of
            optimization objective. Dynamic_threshold = last_metrics + threshold in `max`
            mode or last_metrics - threshold in `min` mode. Default: 'min'
        - patience (:obj:`int`): Number of epochs with no improvement after which
            the parameter will be updated. For example, if `patience = 2`, then we
            will ignore the first 2 epochs with no improvement, and will only update
            the parameter after the 3rd epoch if the metrics still hasn't improved then.
            Default: 10.
        - cooldown (:obj:`int`): Number of epochs to wait before resuming
            normal operation after the parameter has been updated. Default: 0.
    Interfaces:
        __init__, update_param, step
    Property:
        in_cooldown, is_better
    """

    config = dict(
        schedule_flag=False,
        schedule_mode='reduce',
        factor=0.05,
        change_range=[-1, 1],
        threshold=1e-4,
        optimize_mode='min',
        patience=10,
        cooldown=0,
    )

    def __init__(self, merged_scheduler_config: EasyDict) -> None:
        """
        Overview:
            Initialize the scheduler.
        Arguments:
            - merged_scheduler_config (:obj:`EasyDict`): the scheduler config, which merges the user
                config and defaul config
        """

        schedule_mode = merged_scheduler_config.schedule_mode
        factor = merged_scheduler_config.factor
        change_range = merged_scheduler_config.change_range
        threshold = merged_scheduler_config.threshold
        optimize_mode = merged_scheduler_config.optimize_mode
        patience = merged_scheduler_config.patience
        cooldown = merged_scheduler_config.cooldown

        assert schedule_mode in [
            'reduce', 'add', 'multi', 'div'
        ], 'The schedule mode should be one of [\'reduce\', \'add\', \'multi\',\'div\']'
        self.schedule_mode = schedule_mode

        assert isinstance(factor, (float, int)), 'The factor should be a float/int number '
        assert factor > 0, 'The factor should be greater than 0'
        self.factor = float(factor)

        assert isinstance(change_range,
                          list) and len(change_range) == 2, 'The change_range should be a list with 2 float numbers'
        assert (isinstance(change_range[0], (float, int))) and (
            isinstance(change_range[1], (float, int))
        ), 'The change_range should be a list with 2 float/int numbers'
        assert change_range[0] < change_range[1], 'The first num should be smaller than the second num'
        self.change_range = change_range

        assert isinstance(threshold, (float, int)), 'The threshold should be a float/int number'
        self.threshold = threshold

        assert optimize_mode in ['min', 'max'], 'The optimize_mode should be one of [\'min\', \'max\']'
        self.optimize_mode = optimize_mode

        assert isinstance(patience, int), 'The patience should be a integer greater than or equal to 0'
        assert patience >= 0, 'The patience should be a integer greater than or equal to 0'
        self.patience = patience

        assert isinstance(cooldown, int), 'The cooldown_counter should be a integer greater than or equal to 0'
        assert cooldown >= 0, 'The cooldown_counter should be a integer greater than or equal to 0'
        self.cooldown = cooldown
        self.cooldown_counter = cooldown

        self.last_metrics = None
        self.bad_epochs_num = 0

    def step(self, metrics: float, param: float) -> float:
        """
        Overview:
            Decides whether to update the scheduled parameter
        Args:
            - metrics (:obj:`float`): current input metrics
            - param (:obj:`float`): parameter need to be updated
        Returns:
            - step_param (:obj:`float`): parameter after one step
        """
        assert isinstance(metrics, float), 'The metrics should be converted to a float number'
        cur_metrics = metrics

        if self.is_better(cur_metrics):
            self.bad_epochs_num = 0
        else:
            self.bad_epochs_num += 1
        self.last_metrics = cur_metrics

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.bad_epochs_num = 0  # ignore any bad epochs in cooldown

        if self.bad_epochs_num > self.patience:
            param = self.update_param(param)
            self.cooldown_counter = self.cooldown
            self.bad_epochs_num = 0
        return param

    def update_param(self, param: float) -> float:
        """
        Overview:
            update the scheduling parameter
        Args:
            - param (:obj:`float`): parameter need to be updated
        Returns:
            - updated param (:obj:`float`): parameter after updating
        """
        schedule_fn = {
            'reduce': lambda x, y, z: max(x - y, z[0]),
            'add': lambda x, y, z: min(x + y, z[1]),
            'multi': lambda x, y, z: min(x * y, z[1]) if y >= 1 else max(x * y, z[0]),
            'div': lambda x, y, z: max(x / y, z[0]) if y >= 1 else min(x / y, z[1]),
        }

        schedule_mode_list = list(schedule_fn.keys())

        if self.schedule_mode in schedule_mode_list:
            return schedule_fn[self.schedule_mode](param, self.factor, self.change_range)
        else:
            raise KeyError("invalid schedule_mode({}) in {}".format(self.schedule_mode, schedule_mode_list))

    @property
    def in_cooldown(self) -> bool:
        """
        Overview:
            Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler
            will ignore any bad epochs.
        """
        return self.cooldown_counter > 0

    def is_better(self, cur: float) -> bool:
        """
        Overview:
            Checks whether the current metrics is better than last matric with respect to threshold.
        Args:
            - cur (:obj:`float`): current metrics
        """
        if self.last_metrics is None:
            return True

        elif self.optimize_mode == 'min':
            return cur < self.last_metrics - self.threshold

        elif self.optimize_mode == 'max':
            return cur > self.last_metrics + self.threshold