File size: 3,484 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import pytest
import time
import os
from copy import deepcopy
from ding.entry import serial_pipeline_onpolicy
from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppopg_config import cartpole_ppopg_config, cartpole_ppopg_create_config # noqa
from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
from dizoo.petting_zoo.config import ptz_simple_spread_mappo_config, ptz_simple_spread_mappo_create_config
from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config
from dizoo.classic_control.cartpole.config.cartpole_ppo_stdim_config import cartpole_ppo_stdim_config, cartpole_ppo_stdim_create_config # noqa
@pytest.mark.platformtest
@pytest.mark.unittest
def test_pg():
config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)]
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
@pytest.mark.platformtest
@pytest.mark.unittest
def test_a2c():
config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)]
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppo():
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
config[0].policy.learn.epoch_per_collect = 2
config[0].policy.eval.evaluator.eval_freq = 1
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=2)
except Exception:
assert False, "pipeline fail"
@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppopg():
config = [deepcopy(cartpole_ppopg_config), deepcopy(cartpole_ppopg_create_config)]
config[0].policy.learn.epoch_per_collect = 1
config[0].policy.eval.evaluator.eval_freq = 1
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=2)
except Exception:
assert False, "pipeline fail"
@pytest.mark.platformtest
@pytest.mark.unittest
def test_mappo():
config = [deepcopy(ptz_simple_spread_mappo_config), deepcopy(ptz_simple_spread_mappo_create_config)]
config[0].policy.learn.epoch_per_collect = 1
config[1].env_manager.type = 'base'
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppo_continuous():
config = [deepcopy(pendulum_ppo_config), deepcopy(pendulum_ppo_create_config)]
config[0].policy.learn.epoch_per_collect = 1
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
@pytest.mark.platformtest
@pytest.mark.unittest
def test_onppo_stdim():
config = [deepcopy(cartpole_ppo_stdim_config), deepcopy(cartpole_ppo_stdim_create_config)]
config[0].policy.learn.update_per_collect = 1
config[0].exp_name = 'cartpole_ppo_stdim_unittest'
try:
serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
|