|
from ding.entry import serial_pipeline_bc, serial_pipeline, collect_demo_data |
|
from dizoo.mujoco.config.halfcheetah_td3_config import main_config, create_config |
|
from copy import deepcopy |
|
from typing import Union, Optional, List, Any, Tuple |
|
import os |
|
import torch |
|
import logging |
|
from functools import partial |
|
from tensorboardX import SummaryWriter |
|
import torch.nn as nn |
|
from ding.envs import get_vec_env_setting, create_env_manager |
|
from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \ |
|
create_serial_collector |
|
from ding.config import read_config, compile_config |
|
from ding.policy import create_policy |
|
from ding.utils import set_pkg_seed |
|
from ding.entry.utils import random_collect |
|
from ding.entry import collect_demo_data, collect_episodic_demo_data, episode_to_transitions |
|
import pickle |
|
|
|
|
|
def load_policy( |
|
input_cfg: Union[str, Tuple[dict, dict]], |
|
load_path: str, |
|
seed: int = 0, |
|
env_setting: Optional[List[Any]] = None, |
|
model: Optional[torch.nn.Module] = None, |
|
) -> 'Policy': |
|
if isinstance(input_cfg, str): |
|
cfg, create_cfg = read_config(input_cfg) |
|
else: |
|
cfg, create_cfg = input_cfg |
|
create_cfg.policy.type = create_cfg.policy.type + '_command' |
|
env_fn = None if env_setting is None else env_setting[0] |
|
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) |
|
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command']) |
|
sd = torch.load(load_path, map_location='cpu') |
|
policy.collect_mode.load_state_dict(sd) |
|
return policy |
|
|
|
|
|
def main(): |
|
half_td3_config, half_td3_create_config = main_config, create_config |
|
train_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] |
|
exp_path = 'DI-engine/halfcheetah_td3_seed0/ckpt/ckpt_best.pth.tar' |
|
expert_policy = load_policy(train_config, load_path=exp_path, seed=0) |
|
|
|
|
|
collect_count = 100 |
|
expert_data_path = 'expert_data.pkl' |
|
state_dict = expert_policy.collect_mode.state_dict() |
|
collect_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] |
|
|
|
collect_episodic_demo_data( |
|
deepcopy(collect_config), |
|
seed=0, |
|
state_dict=state_dict, |
|
expert_data_path=expert_data_path, |
|
collect_count=collect_count |
|
) |
|
|
|
episode_to_transitions(expert_data_path, expert_data_path, nstep=1) |
|
|
|
|
|
il_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)] |
|
il_config[0].policy.learn.train_epoch = 1000000 |
|
il_config[0].policy.type = 'bc' |
|
il_config[0].policy.continuous = True |
|
il_config[0].exp_name = "continuous_bc_seed0" |
|
il_config[0].env.stop_value = 50000 |
|
il_config[0].multi_agent = False |
|
bc_policy, converge_stop_flag = serial_pipeline_bc(il_config, seed=314, data_path=expert_data_path, max_iter=4e6) |
|
return bc_policy |
|
|
|
|
|
if __name__ == '__main__': |
|
policy = main() |
|
|