|
import pytest |
|
import os |
|
from ditk import logging |
|
from easydict import EasyDict |
|
from copy import deepcopy |
|
|
|
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config |
|
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \ |
|
serial_pipeline_reward_model_onpolicy |
|
|
|
cfg = [ |
|
{ |
|
'type': 'pdeil', |
|
"alpha": 0.5, |
|
"discrete_action": False |
|
}, |
|
{ |
|
'type': 'gail', |
|
'input_size': 5, |
|
'hidden_size': 64, |
|
'batch_size': 64, |
|
}, |
|
{ |
|
'type': 'pwil', |
|
's_size': 4, |
|
'a_size': 2, |
|
'sample_size': 500, |
|
}, |
|
{ |
|
'type': 'red', |
|
'sample_size': 5000, |
|
'input_size': 5, |
|
'hidden_size': 64, |
|
'update_per_collect': 200, |
|
'batch_size': 128, |
|
}, |
|
] |
|
|
|
|
|
@pytest.mark.unittest |
|
@pytest.mark.parametrize('reward_model_config', cfg) |
|
def test_irl(reward_model_config): |
|
reward_model_config = EasyDict(reward_model_config) |
|
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
|
expert_policy = serial_pipeline(config, seed=0, max_train_iter=2) |
|
|
|
collect_count = 10000 |
|
expert_data_path = 'expert_data.pkl' |
|
state_dict = expert_policy.collect_mode.state_dict() |
|
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
|
collect_demo_data( |
|
config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count |
|
) |
|
|
|
cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config) |
|
cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config) |
|
cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type) |
|
if reward_model_config.type == 'gail': |
|
reward_model_config['data_path'] = '.' |
|
else: |
|
reward_model_config['expert_data_path'] = expert_data_path |
|
cp_cartpole_dqn_config.reward_model = reward_model_config |
|
cp_cartpole_dqn_config.policy.collect.n_sample = 128 |
|
serial_pipeline_reward_model_offpolicy( |
|
(cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0, max_train_iter=2 |
|
) |
|
|
|
os.popen("rm -rf ckpt_* log expert_data.pkl") |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_rnd(): |
|
config = [deepcopy(cartpole_ppo_rnd_config), deepcopy(cartpole_ppo_rnd_create_config)] |
|
try: |
|
serial_pipeline_reward_model_onpolicy(config, seed=0, max_train_iter=2) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_icm(): |
|
config = [deepcopy(cartpole_ppo_icm_config), deepcopy(cartpole_ppo_icm_create_config)] |
|
try: |
|
serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|