File size: 3,882 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import time
import pytest
import torch
from easydict import EasyDict
from typing import Any
from functools import partial
from ding.worker import BaseLearner
from ding.worker.learner import LearnerHook, add_learner_hook, create_learner
class FakeLearner(BaseLearner):
@staticmethod
def random_data():
return {
'obs': torch.randn(2),
'replay_buffer_idx': 0,
'replay_unique_id': 0,
}
def get_data(self, batch_size):
return [self.random_data for _ in range(batch_size)]
class FakePolicy:
def __init__(self):
self._model = torch.nn.Identity()
def forward(self, x):
return {
'total_loss': torch.randn(1).squeeze(),
'cur_lr': 0.1,
'priority': [1., 2., 3.],
'[histogram]h_example': [1.2, 2.3, 3.4],
'[scalars]s_example': {
'a': 5.,
'b': 4.
},
}
def data_preprocess(self, x):
return x
def state_dict(self):
return {'model': self._model}
def load_state_dict(self, state_dict):
pass
def info(self):
return 'FakePolicy'
def monitor_vars(self):
return ['total_loss', 'cur_lr']
def get_attribute(self, name):
if name == 'cuda':
return False
elif name == 'device':
return 'cpu'
elif name == 'batch_size':
return 2
elif name == 'on_policy':
return False
else:
raise KeyError
def reset(self):
pass
@pytest.mark.unittest
class TestBaseLearner:
def _get_cfg(self, path):
cfg = BaseLearner.default_config()
cfg.import_names = []
cfg.learner_type = 'fake'
cfg.train_iterations = 10
cfg.hook.load_ckpt_before_run = path
cfg.hook.log_show_after_iter = 5
# Another way to build hook: Complete config
cfg.hook.save_ckpt_after_iter = dict(
name='save_ckpt_after_iter', type='save_ckpt', priority=40, position='after_iter', ext_args={'freq': 5}
)
return cfg
def test_naive(self):
os.popen('rm -rf iteration_5.pth.tar*')
time.sleep(1.0)
with pytest.raises(KeyError):
create_learner(EasyDict({'type': 'placeholder', 'import_names': []}))
path = os.path.join(os.path.dirname(__file__), './iteration_5.pth.tar')
torch.save({'model': {}, 'last_iter': 5}, path)
time.sleep(0.5)
cfg = self._get_cfg(path)
learner = FakeLearner(cfg, exp_name='exp_test')
learner.policy = FakePolicy()
learner.setup_dataloader()
learner.start()
time.sleep(2)
assert learner.last_iter.val == 10 + 5
# test hook
dir_name = '{}/ckpt'.format(learner.exp_name)
for n in [5, 10, 15]:
assert os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
for n in [0, 4, 7, 12]:
assert not os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
learner.debug('iter [5, 10, 15] exists; iter [0, 4, 7, 12] does not exist.')
learner.save_checkpoint('best')
info = learner.learn_info
for info_name in ['learner_step', 'priority_info', 'learner_done']:
assert info_name in info
class FakeHook(LearnerHook):
def __call__(self, engine: Any) -> Any:
pass
original_hook_num = len(learner._hooks['after_run'])
add_learner_hook(learner._hooks, FakeHook(name='fake_hook', priority=30, position='after_run'))
assert len(learner._hooks['after_run']) == original_hook_num + 1
os.popen('rm -rf iteration_5.pth.tar*')
os.popen('rm -rf ' + dir_name)
os.popen('rm -rf learner')
os.popen('rm -rf log')
learner.close()
|