File size: 8,916 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from typing import TYPE_CHECKING, Callable, List, Tuple, Any
from functools import reduce
import treetensor.torch as ttorch
import numpy as np
from ditk import logging
from ding.utils import EasyTimer
from ding.envs import BaseEnvManager
from ding.policy import Policy
from ding.torch_utils import to_ndarray, get_shape0

if TYPE_CHECKING:
    from ding.framework import OnlineRLContext


class TransitionList:

    def __init__(self, env_num: int) -> None:
        self.env_num = env_num
        self._transitions = [[] for _ in range(env_num)]
        self._done_idx = [[] for _ in range(env_num)]

    def append(self, env_id: int, transition: Any) -> None:
        self._transitions[env_id].append(transition)
        if transition.done:
            self._done_idx[env_id].append(len(self._transitions[env_id]))

    def to_trajectories(self) -> Tuple[List[Any], List[int]]:
        trajectories = sum(self._transitions, [])
        lengths = [len(t) for t in self._transitions]
        trajectory_end_idx = [reduce(lambda x, y: x + y, lengths[:i + 1]) for i in range(len(lengths))]
        trajectory_end_idx = [t - 1 for t in trajectory_end_idx]
        return trajectories, trajectory_end_idx

    def to_episodes(self) -> List[List[Any]]:
        episodes = []
        for env_id in range(self.env_num):
            last_idx = 0
            for done_idx in self._done_idx[env_id]:
                episodes.append(self._transitions[env_id][last_idx:done_idx])
                last_idx = done_idx
        return episodes

    def clear(self):
        for item in self._transitions:
            item.clear()
        for item in self._done_idx:
            item.clear()


def inferencer(seed: int, policy: Policy, env: BaseEnvManager) -> Callable:
    """
    Overview:
        The middleware that executes the inference process.
    Arguments:
        - seed (:obj:`int`): Random seed.
        - policy (:obj:`Policy`): The policy to be inferred.
        - env (:obj:`BaseEnvManager`): The env where the inference process is performed. \
            The env.ready_obs (:obj:`tnp.array`) will be used as model input.
    """

    env.seed(seed)

    def _inference(ctx: "OnlineRLContext"):
        """
        Output of ctx:
            - obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \
                from all collector environments.
            - action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id.
            - inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \
                and the value is inference result (Dict).
        """

        if env.closed:
            env.launch()

        obs = ttorch.as_tensor(env.ready_obs)
        ctx.obs = obs
        obs = obs.to(dtype=ttorch.float32)
        # TODO mask necessary rollout

        obs = {i: obs[i] for i in range(get_shape0(obs))}  # TBD
        inference_output = policy.forward(obs, **ctx.collect_kwargs)
        ctx.action = [to_ndarray(v['action']) for v in inference_output.values()]  # TBD
        ctx.inference_output = inference_output

    return _inference


def rolloutor(
        policy: Policy,
        env: BaseEnvManager,
        transitions: TransitionList,
        collect_print_freq=100,
) -> Callable:
    """
    Overview:
        The middleware that executes the transition process in the env.
    Arguments:
        - policy (:obj:`Policy`): The policy to be used during transition.
        - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
                its derivatives are supported.
        - transitions (:obj:`TransitionList`): The transition information which will be filled \
            in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \
            and `done`.
    """

    env_episode_id = [_ for _ in range(env.env_num)]
    current_id = env.env_num
    timer = EasyTimer()
    last_train_iter = 0
    total_envstep_count = 0
    total_episode_count = 0
    total_train_sample_count = 0
    env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)}
    episode_info = []

    def _rollout(ctx: "OnlineRLContext"):
        """
        Input of ctx:
            - action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process.
            - obs (:obj:`Dict[Tensor]`): The states fed into the transition dict.
            - inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \
                transition dict.
            - train_iter (:obj:`int`): The train iteration count to be fed into the transition dict.
            - env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \
                transition call.
            - env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \
                trajectory stops.
        """

        nonlocal current_id, env_info, episode_info, timer, \
        total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter
        timesteps = env.step(ctx.action)
        ctx.env_step += len(timesteps)
        timesteps = [t.tensor() for t in timesteps]

        collected_sample = 0
        collected_step = 0
        collected_episode = 0
        interaction_duration = timer.value / len(timesteps)
        for i, timestep in enumerate(timesteps):
            with timer:
                transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep)
                transition = ttorch.as_tensor(transition)
                transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
                transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]])
                transitions.append(timestep.env_id, transition)

                collected_step += 1
                collected_sample += len(transition.obs)
                env_info[timestep.env_id.item()]['step'] += 1
                env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs)

            env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration
            if timestep.done:
                info = {
                    'reward': timestep.info['eval_episode_return'],
                    'time': env_info[timestep.env_id.item()]['time'],
                    'step': env_info[timestep.env_id.item()]['step'],
                    'train_sample': env_info[timestep.env_id.item()]['train_sample'],
                }

                episode_info.append(info)
                policy.reset([timestep.env_id.item()])
                env_episode_id[timestep.env_id.item()] = current_id
                collected_episode += 1
                current_id += 1
                ctx.env_episode += 1

        total_envstep_count += collected_step
        total_episode_count += collected_episode
        total_train_sample_count += collected_sample

        if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0:
            output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count)
            last_train_iter = ctx.train_iter

    return _rollout


def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None:
    """
    Overview:
        Print the output log information. You can refer to the docs of `Best Practice` to understand \
        the training generated logs and tensorboards.
    Arguments:
        - train_iter (:obj:`int`): the number of training iteration.
    """
    episode_count = len(episode_info)
    envstep_count = sum([d['step'] for d in episode_info])
    train_sample_count = sum([d['train_sample'] for d in episode_info])
    duration = sum([d['time'] for d in episode_info])
    episode_return = [d['reward'].item() for d in episode_info]
    info = {
        'episode_count': episode_count,
        'envstep_count': envstep_count,
        'train_sample_count': train_sample_count,
        'avg_envstep_per_episode': envstep_count / episode_count,
        'avg_sample_per_episode': train_sample_count / episode_count,
        'avg_envstep_per_sec': envstep_count / duration,
        'avg_train_sample_per_sec': train_sample_count / duration,
        'avg_episode_per_sec': episode_count / duration,
        'reward_mean': np.mean(episode_return),
        'reward_std': np.std(episode_return),
        'reward_max': np.max(episode_return),
        'reward_min': np.min(episode_return),
        'total_envstep_count': total_envstep_count,
        'total_train_sample_count': total_train_sample_count,
        'total_episode_count': total_episode_count,
        # 'each_reward': episode_return,
    }
    episode_info.clear()
    logging.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))