File size: 3,484 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import pytest
import time
import os
from copy import deepcopy

from ding.entry import serial_pipeline_onpolicy
from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppopg_config import cartpole_ppopg_config, cartpole_ppopg_create_config  # noqa
from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
from dizoo.petting_zoo.config import ptz_simple_spread_mappo_config, ptz_simple_spread_mappo_create_config
from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_stdim_config import cartpole_ppo_stdim_config, cartpole_ppo_stdim_create_config  # noqa


@pytest.mark.platformtest
@pytest.mark.unittest
def test_pg():
    config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)]
    try:
        serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
    except Exception:
        assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_a2c():
    config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)]
    try:
        serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
    except Exception:
        assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppo():
    config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
    config[0].policy.learn.epoch_per_collect = 2
    config[0].policy.eval.evaluator.eval_freq = 1
    try:
        serial_pipeline_onpolicy(config, seed=0, max_train_iter=2)
    except Exception:
        assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppopg():
    config = [deepcopy(cartpole_ppopg_config), deepcopy(cartpole_ppopg_create_config)]
    config[0].policy.learn.epoch_per_collect = 1
    config[0].policy.eval.evaluator.eval_freq = 1
    try:
        serial_pipeline_onpolicy(config, seed=0, max_train_iter=2)
    except Exception:
        assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_mappo():
    config = [deepcopy(ptz_simple_spread_mappo_config), deepcopy(ptz_simple_spread_mappo_create_config)]
    config[0].policy.learn.epoch_per_collect = 1
    config[1].env_manager.type = 'base'
    try:
        serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
    except Exception:
        assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppo_continuous():
    config = [deepcopy(pendulum_ppo_config), deepcopy(pendulum_ppo_create_config)]
    config[0].policy.learn.epoch_per_collect = 1
    try:
        serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
    except Exception:
        assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_onppo_stdim():
    config = [deepcopy(cartpole_ppo_stdim_config), deepcopy(cartpole_ppo_stdim_create_config)]
    config[0].policy.learn.update_per_collect = 1
    config[0].exp_name = 'cartpole_ppo_stdim_unittest'
    try:
        serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
    except Exception:
        assert False, "pipeline fail"