File size: 6,652 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import argparse
import torch
import os
from typing import Union, Optional, List, Any
from functools import partial
from copy import deepcopy

from ding.config import compile_config, read_config
from ding.worker import EpisodeSerialCollector
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.torch_utils import to_device
from ding.utils import set_pkg_seed
from ding.utils.data import offline_data_save_type
from ding.utils.data import default_collate


def collect_episodic_demo_data_for_trex(
    input_cfg: Union[str, dict],
    seed: int,
    collect_count: int,
    rank: int,
    env_setting: Optional[List[Any]] = None,
    model: Optional[torch.nn.Module] = None,
    state_dict: Optional[dict] = None,
    state_dict_path: Optional[str] = None,
):
    """
    Overview:
        Collect episodic demonstration data by the trained policy for trex specifically.
    Arguments:
        - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
            ``str`` type means config file path. \
            ``Tuple[dict, dict]`` type means [user_config, create_cfg].
        - seed (:obj:`int`): Random seed.
        - collect_count (:obj:`int`): The count of collected data.
        - rank (:obj:`int`): The episode ranking.
        - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
            ``BaseEnv`` subclass, collector env config, and evaluator env config.
        - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
        - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
        - state_dict_path (:obj:'str') The abs path of the state dict.
    """
    if isinstance(input_cfg, str):
        cfg, create_cfg = read_config(input_cfg)
    else:
        cfg, create_cfg = deepcopy(input_cfg)
    create_cfg.policy.type += '_command'
    env_fn = None if env_setting is None else env_setting[0]
    cfg.env.collector_env_num = 1
    cfg = compile_config(
        cfg,
        collector=EpisodeSerialCollector,
        seed=seed,
        env=env_fn,
        auto=True,
        create_cfg=create_cfg,
        save_cfg=True,
        save_path='collect_demo_data_config.py'
    )

    # Create components: env, policy, collector
    if env_setting is None:
        env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env)
    else:
        env_fn, collector_env_cfg, _ = env_setting
    collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
    collector_env.seed(seed)
    set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
    policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval'])
    collect_demo_policy = policy.collect_mode
    if state_dict is None:
        assert state_dict_path is not None
        state_dict = torch.load(state_dict_path, map_location='cpu')
    policy.collect_mode.load_state_dict(state_dict)
    collector = EpisodeSerialCollector(
        cfg.policy.collect.collector, collector_env, collect_demo_policy, exp_name=cfg.exp_name
    )

    policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
        else {'eps': cfg.policy.other.eps.get('collect', 0.2)}

    # Let's collect some sub-optimal demostrations
    exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs)
    if cfg.policy.cuda:
        exp_data = to_device(exp_data, 'cpu')
    # Save data transitions.
    print('Collect {}th episodic demo data successfully'.format(rank))
    return exp_data


def trex_get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg', type=str, default='abs path for a config')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    args = parser.parse_known_args()[0]
    return args


def trex_collecting_data(args=None):
    if args is None:
        args = trex_get_args()  # TODO(nyz) use sub-command in cli
    if isinstance(args.cfg, str):
        cfg, create_cfg = read_config(args.cfg)
    else:
        cfg, create_cfg = deepcopy(args.cfg)
    data_path = cfg.exp_name
    expert_model_path = cfg.reward_model.expert_model_path  # directory path
    checkpoint_min = cfg.reward_model.checkpoint_min
    checkpoint_max = cfg.reward_model.checkpoint_max
    checkpoint_step = cfg.reward_model.checkpoint_step
    checkpoints = []
    for i in range(checkpoint_min, checkpoint_max + checkpoint_step, checkpoint_step):
        checkpoints.append(str(i))
    data_for_save = {}
    learning_returns = []
    learning_rewards = []
    episodes_data = []
    for checkpoint in checkpoints:
        num_per_ckpt = 1
        model_path = expert_model_path + \
        '/ckpt/iteration_' + checkpoint + '.pth.tar'
        seed = args.seed + (int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)
        exp_data = collect_episodic_demo_data_for_trex(
            deepcopy(args.cfg),
            seed,
            state_dict_path=model_path,
            collect_count=num_per_ckpt,
            rank=(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + 1
        )
        data_for_save[(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)] = exp_data
        obs = [list(default_collate(exp_data[i])['obs'].numpy()) for i in range(len(exp_data))]
        rewards = [default_collate(exp_data[i])['reward'].tolist() for i in range(len(exp_data))]
        sum_rewards = [torch.sum(default_collate(exp_data[i])['reward']).item() for i in range(len(exp_data))]

        learning_rewards.append(rewards)
        learning_returns.append(sum_rewards)
        episodes_data.append(obs)
    offline_data_save_type(
        data_for_save, data_path + '/suboptimal_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
    )
    # if not compiled_cfg.reward_model.auto: more feature
    offline_data_save_type(
        episodes_data, data_path + '/episodes_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
    )
    offline_data_save_type(
        learning_returns, data_path + '/learning_returns.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
    )
    offline_data_save_type(
        learning_rewards, data_path + '/learning_rewards.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
    )
    offline_data_save_type(
        checkpoints, data_path + '/checkpoints.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
    )
    return checkpoints, episodes_data, learning_returns, learning_rewards


if __name__ == '__main__':
    trex_collecting_data()