File size: 2,921 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import easydict
import pytest
from ding.worker.learner import register_learner_hook, build_learner_hook_by_cfg, LearnerHook
from ding.worker.learner.learner_hook import SaveCkptHook, LoadCkptHook, LogShowHook, LogReduceHook
from ding.worker.learner.learner_hook import show_hooks, add_learner_hook, merge_hooks
from easydict import EasyDict
@pytest.fixture(scope='function')
def setup_simplified_hook_cfg():
return dict(
save_ckpt_after_iter=20,
save_ckpt_after_run=True,
)
@pytest.fixture(scope='function')
def fake_setup_simplified_hook_cfg():
return dict(
log_show_after_iter=20,
log_reduce_after_iter=True,
)
@pytest.mark.unittest
class TestLearnerHook:
def test_register(self):
class FakeHook(LearnerHook):
pass
register_learner_hook('fake', FakeHook)
with pytest.raises(AssertionError):
register_learner_hook('placeholder', type)
def test_build_learner_hook_by_cfg(self, setup_simplified_hook_cfg):
hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
show_hooks(hooks)
assert len(hooks['before_run']) == 0
assert len(hooks['before_iter']) == 0
assert len(hooks['after_iter']) == 1
assert isinstance(hooks['after_iter'][0], SaveCkptHook)
assert len(hooks['after_run']) == 1
assert isinstance(hooks['after_run'][0], SaveCkptHook)
def test_add_learner_hook(self, setup_simplified_hook_cfg):
hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
hook_1 = LogShowHook('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': 100}))
add_learner_hook(hooks, hook_1)
hook_2 = LoadCkptHook('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': './model.pth'}))
add_learner_hook(hooks, hook_2)
hook_3 = LogReduceHook('log_reduce', 10, position='after_iter')
add_learner_hook(hooks, hook_3)
show_hooks(hooks)
assert len(hooks['after_iter']) == 3
assert len(hooks['after_run']) == 1
assert len(hooks['before_run']) == 1
assert len(hooks['before_iter']) == 0
assert isinstance(hooks['after_run'][0], SaveCkptHook)
assert isinstance(hooks['before_run'][0], LoadCkptHook)
def test_merge_hooks(self, setup_simplified_hook_cfg, fake_setup_simplified_hook_cfg):
hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
show_hooks(hooks)
fake_hooks = build_learner_hook_by_cfg(fake_setup_simplified_hook_cfg)
show_hooks(fake_hooks)
hooks_ = merge_hooks(hooks, fake_hooks)
show_hooks(hooks_)
assert len(hooks_['after_iter']) == 3
assert len(hooks_['after_run']) == 1
assert len(hooks_['before_run']) == 0
assert len(hooks_['before_iter']) == 0
assert isinstance(hooks['after_run'][0], SaveCkptHook)
|