File size: 1,557 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import pytest
from copy import deepcopy
from ding.entry.serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
from dizoo.classic_control.pendulum.config.mbrl.pendulum_sac_mbpo_config \
import main_config as pendulum_sac_mbpo_main_config,\
create_config as pendulum_sac_mbpo_create_config
from dizoo.classic_control.pendulum.config.mbrl.pendulum_mbsac_mbpo_config \
import main_config as pendulum_mbsac_mbpo_main_config,\
create_config as pendulum_mbsac_mbpo_create_config
from dizoo.classic_control.pendulum.config.mbrl.pendulum_stevesac_mbpo_config \
import main_config as pendulum_stevesac_mbpo_main_config,\
create_config as pendulum_stevesac_mbpo_create_config
@pytest.mark.unittest
def test_dyna():
config = [deepcopy(pendulum_sac_mbpo_main_config), deepcopy(pendulum_sac_mbpo_create_config)]
config[0].world_model.model.max_epochs_since_update = 0
try:
serial_pipeline_dyna(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
@pytest.mark.unittest
def test_dream():
configs = [
[deepcopy(pendulum_mbsac_mbpo_main_config),
deepcopy(pendulum_mbsac_mbpo_create_config)],
[deepcopy(pendulum_stevesac_mbpo_main_config),
deepcopy(pendulum_stevesac_mbpo_create_config)]
]
try:
for config in configs:
config[0].world_model.model.max_epochs_since_update = 0
serial_pipeline_dream(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
|