import argparse import torch import os from typing import Union, Optional, List, Any from functools import partial from copy import deepcopy from ding.config import compile_config, read_config from ding.worker import EpisodeSerialCollector from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy from ding.torch_utils import to_device from ding.utils import set_pkg_seed from ding.utils.data import offline_data_save_type from ding.utils.data import default_collate def collect_episodic_demo_data_for_trex( input_cfg: Union[str, dict], seed: int, collect_count: int, rank: int, env_setting: Optional[List[Any]] = None, model: Optional[torch.nn.Module] = None, state_dict: Optional[dict] = None, state_dict_path: Optional[str] = None, ): """ Overview: Collect episodic demonstration data by the trained policy for trex specifically. Arguments: - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ ``str`` type means config file path. \ ``Tuple[dict, dict]`` type means [user_config, create_cfg]. - seed (:obj:`int`): Random seed. - collect_count (:obj:`int`): The count of collected data. - rank (:obj:`int`): The episode ranking. - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ ``BaseEnv`` subclass, collector env config, and evaluator env config. - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. - state_dict_path (:obj:'str') The abs path of the state dict. """ if isinstance(input_cfg, str): cfg, create_cfg = read_config(input_cfg) else: cfg, create_cfg = deepcopy(input_cfg) create_cfg.policy.type += '_command' env_fn = None if env_setting is None else env_setting[0] cfg.env.collector_env_num = 1 cfg = compile_config( cfg, collector=EpisodeSerialCollector, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, save_path='collect_demo_data_config.py' ) # Create components: env, policy, collector if env_setting is None: env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env) else: env_fn, collector_env_cfg, _ = env_setting collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) collector_env.seed(seed) set_pkg_seed(seed, use_cuda=cfg.policy.cuda) policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) collect_demo_policy = policy.collect_mode if state_dict is None: assert state_dict_path is not None state_dict = torch.load(state_dict_path, map_location='cpu') policy.collect_mode.load_state_dict(state_dict) collector = EpisodeSerialCollector( cfg.policy.collect.collector, collector_env, collect_demo_policy, exp_name=cfg.exp_name ) policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \ else {'eps': cfg.policy.other.eps.get('collect', 0.2)} # Let's collect some sub-optimal demostrations exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs) if cfg.policy.cuda: exp_data = to_device(exp_data, 'cpu') # Save data transitions. print('Collect {}th episodic demo data successfully'.format(rank)) return exp_data def trex_get_args(): parser = argparse.ArgumentParser() parser.add_argument('--cfg', type=str, default='abs path for a config') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') args = parser.parse_known_args()[0] return args def trex_collecting_data(args=None): if args is None: args = trex_get_args() # TODO(nyz) use sub-command in cli if isinstance(args.cfg, str): cfg, create_cfg = read_config(args.cfg) else: cfg, create_cfg = deepcopy(args.cfg) data_path = cfg.exp_name expert_model_path = cfg.reward_model.expert_model_path # directory path checkpoint_min = cfg.reward_model.checkpoint_min checkpoint_max = cfg.reward_model.checkpoint_max checkpoint_step = cfg.reward_model.checkpoint_step checkpoints = [] for i in range(checkpoint_min, checkpoint_max + checkpoint_step, checkpoint_step): checkpoints.append(str(i)) data_for_save = {} learning_returns = [] learning_rewards = [] episodes_data = [] for checkpoint in checkpoints: num_per_ckpt = 1 model_path = expert_model_path + \ '/ckpt/iteration_' + checkpoint + '.pth.tar' seed = args.seed + (int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) exp_data = collect_episodic_demo_data_for_trex( deepcopy(args.cfg), seed, state_dict_path=model_path, collect_count=num_per_ckpt, rank=(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + 1 ) data_for_save[(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)] = exp_data obs = [list(default_collate(exp_data[i])['obs'].numpy()) for i in range(len(exp_data))] rewards = [default_collate(exp_data[i])['reward'].tolist() for i in range(len(exp_data))] sum_rewards = [torch.sum(default_collate(exp_data[i])['reward']).item() for i in range(len(exp_data))] learning_rewards.append(rewards) learning_returns.append(sum_rewards) episodes_data.append(obs) offline_data_save_type( data_for_save, data_path + '/suboptimal_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') ) # if not compiled_cfg.reward_model.auto: more feature offline_data_save_type( episodes_data, data_path + '/episodes_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') ) offline_data_save_type( learning_returns, data_path + '/learning_returns.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') ) offline_data_save_type( learning_rewards, data_path + '/learning_rewards.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') ) offline_data_save_type( checkpoints, data_path + '/checkpoints.pkl', data_type=cfg.policy.collect.get('data_type', 'naive') ) return checkpoints, episodes_data, learning_returns, learning_rewards if __name__ == '__main__': trex_collecting_data()