File size: 3,882 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import time

import pytest
import torch
from easydict import EasyDict
from typing import Any
from functools import partial

from ding.worker import BaseLearner
from ding.worker.learner import LearnerHook, add_learner_hook, create_learner


class FakeLearner(BaseLearner):

    @staticmethod
    def random_data():
        return {
            'obs': torch.randn(2),
            'replay_buffer_idx': 0,
            'replay_unique_id': 0,
        }

    def get_data(self, batch_size):
        return [self.random_data for _ in range(batch_size)]


class FakePolicy:

    def __init__(self):
        self._model = torch.nn.Identity()

    def forward(self, x):
        return {
            'total_loss': torch.randn(1).squeeze(),
            'cur_lr': 0.1,
            'priority': [1., 2., 3.],
            '[histogram]h_example': [1.2, 2.3, 3.4],
            '[scalars]s_example': {
                'a': 5.,
                'b': 4.
            },
        }

    def data_preprocess(self, x):
        return x

    def state_dict(self):
        return {'model': self._model}

    def load_state_dict(self, state_dict):
        pass

    def info(self):
        return 'FakePolicy'

    def monitor_vars(self):
        return ['total_loss', 'cur_lr']

    def get_attribute(self, name):
        if name == 'cuda':
            return False
        elif name == 'device':
            return 'cpu'
        elif name == 'batch_size':
            return 2
        elif name == 'on_policy':
            return False
        else:
            raise KeyError

    def reset(self):
        pass


@pytest.mark.unittest
class TestBaseLearner:

    def _get_cfg(self, path):
        cfg = BaseLearner.default_config()
        cfg.import_names = []
        cfg.learner_type = 'fake'
        cfg.train_iterations = 10
        cfg.hook.load_ckpt_before_run = path
        cfg.hook.log_show_after_iter = 5
        # Another way to build hook: Complete config
        cfg.hook.save_ckpt_after_iter = dict(
            name='save_ckpt_after_iter', type='save_ckpt', priority=40, position='after_iter', ext_args={'freq': 5}
        )

        return cfg

    def test_naive(self):
        os.popen('rm -rf iteration_5.pth.tar*')
        time.sleep(1.0)
        with pytest.raises(KeyError):
            create_learner(EasyDict({'type': 'placeholder', 'import_names': []}))
        path = os.path.join(os.path.dirname(__file__), './iteration_5.pth.tar')
        torch.save({'model': {}, 'last_iter': 5}, path)
        time.sleep(0.5)
        cfg = self._get_cfg(path)
        learner = FakeLearner(cfg, exp_name='exp_test')
        learner.policy = FakePolicy()
        learner.setup_dataloader()
        learner.start()
        time.sleep(2)
        assert learner.last_iter.val == 10 + 5

        # test hook
        dir_name = '{}/ckpt'.format(learner.exp_name)
        for n in [5, 10, 15]:
            assert os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
        for n in [0, 4, 7, 12]:
            assert not os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
        learner.debug('iter [5, 10, 15] exists; iter [0, 4, 7, 12] does not exist.')

        learner.save_checkpoint('best')

        info = learner.learn_info
        for info_name in ['learner_step', 'priority_info', 'learner_done']:
            assert info_name in info

        class FakeHook(LearnerHook):

            def __call__(self, engine: Any) -> Any:
                pass

        original_hook_num = len(learner._hooks['after_run'])
        add_learner_hook(learner._hooks, FakeHook(name='fake_hook', priority=30, position='after_run'))
        assert len(learner._hooks['after_run']) == original_hook_num + 1

        os.popen('rm -rf iteration_5.pth.tar*')
        os.popen('rm -rf ' + dir_name)
        os.popen('rm -rf learner')
        os.popen('rm -rf log')
        learner.close()