File size: 3,931 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from copy import deepcopy
import pytest
import os
import pickle

from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
    cartpole_ppo_offpolicy_create_config  # noqa
from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
     cartpole_trex_offppo_create_config
from dizoo.classic_control.cartpole.envs import CartPoleEnv
from ding.entry import serial_pipeline, eval, collect_demo_data
from ding.config import compile_config
from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions


@pytest.fixture(scope='module')
def setup_state_dict():
    config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
    try:
        policy = serial_pipeline(config, seed=0)
    except Exception:
        assert False, 'Serial pipeline failure'
    state_dict = {
        'eval': policy.eval_mode.state_dict(),
        'collect': policy.collect_mode.state_dict(),
    }
    return state_dict


@pytest.mark.unittest
class TestApplication:

    def test_eval(self, setup_state_dict):
        cfg_for_stop_value = compile_config(
            cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config
        )
        stop_value = cfg_for_stop_value.env.stop_value
        config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
        episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval'])
        assert episode_return >= stop_value
        config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
        episode_return = eval(
            config,
            seed=0,
            env_setting=[CartPoleEnv, None, [{} for _ in range(5)]],
            state_dict=setup_state_dict['eval']
        )
        assert episode_return >= stop_value

    def test_collect_demo_data(self, setup_state_dict):
        config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
        collect_count = 16
        expert_data_path = './expert.data'
        collect_demo_data(
            config,
            seed=0,
            state_dict=setup_state_dict['collect'],
            collect_count=collect_count,
            expert_data_path=expert_data_path
        )
        with open(expert_data_path, 'rb') as f:
            exp_data = pickle.load(f)
        assert isinstance(exp_data, list)
        assert isinstance(exp_data[0], dict)

    def test_collect_episodic_demo_data(self, setup_state_dict):
        config = deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)
        config[0].exp_name = 'cartpole_trex_offppo_episodic'
        collect_count = 16
        if not os.path.exists('./test_episode'):
            os.mkdir('./test_episode')
        expert_data_path = './test_episode/expert.data'
        collect_episodic_demo_data(
            config,
            seed=0,
            state_dict=setup_state_dict['collect'],
            expert_data_path=expert_data_path,
            collect_count=collect_count,
        )
        with open(expert_data_path, 'rb') as f:
            exp_data = pickle.load(f)
        assert isinstance(exp_data, list)
        assert isinstance(exp_data[0][0], dict)

    def test_episode_to_transitions(self, setup_state_dict):
        self.test_collect_episodic_demo_data(setup_state_dict)
        expert_data_path = './test_episode/expert.data'
        episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3)
        with open(expert_data_path, 'rb') as f:
            exp_data = pickle.load(f)
        assert isinstance(exp_data, list)
        assert isinstance(exp_data[0], dict)
        os.popen('rm -rf ./test_episode/expert.data ckpt* log')
        os.popen('rm -rf ./test_episode')