File size: 3,319 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pytest
import os
from ditk import logging
from easydict import EasyDict
from copy import deepcopy

from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config  # noqa
from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config  # noqa
from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config  # noqa
from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
    serial_pipeline_reward_model_onpolicy

cfg = [
    {
        'type': 'pdeil',
        "alpha": 0.5,
        "discrete_action": False
    },
    {
        'type': 'gail',
        'input_size': 5,
        'hidden_size': 64,
        'batch_size': 64,
    },
    {
        'type': 'pwil',
        's_size': 4,
        'a_size': 2,
        'sample_size': 500,
    },
    {
        'type': 'red',
        'sample_size': 5000,
        'input_size': 5,
        'hidden_size': 64,
        'update_per_collect': 200,
        'batch_size': 128,
    },
]


@pytest.mark.unittest
@pytest.mark.parametrize('reward_model_config', cfg)
def test_irl(reward_model_config):
    reward_model_config = EasyDict(reward_model_config)
    config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
    expert_policy = serial_pipeline(config, seed=0, max_train_iter=2)
    # collect expert demo data
    collect_count = 10000
    expert_data_path = 'expert_data.pkl'
    state_dict = expert_policy.collect_mode.state_dict()
    config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
    collect_demo_data(
        config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
    )
    # irl + rl training
    cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config)
    cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config)
    cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type)
    if reward_model_config.type == 'gail':
        reward_model_config['data_path'] = '.'
    else:
        reward_model_config['expert_data_path'] = expert_data_path
    cp_cartpole_dqn_config.reward_model = reward_model_config
    cp_cartpole_dqn_config.policy.collect.n_sample = 128
    serial_pipeline_reward_model_offpolicy(
        (cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0, max_train_iter=2
    )

    os.popen("rm -rf ckpt_* log expert_data.pkl")


@pytest.mark.unittest
def test_rnd():
    config = [deepcopy(cartpole_ppo_rnd_config), deepcopy(cartpole_ppo_rnd_create_config)]
    try:
        serial_pipeline_reward_model_onpolicy(config, seed=0, max_train_iter=2)
    except Exception:
        assert False, "pipeline fail"


@pytest.mark.unittest
def test_icm():
    config = [deepcopy(cartpole_ppo_icm_config), deepcopy(cartpole_ppo_icm_create_config)]
    try:
        serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2)
    except Exception:
        assert False, "pipeline fail"