gomoku / DI-engine /ding /entry /tests /test_application_entry.py
zjowowen's picture
init space
079c32c
from copy import deepcopy
import pytest
import os
import pickle
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
cartpole_ppo_offpolicy_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
cartpole_trex_offppo_create_config
from dizoo.classic_control.cartpole.envs import CartPoleEnv
from ding.entry import serial_pipeline, eval, collect_demo_data
from ding.config import compile_config
from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions
@pytest.fixture(scope='module')
def setup_state_dict():
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
try:
policy = serial_pipeline(config, seed=0)
except Exception:
assert False, 'Serial pipeline failure'
state_dict = {
'eval': policy.eval_mode.state_dict(),
'collect': policy.collect_mode.state_dict(),
}
return state_dict
@pytest.mark.unittest
class TestApplication:
def test_eval(self, setup_state_dict):
cfg_for_stop_value = compile_config(
cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config
)
stop_value = cfg_for_stop_value.env.stop_value
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval'])
assert episode_return >= stop_value
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
episode_return = eval(
config,
seed=0,
env_setting=[CartPoleEnv, None, [{} for _ in range(5)]],
state_dict=setup_state_dict['eval']
)
assert episode_return >= stop_value
def test_collect_demo_data(self, setup_state_dict):
config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
collect_count = 16
expert_data_path = './expert.data'
collect_demo_data(
config,
seed=0,
state_dict=setup_state_dict['collect'],
collect_count=collect_count,
expert_data_path=expert_data_path
)
with open(expert_data_path, 'rb') as f:
exp_data = pickle.load(f)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0], dict)
def test_collect_episodic_demo_data(self, setup_state_dict):
config = deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)
config[0].exp_name = 'cartpole_trex_offppo_episodic'
collect_count = 16
if not os.path.exists('./test_episode'):
os.mkdir('./test_episode')
expert_data_path = './test_episode/expert.data'
collect_episodic_demo_data(
config,
seed=0,
state_dict=setup_state_dict['collect'],
expert_data_path=expert_data_path,
collect_count=collect_count,
)
with open(expert_data_path, 'rb') as f:
exp_data = pickle.load(f)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0][0], dict)
def test_episode_to_transitions(self, setup_state_dict):
self.test_collect_episodic_demo_data(setup_state_dict)
expert_data_path = './test_episode/expert.data'
episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3)
with open(expert_data_path, 'rb') as f:
exp_data = pickle.load(f)
assert isinstance(exp_data, list)
assert isinstance(exp_data[0], dict)
os.popen('rm -rf ./test_episode/expert.data ckpt* log')
os.popen('rm -rf ./test_episode')