|
import pytest |
|
import torch |
|
from easydict import EasyDict |
|
from ding.reward_model.gail_irl_model import GailRewardModel |
|
from ding.utils.data import offline_data_save_type |
|
from tensorboardX import SummaryWriter |
|
import os |
|
|
|
obs_space_1d, obs_space_3d = 4, [4, 84, 84] |
|
expert_data_path_1d, expert_data_path_3d = './expert_data_1d', './expert_data_3d' |
|
if not os.path.exists('./expert_data_1d'): |
|
try: |
|
os.mkdir('./expert_data_1d') |
|
except FileExistsError: |
|
pass |
|
if not os.path.exists('./expert_data_3d'): |
|
try: |
|
os.mkdir('./expert_data_3d') |
|
except FileExistsError: |
|
pass |
|
device = 'cpu' |
|
action_space = 3 |
|
|
|
cfg1 = dict( |
|
input_size=obs_space_1d + 1, |
|
hidden_size=64, |
|
batch_size=5, |
|
learning_rate=1e-3, |
|
update_per_collect=2, |
|
data_path=expert_data_path_1d, |
|
), |
|
|
|
cfg2 = dict( |
|
input_size=obs_space_3d, |
|
hidden_size=64, |
|
batch_size=5, |
|
learning_rate=1e-3, |
|
update_per_collect=2, |
|
data_path=expert_data_path_3d, |
|
action_size=action_space, |
|
), |
|
|
|
|
|
data_1d = [] |
|
for i in range(20): |
|
d = {} |
|
d['obs'] = torch.zeros(obs_space_1d) |
|
d['action'] = torch.Tensor([1.]) |
|
data_1d.append(d) |
|
|
|
data_3d = [] |
|
for i in range(20): |
|
d = {} |
|
d['obs'] = torch.zeros(obs_space_3d) |
|
d['action'] = torch.Tensor([1.]) |
|
data_3d.append(d) |
|
|
|
|
|
@pytest.mark.parametrize('cfg', cfg1) |
|
@pytest.mark.unittest |
|
def test_dataset_1d(cfg): |
|
offline_data_save_type( |
|
exp_data=data_1d, expert_data_path=expert_data_path_1d + '/expert_data.pkl', data_type='naive' |
|
) |
|
data = data_1d |
|
cfg = EasyDict(cfg) |
|
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter()) |
|
policy.load_expert_data() |
|
assert len(policy.expert_data) == 20 |
|
state = policy.state_dict() |
|
policy.load_state_dict(state) |
|
policy.collect_data(data) |
|
assert len(policy.train_data) == 20 |
|
for _ in range(5): |
|
policy.train() |
|
train_data_augmented = policy.estimate(data) |
|
assert 'reward' in train_data_augmented[0].keys() |
|
policy.clear_data() |
|
assert len(policy.train_data) == 0 |
|
os.popen('rm -rf {}'.format(expert_data_path_1d)) |
|
|
|
|
|
@pytest.mark.parametrize('cfg', cfg2) |
|
@pytest.mark.unittest |
|
def test_dataset_3d(cfg): |
|
offline_data_save_type( |
|
exp_data=data_3d, expert_data_path=expert_data_path_3d + '/expert_data.pkl', data_type='naive' |
|
) |
|
data = data_3d |
|
cfg = EasyDict(cfg) |
|
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter()) |
|
policy.load_expert_data() |
|
assert len(policy.expert_data) == 20 |
|
state = policy.state_dict() |
|
policy.load_state_dict(state) |
|
policy.collect_data(data) |
|
assert len(policy.train_data) == 20 |
|
for _ in range(5): |
|
policy.train() |
|
train_data_augmented = policy.estimate(data) |
|
assert 'reward' in train_data_augmented[0].keys() |
|
policy.clear_data() |
|
assert len(policy.train_data) == 0 |
|
os.popen('rm -rf {}'.format(expert_data_path_3d)) |
|
|