File size: 1,124 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import pytest
import torch
from copy import deepcopy
from ding.entry import serial_pipeline
from ding.entry.serial_entry_bco import serial_pipeline_bco
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
from dizoo.classic_control.cartpole.config.cartpole_bco_config import cartpole_bco_config, cartpole_bco_create_config


@pytest.mark.unittest
def test_bco():
    expert_policy_state_dict_path = './expert_policy.pth'
    expert_config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
    expert_policy = serial_pipeline(expert_config, seed=0)
    torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)

    config = [deepcopy(cartpole_bco_config), deepcopy(cartpole_bco_create_config)]
    config[0].policy.collect.model_path = expert_policy_state_dict_path
    try:
        serial_pipeline_bco(
            config, [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)], seed=0, max_train_iter=3
        )
    except Exception as e:
        print(e)
        assert False, "pipeline fail"