|
from copy import deepcopy |
|
import pytest |
|
import os |
|
import pickle |
|
|
|
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \ |
|
cartpole_ppo_offpolicy_create_config |
|
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\ |
|
cartpole_trex_offppo_create_config |
|
from dizoo.classic_control.cartpole.envs import CartPoleEnv |
|
from ding.entry import serial_pipeline, eval, collect_demo_data |
|
from ding.config import compile_config |
|
from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions |
|
|
|
|
|
@pytest.fixture(scope='module') |
|
def setup_state_dict(): |
|
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
|
try: |
|
policy = serial_pipeline(config, seed=0) |
|
except Exception: |
|
assert False, 'Serial pipeline failure' |
|
state_dict = { |
|
'eval': policy.eval_mode.state_dict(), |
|
'collect': policy.collect_mode.state_dict(), |
|
} |
|
return state_dict |
|
|
|
|
|
@pytest.mark.unittest |
|
class TestApplication: |
|
|
|
def test_eval(self, setup_state_dict): |
|
cfg_for_stop_value = compile_config( |
|
cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config |
|
) |
|
stop_value = cfg_for_stop_value.env.stop_value |
|
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
|
episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval']) |
|
assert episode_return >= stop_value |
|
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
|
episode_return = eval( |
|
config, |
|
seed=0, |
|
env_setting=[CartPoleEnv, None, [{} for _ in range(5)]], |
|
state_dict=setup_state_dict['eval'] |
|
) |
|
assert episode_return >= stop_value |
|
|
|
def test_collect_demo_data(self, setup_state_dict): |
|
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config) |
|
collect_count = 16 |
|
expert_data_path = './expert.data' |
|
collect_demo_data( |
|
config, |
|
seed=0, |
|
state_dict=setup_state_dict['collect'], |
|
collect_count=collect_count, |
|
expert_data_path=expert_data_path |
|
) |
|
with open(expert_data_path, 'rb') as f: |
|
exp_data = pickle.load(f) |
|
assert isinstance(exp_data, list) |
|
assert isinstance(exp_data[0], dict) |
|
|
|
def test_collect_episodic_demo_data(self, setup_state_dict): |
|
config = deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config) |
|
config[0].exp_name = 'cartpole_trex_offppo_episodic' |
|
collect_count = 16 |
|
if not os.path.exists('./test_episode'): |
|
os.mkdir('./test_episode') |
|
expert_data_path = './test_episode/expert.data' |
|
collect_episodic_demo_data( |
|
config, |
|
seed=0, |
|
state_dict=setup_state_dict['collect'], |
|
expert_data_path=expert_data_path, |
|
collect_count=collect_count, |
|
) |
|
with open(expert_data_path, 'rb') as f: |
|
exp_data = pickle.load(f) |
|
assert isinstance(exp_data, list) |
|
assert isinstance(exp_data[0][0], dict) |
|
|
|
def test_episode_to_transitions(self, setup_state_dict): |
|
self.test_collect_episodic_demo_data(setup_state_dict) |
|
expert_data_path = './test_episode/expert.data' |
|
episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3) |
|
with open(expert_data_path, 'rb') as f: |
|
exp_data = pickle.load(f) |
|
assert isinstance(exp_data, list) |
|
assert isinstance(exp_data[0], dict) |
|
os.popen('rm -rf ./test_episode/expert.data ckpt* log') |
|
os.popen('rm -rf ./test_episode') |
|
|