gomoku / DI-engine /ding /worker /learner /tests /test_base_learner.py
zjowowen's picture
init space
079c32c
raw
history blame
3.88 kB
import os
import time
import pytest
import torch
from easydict import EasyDict
from typing import Any
from functools import partial
from ding.worker import BaseLearner
from ding.worker.learner import LearnerHook, add_learner_hook, create_learner
class FakeLearner(BaseLearner):
@staticmethod
def random_data():
return {
'obs': torch.randn(2),
'replay_buffer_idx': 0,
'replay_unique_id': 0,
}
def get_data(self, batch_size):
return [self.random_data for _ in range(batch_size)]
class FakePolicy:
def __init__(self):
self._model = torch.nn.Identity()
def forward(self, x):
return {
'total_loss': torch.randn(1).squeeze(),
'cur_lr': 0.1,
'priority': [1., 2., 3.],
'[histogram]h_example': [1.2, 2.3, 3.4],
'[scalars]s_example': {
'a': 5.,
'b': 4.
},
}
def data_preprocess(self, x):
return x
def state_dict(self):
return {'model': self._model}
def load_state_dict(self, state_dict):
pass
def info(self):
return 'FakePolicy'
def monitor_vars(self):
return ['total_loss', 'cur_lr']
def get_attribute(self, name):
if name == 'cuda':
return False
elif name == 'device':
return 'cpu'
elif name == 'batch_size':
return 2
elif name == 'on_policy':
return False
else:
raise KeyError
def reset(self):
pass
@pytest.mark.unittest
class TestBaseLearner:
def _get_cfg(self, path):
cfg = BaseLearner.default_config()
cfg.import_names = []
cfg.learner_type = 'fake'
cfg.train_iterations = 10
cfg.hook.load_ckpt_before_run = path
cfg.hook.log_show_after_iter = 5
# Another way to build hook: Complete config
cfg.hook.save_ckpt_after_iter = dict(
name='save_ckpt_after_iter', type='save_ckpt', priority=40, position='after_iter', ext_args={'freq': 5}
)
return cfg
def test_naive(self):
os.popen('rm -rf iteration_5.pth.tar*')
time.sleep(1.0)
with pytest.raises(KeyError):
create_learner(EasyDict({'type': 'placeholder', 'import_names': []}))
path = os.path.join(os.path.dirname(__file__), './iteration_5.pth.tar')
torch.save({'model': {}, 'last_iter': 5}, path)
time.sleep(0.5)
cfg = self._get_cfg(path)
learner = FakeLearner(cfg, exp_name='exp_test')
learner.policy = FakePolicy()
learner.setup_dataloader()
learner.start()
time.sleep(2)
assert learner.last_iter.val == 10 + 5
# test hook
dir_name = '{}/ckpt'.format(learner.exp_name)
for n in [5, 10, 15]:
assert os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
for n in [0, 4, 7, 12]:
assert not os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
learner.debug('iter [5, 10, 15] exists; iter [0, 4, 7, 12] does not exist.')
learner.save_checkpoint('best')
info = learner.learn_info
for info_name in ['learner_step', 'priority_info', 'learner_done']:
assert info_name in info
class FakeHook(LearnerHook):
def __call__(self, engine: Any) -> Any:
pass
original_hook_num = len(learner._hooks['after_run'])
add_learner_hook(learner._hooks, FakeHook(name='fake_hook', priority=30, position='after_run'))
assert len(learner._hooks['after_run']) == original_hook_num + 1
os.popen('rm -rf iteration_5.pth.tar*')
os.popen('rm -rf ' + dir_name)
os.popen('rm -rf learner')
os.popen('rm -rf log')
learner.close()