File size: 2,948 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import pytest
import torch
from easydict import EasyDict
from ding.reward_model.gail_irl_model import GailRewardModel
from ding.utils.data import offline_data_save_type
from tensorboardX import SummaryWriter
import os
obs_space_1d, obs_space_3d = 4, [4, 84, 84]
expert_data_path_1d, expert_data_path_3d = './expert_data_1d', './expert_data_3d'
if not os.path.exists('./expert_data_1d'):
try:
os.mkdir('./expert_data_1d')
except FileExistsError:
pass
if not os.path.exists('./expert_data_3d'):
try:
os.mkdir('./expert_data_3d')
except FileExistsError:
pass
device = 'cpu'
action_space = 3
cfg1 = dict(
input_size=obs_space_1d + 1,
hidden_size=64,
batch_size=5,
learning_rate=1e-3,
update_per_collect=2,
data_path=expert_data_path_1d,
),
cfg2 = dict(
input_size=obs_space_3d,
hidden_size=64,
batch_size=5,
learning_rate=1e-3,
update_per_collect=2,
data_path=expert_data_path_3d,
action_size=action_space,
),
# create fake expert dataset
data_1d = []
for i in range(20):
d = {}
d['obs'] = torch.zeros(obs_space_1d)
d['action'] = torch.Tensor([1.])
data_1d.append(d)
data_3d = []
for i in range(20):
d = {}
d['obs'] = torch.zeros(obs_space_3d)
d['action'] = torch.Tensor([1.])
data_3d.append(d)
@pytest.mark.parametrize('cfg', cfg1)
@pytest.mark.unittest
def test_dataset_1d(cfg):
offline_data_save_type(
exp_data=data_1d, expert_data_path=expert_data_path_1d + '/expert_data.pkl', data_type='naive'
)
data = data_1d
cfg = EasyDict(cfg)
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter())
policy.load_expert_data()
assert len(policy.expert_data) == 20
state = policy.state_dict()
policy.load_state_dict(state)
policy.collect_data(data)
assert len(policy.train_data) == 20
for _ in range(5):
policy.train()
train_data_augmented = policy.estimate(data)
assert 'reward' in train_data_augmented[0].keys()
policy.clear_data()
assert len(policy.train_data) == 0
os.popen('rm -rf {}'.format(expert_data_path_1d))
@pytest.mark.parametrize('cfg', cfg2)
@pytest.mark.unittest
def test_dataset_3d(cfg):
offline_data_save_type(
exp_data=data_3d, expert_data_path=expert_data_path_3d + '/expert_data.pkl', data_type='naive'
)
data = data_3d
cfg = EasyDict(cfg)
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter())
policy.load_expert_data()
assert len(policy.expert_data) == 20
state = policy.state_dict()
policy.load_state_dict(state)
policy.collect_data(data)
assert len(policy.train_data) == 20
for _ in range(5):
policy.train()
train_data_augmented = policy.estimate(data)
assert 'reward' in train_data_augmented[0].keys()
policy.clear_data()
assert len(policy.train_data) == 0
os.popen('rm -rf {}'.format(expert_data_path_3d))
|