gomoku / DI-engine /ding /utils /tests /test_scheduler_helper.py
zjowowen's picture
init space
079c32c
from easydict import EasyDict
import pytest
from ding.utils import Scheduler
from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
@pytest.mark.unittest
class TestSchedulerModule():
test_merged_scheduler_config = dict(
schedule_flag=False,
schedule_mode='reduce',
factor=0.05,
change_range=[-1, 1],
threshold=1e-4,
optimize_mode='min',
patience=1,
cooldown=0,
)
test_merged_scheduler_config = EasyDict(test_merged_scheduler_config)
test_policy_config = EasyDict(league_demo_ppo_config.policy)
test_policy_config_param = test_policy_config.learn.entropy_weight
def test_init_factor(self):
self.test_merged_scheduler_config.factor = 'hello_test'
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'float/int' in str(excinfo.value)
self.test_merged_scheduler_config.factor = 0
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'greater than 0' in str(excinfo.value)
# recover the correct value for later test function
self.test_merged_scheduler_config.factor = 0.05
def test_init_change_range(self):
self.test_merged_scheduler_config.change_range = 0
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'list' in str(excinfo.value)
self.test_merged_scheduler_config.change_range = [0, 'hello_test']
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'float' in str(excinfo.value)
self.test_merged_scheduler_config.change_range = [0, -1]
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'smaller' in str(excinfo.value)
# recover the correct value for later test function
self.test_merged_scheduler_config.change_range = [-1, 1]
def test_init_patience(self):
self.test_merged_scheduler_config.patience = "hello_test"
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'integer' in str(excinfo.value)
self.test_merged_scheduler_config.patience = -1
with pytest.raises(AssertionError) as excinfo:
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert 'greater' in str(excinfo.value)
# recover the correct value for later test function
self.test_merged_scheduler_config.patience = 1
def test_is_better(self):
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert test_scheduler.is_better(-1) is True
test_scheduler.last_metrics = 1
assert test_scheduler.is_better(0.5) is True
def test_in_cooldown(self):
self.test_merged_scheduler_config.cooldown_counter = 0
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert test_scheduler.in_cooldown is False
def test_step(self):
self.test_merged_scheduler_config.cooldown = 1
test_scheduler = Scheduler(self.test_merged_scheduler_config)
assert test_scheduler.cooldown_counter == 1
test_scheduler.last_metrics = 1.0
old_param = self.test_policy_config.learn.entropy_weight
# good epoch with maximum cooldown lenth is 1
self.test_policy_config_param = test_scheduler.step(0.9, self.test_policy_config_param)
assert self.test_policy_config_param == old_param
assert test_scheduler.cooldown_counter == 0
assert test_scheduler.last_metrics == 0.9
assert test_scheduler.bad_epochs_num == 0
# first bad epoch in cooldown period
self.test_policy_config_param = test_scheduler.step(0.899999, self.test_policy_config_param)
assert self.test_policy_config_param == old_param
assert test_scheduler.cooldown_counter == 0
assert test_scheduler.last_metrics == 0.899999
assert test_scheduler.bad_epochs_num == 1
# first bad epoch after cooldown
self.test_policy_config_param = test_scheduler.step(0.899998, self.test_policy_config_param)
assert self.test_policy_config_param == old_param - self.test_merged_scheduler_config.factor
assert test_scheduler.cooldown_counter == 1
assert test_scheduler.last_metrics == 0.899998
assert test_scheduler.bad_epochs_num == 0