|
from typing import Union, Optional, List, Any, Tuple |
|
import os |
|
import pickle |
|
import numpy as np |
|
import torch |
|
from functools import partial |
|
from copy import deepcopy |
|
|
|
from ding.config import compile_config, read_config |
|
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, EpisodeSerialCollector |
|
from ding.envs import create_env_manager, get_vec_env_setting |
|
from ding.policy import create_policy |
|
from ding.torch_utils import to_device, to_ndarray |
|
from ding.utils import set_pkg_seed |
|
from ding.utils.data import offline_data_save_type |
|
from ding.rl_utils import get_nstep_return_data |
|
from ding.utils.data import default_collate |
|
|
|
|
|
def eval( |
|
input_cfg: Union[str, Tuple[dict, dict]], |
|
seed: int = 0, |
|
env_setting: Optional[List[Any]] = None, |
|
model: Optional[torch.nn.Module] = None, |
|
state_dict: Optional[dict] = None, |
|
load_path: Optional[str] = None, |
|
replay_path: Optional[str] = None, |
|
) -> float: |
|
""" |
|
Overview: |
|
Pure policy evaluation entry. Evaluate mean episode return and save replay videos. |
|
Arguments: |
|
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ |
|
``str`` type means config file path. \ |
|
``Tuple[dict, dict]`` type means [user_config, create_cfg]. |
|
- seed (:obj:`int`): Random seed. |
|
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ |
|
``BaseEnv`` subclass, collector env config, and evaluator env config. |
|
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. |
|
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. |
|
- load_path (:obj:`Optional[str]`): Path to load ckpt. |
|
- replay_path (:obj:`Optional[str]`): Path to save replay. |
|
""" |
|
if isinstance(input_cfg, str): |
|
cfg, create_cfg = read_config(input_cfg) |
|
else: |
|
cfg, create_cfg = deepcopy(input_cfg) |
|
env_fn = None if env_setting is None else env_setting[0] |
|
cfg = compile_config( |
|
cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, save_path='eval_config.py' |
|
) |
|
|
|
|
|
if env_setting is None: |
|
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False) |
|
else: |
|
env_fn, _, evaluator_env_cfg = env_setting |
|
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) |
|
evaluator_env.seed(seed, dynamic_seed=False) |
|
if replay_path is None: |
|
replay_path = cfg.env.get('replay_path', None) |
|
if replay_path: |
|
evaluator_env.enable_save_replay(replay_path) |
|
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) |
|
policy = create_policy(cfg.policy, model=model, enable_field=['eval']) |
|
if state_dict is None: |
|
if load_path is None: |
|
load_path = cfg.policy.learn.learner.load_path |
|
state_dict = torch.load(load_path, map_location='cpu') |
|
policy.eval_mode.load_state_dict(state_dict) |
|
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode) |
|
|
|
|
|
_, episode_info = evaluator.eval() |
|
episode_return = np.mean(episode_info['eval_episode_return']) |
|
print('Eval is over! The performance of your RL policy is {}'.format(episode_return)) |
|
return episode_return |
|
|
|
|
|
def collect_demo_data( |
|
input_cfg: Union[str, dict], |
|
seed: int, |
|
collect_count: int, |
|
expert_data_path: Optional[str] = None, |
|
env_setting: Optional[List[Any]] = None, |
|
model: Optional[torch.nn.Module] = None, |
|
state_dict: Optional[dict] = None, |
|
state_dict_path: Optional[str] = None, |
|
) -> None: |
|
r""" |
|
Overview: |
|
Collect demonstration data by the trained policy. |
|
Arguments: |
|
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ |
|
``str`` type means config file path. \ |
|
``Tuple[dict, dict]`` type means [user_config, create_cfg]. |
|
- seed (:obj:`int`): Random seed. |
|
- collect_count (:obj:`int`): The count of collected data. |
|
- expert_data_path (:obj:`str`): File path of the expert demo data will be written to. |
|
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ |
|
``BaseEnv`` subclass, collector env config, and evaluator env config. |
|
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. |
|
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. |
|
- state_dict_path (:obj:`Optional[str]`): The path of the state_dict of policy or model. |
|
""" |
|
if isinstance(input_cfg, str): |
|
cfg, create_cfg = read_config(input_cfg) |
|
else: |
|
cfg, create_cfg = deepcopy(input_cfg) |
|
env_fn = None if env_setting is None else env_setting[0] |
|
cfg = compile_config( |
|
cfg, |
|
seed=seed, |
|
env=env_fn, |
|
auto=True, |
|
create_cfg=create_cfg, |
|
save_cfg=True, |
|
save_path='collect_demo_data_config.py' |
|
) |
|
if expert_data_path is None: |
|
expert_data_path = cfg.policy.collect.save_path |
|
|
|
|
|
if env_setting is None: |
|
env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env, eval_=False) |
|
else: |
|
env_fn, collector_env_cfg, _ = env_setting |
|
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) |
|
collector_env.seed(seed) |
|
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) |
|
policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
collect_demo_policy = policy.collect_mode |
|
if state_dict is None: |
|
assert state_dict_path is not None |
|
state_dict = torch.load(state_dict_path, map_location='cpu') |
|
policy.collect_mode.load_state_dict(state_dict) |
|
collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) |
|
|
|
if hasattr(cfg.policy.other, 'eps'): |
|
policy_kwargs = {'eps': 0.} |
|
else: |
|
policy_kwargs = None |
|
|
|
|
|
exp_data = collector.collect(n_sample=collect_count, policy_kwargs=policy_kwargs) |
|
if cfg.policy.cuda: |
|
exp_data = to_device(exp_data, 'cpu') |
|
|
|
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) |
|
print('Collect demo data successfully') |
|
|
|
|
|
def collect_episodic_demo_data( |
|
input_cfg: Union[str, dict], |
|
seed: int, |
|
collect_count: int, |
|
expert_data_path: str, |
|
env_setting: Optional[List[Any]] = None, |
|
model: Optional[torch.nn.Module] = None, |
|
state_dict: Optional[dict] = None, |
|
state_dict_path: Optional[str] = None, |
|
) -> None: |
|
r""" |
|
Overview: |
|
Collect episodic demonstration data by the trained policy. |
|
Arguments: |
|
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ |
|
``str`` type means config file path. \ |
|
``Tuple[dict, dict]`` type means [user_config, create_cfg]. |
|
- seed (:obj:`int`): Random seed. |
|
- collect_count (:obj:`int`): The count of collected data. |
|
- expert_data_path (:obj:`str`): File path of the expert demo data will be written to. |
|
- env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \ |
|
``BaseEnv`` subclass, collector env config, and evaluator env config. |
|
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. |
|
- state_dict (:obj:`Optional[dict]`): The state_dict of policy or model. |
|
- state_dict_path (:obj:'str') the abs path of the state dict |
|
""" |
|
if isinstance(input_cfg, str): |
|
cfg, create_cfg = read_config(input_cfg) |
|
else: |
|
cfg, create_cfg = deepcopy(input_cfg) |
|
env_fn = None if env_setting is None else env_setting[0] |
|
cfg = compile_config( |
|
cfg, |
|
collector=EpisodeSerialCollector, |
|
seed=seed, |
|
env=env_fn, |
|
auto=True, |
|
create_cfg=create_cfg, |
|
save_cfg=True, |
|
save_path='collect_demo_data_config.py' |
|
) |
|
|
|
|
|
if env_setting is None: |
|
env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env, eval_=False) |
|
else: |
|
env_fn, collector_env_cfg, _ = env_setting |
|
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) |
|
collector_env.seed(seed) |
|
set_pkg_seed(seed, use_cuda=cfg.policy.cuda) |
|
policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval']) |
|
collect_demo_policy = policy.collect_mode |
|
if state_dict is None: |
|
assert state_dict_path is not None |
|
state_dict = torch.load(state_dict_path, map_location='cpu') |
|
policy.collect_mode.load_state_dict(state_dict) |
|
collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy) |
|
|
|
if hasattr(cfg.policy.other, 'eps'): |
|
policy_kwargs = {'eps': 0.} |
|
else: |
|
policy_kwargs = None |
|
|
|
|
|
exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs) |
|
if cfg.policy.cuda: |
|
exp_data = to_device(exp_data, 'cpu') |
|
|
|
offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive')) |
|
print('Collect episodic demo data successfully') |
|
|
|
|
|
def episode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None: |
|
r""" |
|
Overview: |
|
Transfer episodic data into nstep transitions. |
|
Arguments: |
|
- data_path (:obj:str): data path that stores the pkl file |
|
- expert_data_path (:obj:`str`): File path of the expert demo data will be written to. |
|
- nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}. |
|
|
|
""" |
|
with open(data_path, 'rb') as f: |
|
_dict = pickle.load(f) |
|
post_process_data = [] |
|
for i in range(len(_dict)): |
|
data = get_nstep_return_data(_dict[i], nstep) |
|
post_process_data.extend(data) |
|
offline_data_save_type( |
|
post_process_data, |
|
expert_data_path, |
|
) |
|
|
|
|
|
def episode_to_transitions_filter(data_path: str, expert_data_path: str, nstep: int, min_episode_return: int) -> None: |
|
r""" |
|
Overview: |
|
Transfer episodic data into n-step transitions and only take the episode data whose return is larger than |
|
min_episode_return. |
|
Arguments: |
|
- data_path (:obj:str): data path that stores the pkl file |
|
- expert_data_path (:obj:`str`): File path of the expert demo data will be written to. |
|
- nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}. |
|
|
|
""" |
|
with open(data_path, 'rb') as f: |
|
_dict = pickle.load(f) |
|
post_process_data = [] |
|
for i in range(len(_dict)): |
|
episode_returns = torch.stack([_dict[i][j]['reward'] for j in range(_dict[i].__len__())], axis=0) |
|
if episode_returns.sum() < min_episode_return: |
|
continue |
|
data = get_nstep_return_data(_dict[i], nstep) |
|
post_process_data.extend(data) |
|
offline_data_save_type( |
|
post_process_data, |
|
expert_data_path, |
|
) |
|
|