File size: 6,652 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import argparse
import torch
import os
from typing import Union, Optional, List, Any
from functools import partial
from copy import deepcopy
from ding.config import compile_config, read_config
from ding.worker import EpisodeSerialCollector
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.torch_utils import to_device
from ding.utils import set_pkg_seed
from ding.utils.data import offline_data_save_type
from ding.utils.data import default_collate
def collect_episodic_demo_data_for_trex(
input_cfg: Union[str, dict],
seed: int,
collect_count: int,
rank: int,
env_setting: Optional[List[Any]] = None,
model: Optional[torch.nn.Module] = None,
state_dict: Optional[dict] = None,
state_dict_path: Optional[str] = None,
):
"""
Overview:
Collect episodic demonstration data by the trained policy for trex specifically.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- collect_count (:obj:`int`): The count of collected data.
- rank (:obj:`int`): The episode ranking.
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
``BaseEnv`` subclass, collector env config, and evaluator env config.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
- state_dict_path (:obj:'str') The abs path of the state dict.
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
create_cfg.policy.type += '_command'
env_fn = None if env_setting is None else env_setting[0]
cfg.env.collector_env_num = 1
cfg = compile_config(
cfg,
collector=EpisodeSerialCollector,
seed=seed,
env=env_fn,
auto=True,
create_cfg=create_cfg,
save_cfg=True,
save_path='collect_demo_data_config.py'
)
# Create components: env, policy, collector
if env_setting is None:
env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env)
else:
env_fn, collector_env_cfg, _ = env_setting
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
collector_env.seed(seed)
set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval'])
collect_demo_policy = policy.collect_mode
if state_dict is None:
assert state_dict_path is not None
state_dict = torch.load(state_dict_path, map_location='cpu')
policy.collect_mode.load_state_dict(state_dict)
collector = EpisodeSerialCollector(
cfg.policy.collect.collector, collector_env, collect_demo_policy, exp_name=cfg.exp_name
)
policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
else {'eps': cfg.policy.other.eps.get('collect', 0.2)}
# Let's collect some sub-optimal demostrations
exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs)
if cfg.policy.cuda:
exp_data = to_device(exp_data, 'cpu')
# Save data transitions.
print('Collect {}th episodic demo data successfully'.format(rank))
return exp_data
def trex_get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='abs path for a config')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def trex_collecting_data(args=None):
if args is None:
args = trex_get_args() # TODO(nyz) use sub-command in cli
if isinstance(args.cfg, str):
cfg, create_cfg = read_config(args.cfg)
else:
cfg, create_cfg = deepcopy(args.cfg)
data_path = cfg.exp_name
expert_model_path = cfg.reward_model.expert_model_path # directory path
checkpoint_min = cfg.reward_model.checkpoint_min
checkpoint_max = cfg.reward_model.checkpoint_max
checkpoint_step = cfg.reward_model.checkpoint_step
checkpoints = []
for i in range(checkpoint_min, checkpoint_max + checkpoint_step, checkpoint_step):
checkpoints.append(str(i))
data_for_save = {}
learning_returns = []
learning_rewards = []
episodes_data = []
for checkpoint in checkpoints:
num_per_ckpt = 1
model_path = expert_model_path + \
'/ckpt/iteration_' + checkpoint + '.pth.tar'
seed = args.seed + (int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)
exp_data = collect_episodic_demo_data_for_trex(
deepcopy(args.cfg),
seed,
state_dict_path=model_path,
collect_count=num_per_ckpt,
rank=(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + 1
)
data_for_save[(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)] = exp_data
obs = [list(default_collate(exp_data[i])['obs'].numpy()) for i in range(len(exp_data))]
rewards = [default_collate(exp_data[i])['reward'].tolist() for i in range(len(exp_data))]
sum_rewards = [torch.sum(default_collate(exp_data[i])['reward']).item() for i in range(len(exp_data))]
learning_rewards.append(rewards)
learning_returns.append(sum_rewards)
episodes_data.append(obs)
offline_data_save_type(
data_for_save, data_path + '/suboptimal_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
)
# if not compiled_cfg.reward_model.auto: more feature
offline_data_save_type(
episodes_data, data_path + '/episodes_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
)
offline_data_save_type(
learning_returns, data_path + '/learning_returns.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
)
offline_data_save_type(
learning_rewards, data_path + '/learning_rewards.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
)
offline_data_save_type(
checkpoints, data_path + '/checkpoints.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
)
return checkpoints, episodes_data, learning_returns, learning_rewards
if __name__ == '__main__':
trex_collecting_data()
|