File size: 4,714 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
from easydict import EasyDict
import pytest
from ding.utils import Scheduler
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
@pytest.mark.unittest
class TestSchedulerModule():
test_merged_scheduler_config = dict(
schedule_flag=False,
schedule_mode='reduce',
factor=0.05,
change_range=[-1, 1],
threshold=1e-4,
optimize_mode='min',
patience=1,
cooldown=0,
)
test_merged_scheduler_config = EasyDict(test_merged_scheduler_config)
test_policy_config = EasyDict(league_demo_ppo_config.policy)
test_policy_config_param = test_policy_config.learn.entropy_weight
def test_init_factor(self):
self.test_merged_scheduler_config.factor = 'hello_test'
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'float/int' in str(excinfo.value)
self.test_merged_scheduler_config.factor = 0
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'greater than 0' in str(excinfo.value)
# recover the correct value for later test function
self.test_merged_scheduler_config.factor = 0.05
def test_init_change_range(self):
self.test_merged_scheduler_config.change_range = 0
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'list' in str(excinfo.value)
self.test_merged_scheduler_config.change_range = [0, 'hello_test']
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'float' in str(excinfo.value)
self.test_merged_scheduler_config.change_range = [0, -1]
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'smaller' in str(excinfo.value)
# recover the correct value for later test function
self.test_merged_scheduler_config.change_range = [-1, 1]
def test_init_patience(self):
self.test_merged_scheduler_config.patience = "hello_test"
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'integer' in str(excinfo.value)
self.test_merged_scheduler_config.patience = -1
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'greater' in str(excinfo.value)
# recover the correct value for later test function
self.test_merged_scheduler_config.patience = 1
def test_is_better(self):
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert test_scheduler.is_better(-1) is True
test_scheduler.last_metrics = 1
assert test_scheduler.is_better(0.5) is True
def test_in_cooldown(self):
self.test_merged_scheduler_config.cooldown_counter = 0
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert test_scheduler.in_cooldown is False
def test_step(self):
self.test_merged_scheduler_config.cooldown = 1
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert test_scheduler.cooldown_counter == 1
test_scheduler.last_metrics = 1.0
old_param = self.test_policy_config.learn.entropy_weight
# good epoch with maximum cooldown lenth is 1
self.test_policy_config_param = test_scheduler.step(0.9, self.test_policy_config_param)
assert self.test_policy_config_param == old_param
assert test_scheduler.cooldown_counter == 0
assert test_scheduler.last_metrics == 0.9
assert test_scheduler.bad_epochs_num == 0
# first bad epoch in cooldown period
self.test_policy_config_param = test_scheduler.step(0.899999, self.test_policy_config_param)
assert self.test_policy_config_param == old_param
assert test_scheduler.cooldown_counter == 0
assert test_scheduler.last_metrics == 0.899999
assert test_scheduler.bad_epochs_num == 1
# first bad epoch after cooldown
self.test_policy_config_param = test_scheduler.step(0.899998, self.test_policy_config_param)
assert self.test_policy_config_param == old_param - self.test_merged_scheduler_config.factor
assert test_scheduler.cooldown_counter == 1
assert test_scheduler.last_metrics == 0.899998
assert test_scheduler.bad_epochs_num == 0
|