File size: 2,921 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import easydict
import pytest
from ding.worker.learner import register_learner_hook, build_learner_hook_by_cfg, LearnerHook
from ding.worker.learner.learner_hook import SaveCkptHook, LoadCkptHook, LogShowHook, LogReduceHook
from ding.worker.learner.learner_hook import show_hooks, add_learner_hook, merge_hooks
from easydict import EasyDict


@pytest.fixture(scope='function')
def setup_simplified_hook_cfg():
    return dict(
        save_ckpt_after_iter=20,
        save_ckpt_after_run=True,
    )


@pytest.fixture(scope='function')
def fake_setup_simplified_hook_cfg():
    return dict(
        log_show_after_iter=20,
        log_reduce_after_iter=True,
    )


@pytest.mark.unittest
class TestLearnerHook:

    def test_register(self):

        class FakeHook(LearnerHook):
            pass

        register_learner_hook('fake', FakeHook)
        with pytest.raises(AssertionError):
            register_learner_hook('placeholder', type)

    def test_build_learner_hook_by_cfg(self, setup_simplified_hook_cfg):
        hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
        show_hooks(hooks)
        assert len(hooks['before_run']) == 0
        assert len(hooks['before_iter']) == 0
        assert len(hooks['after_iter']) == 1
        assert isinstance(hooks['after_iter'][0], SaveCkptHook)
        assert len(hooks['after_run']) == 1
        assert isinstance(hooks['after_run'][0], SaveCkptHook)

    def test_add_learner_hook(self, setup_simplified_hook_cfg):
        hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
        hook_1 = LogShowHook('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': 100}))
        add_learner_hook(hooks, hook_1)
        hook_2 = LoadCkptHook('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': './model.pth'}))
        add_learner_hook(hooks, hook_2)
        hook_3 = LogReduceHook('log_reduce', 10, position='after_iter')
        add_learner_hook(hooks, hook_3)

        show_hooks(hooks)
        assert len(hooks['after_iter']) == 3
        assert len(hooks['after_run']) == 1
        assert len(hooks['before_run']) == 1
        assert len(hooks['before_iter']) == 0
        assert isinstance(hooks['after_run'][0], SaveCkptHook)
        assert isinstance(hooks['before_run'][0], LoadCkptHook)

    def test_merge_hooks(self, setup_simplified_hook_cfg, fake_setup_simplified_hook_cfg):
        hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
        show_hooks(hooks)
        fake_hooks = build_learner_hook_by_cfg(fake_setup_simplified_hook_cfg)
        show_hooks(fake_hooks)
        hooks_ = merge_hooks(hooks, fake_hooks)
        show_hooks(hooks_)
        assert len(hooks_['after_iter']) == 3
        assert len(hooks_['after_run']) == 1
        assert len(hooks_['before_run']) == 0
        assert len(hooks_['before_iter']) == 0
        assert isinstance(hooks['after_run'][0], SaveCkptHook)