|
import pytest |
|
from copy import deepcopy |
|
from ding.entry.serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream |
|
|
|
from dizoo.classic_control.pendulum.config.mbrl.pendulum_sac_mbpo_config \ |
|
import main_config as pendulum_sac_mbpo_main_config,\ |
|
create_config as pendulum_sac_mbpo_create_config |
|
|
|
from dizoo.classic_control.pendulum.config.mbrl.pendulum_mbsac_mbpo_config \ |
|
import main_config as pendulum_mbsac_mbpo_main_config,\ |
|
create_config as pendulum_mbsac_mbpo_create_config |
|
|
|
from dizoo.classic_control.pendulum.config.mbrl.pendulum_stevesac_mbpo_config \ |
|
import main_config as pendulum_stevesac_mbpo_main_config,\ |
|
create_config as pendulum_stevesac_mbpo_create_config |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dyna(): |
|
config = [deepcopy(pendulum_sac_mbpo_main_config), deepcopy(pendulum_sac_mbpo_create_config)] |
|
config[0].world_model.model.max_epochs_since_update = 0 |
|
try: |
|
serial_pipeline_dyna(config, seed=0, max_train_iter=1) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_dream(): |
|
configs = [ |
|
[deepcopy(pendulum_mbsac_mbpo_main_config), |
|
deepcopy(pendulum_mbsac_mbpo_create_config)], |
|
[deepcopy(pendulum_stevesac_mbpo_main_config), |
|
deepcopy(pendulum_stevesac_mbpo_create_config)] |
|
] |
|
try: |
|
for config in configs: |
|
config[0].world_model.model.max_epochs_since_update = 0 |
|
serial_pipeline_dream(config, seed=0, max_train_iter=1) |
|
except Exception: |
|
assert False, "pipeline fail" |
|
|