gomoku / DI-engine /ding /reward_model /tests /test_gail_irl_model.py
zjowowen's picture
init space
079c32c
import pytest
import torch
from easydict import EasyDict
from ding.reward_model.gail_irl_model import GailRewardModel
from ding.utils.data import offline_data_save_type
from tensorboardX import SummaryWriter
import os
obs_space_1d, obs_space_3d = 4, [4, 84, 84]
expert_data_path_1d, expert_data_path_3d = './expert_data_1d', './expert_data_3d'
if not os.path.exists('./expert_data_1d'):
try:
os.mkdir('./expert_data_1d')
except FileExistsError:
pass
if not os.path.exists('./expert_data_3d'):
try:
os.mkdir('./expert_data_3d')
except FileExistsError:
pass
device = 'cpu'
action_space = 3
cfg1 = dict(
input_size=obs_space_1d + 1,
hidden_size=64,
batch_size=5,
learning_rate=1e-3,
update_per_collect=2,
data_path=expert_data_path_1d,
),
cfg2 = dict(
input_size=obs_space_3d,
hidden_size=64,
batch_size=5,
learning_rate=1e-3,
update_per_collect=2,
data_path=expert_data_path_3d,
action_size=action_space,
),
# create fake expert dataset
data_1d = []
for i in range(20):
d = {}
d['obs'] = torch.zeros(obs_space_1d)
d['action'] = torch.Tensor([1.])
data_1d.append(d)
data_3d = []
for i in range(20):
d = {}
d['obs'] = torch.zeros(obs_space_3d)
d['action'] = torch.Tensor([1.])
data_3d.append(d)
@pytest.mark.parametrize('cfg', cfg1)
@pytest.mark.unittest
def test_dataset_1d(cfg):
offline_data_save_type(
exp_data=data_1d, expert_data_path=expert_data_path_1d + '/expert_data.pkl', data_type='naive'
)
data = data_1d
cfg = EasyDict(cfg)
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter())
policy.load_expert_data()
assert len(policy.expert_data) == 20
state = policy.state_dict()
policy.load_state_dict(state)
policy.collect_data(data)
assert len(policy.train_data) == 20
for _ in range(5):
policy.train()
train_data_augmented = policy.estimate(data)
assert 'reward' in train_data_augmented[0].keys()
policy.clear_data()
assert len(policy.train_data) == 0
os.popen('rm -rf {}'.format(expert_data_path_1d))
@pytest.mark.parametrize('cfg', cfg2)
@pytest.mark.unittest
def test_dataset_3d(cfg):
offline_data_save_type(
exp_data=data_3d, expert_data_path=expert_data_path_3d + '/expert_data.pkl', data_type='naive'
)
data = data_3d
cfg = EasyDict(cfg)
policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter())
policy.load_expert_data()
assert len(policy.expert_data) == 20
state = policy.state_dict()
policy.load_state_dict(state)
policy.collect_data(data)
assert len(policy.train_data) == 20
for _ in range(5):
policy.train()
train_data_augmented = policy.estimate(data)
assert 'reward' in train_data_augmented[0].keys()
policy.clear_data()
assert len(policy.train_data) == 0
os.popen('rm -rf {}'.format(expert_data_path_3d))