File size: 1,105 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import pytest
import torch
from copy import deepcopy
from ding.entry import serial_pipeline
from ding.entry.serial_entry_sqil import serial_pipeline_sqil
from dizoo.classic_control.cartpole.config.cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config
from dizoo.classic_control.cartpole.config.cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
@pytest.mark.unittest
def test_sqil():
expert_policy_state_dict_path = './expert_policy.pth'
config = [deepcopy(cartpole_sql_config), deepcopy(cartpole_sql_create_config)]
expert_policy = serial_pipeline(config, seed=0)
torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
config = [deepcopy(cartpole_sqil_config), deepcopy(cartpole_sqil_create_config)]
config[0].policy.collect.model_path = expert_policy_state_dict_path
config[0].policy.learn.update_per_collect = 1
try:
serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0, max_train_iter=1)
except Exception:
assert False, "pipeline fail"
|