gomoku / DI-engine /ding /entry /tests /test_serial_entry_mbrl.py
zjowowen's picture
init space
079c32c
import pytest
from copy import deepcopy
from ding.entry.serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
from dizoo.classic_control.pendulum.config.mbrl.pendulum_sac_mbpo_config \
import main_config as pendulum_sac_mbpo_main_config,\
create_config as pendulum_sac_mbpo_create_config
from dizoo.classic_control.pendulum.config.mbrl.pendulum_mbsac_mbpo_config \
import main_config as pendulum_mbsac_mbpo_main_config,\
create_config as pendulum_mbsac_mbpo_create_config
from dizoo.classic_control.pendulum.config.mbrl.pendulum_stevesac_mbpo_config \
import main_config as pendulum_stevesac_mbpo_main_config,\
create_config as pendulum_stevesac_mbpo_create_config
@pytest.mark.unittest
def test_dyna():
config = [deepcopy(pendulum_sac_mbpo_main_config), deepcopy(pendulum_sac_mbpo_create_config)]
config[0].world_model.model.max_epochs_since_update = 0
try:
serial_pipeline_dyna(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
@pytest.mark.unittest
def test_dream():
configs = [
[deepcopy(pendulum_mbsac_mbpo_main_config),
deepcopy(pendulum_mbsac_mbpo_create_config)],
[deepcopy(pendulum_stevesac_mbpo_main_config),
deepcopy(pendulum_stevesac_mbpo_create_config)]
]
try:
for config in configs:
config[0].world_model.model.max_epochs_since_update = 0
serial_pipeline_dream(config, seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"