zjowowen's picture
init space
079c32c
raw
history blame
8.92 kB
from typing import TYPE_CHECKING, Callable, List, Tuple, Any
from functools import reduce
import treetensor.torch as ttorch
import numpy as np
from ditk import logging
from ding.utils import EasyTimer
from ding.envs import BaseEnvManager
from ding.policy import Policy
from ding.torch_utils import to_ndarray, get_shape0
if TYPE_CHECKING:
from ding.framework import OnlineRLContext
class TransitionList:
def __init__(self, env_num: int) -> None:
self.env_num = env_num
self._transitions = [[] for _ in range(env_num)]
self._done_idx = [[] for _ in range(env_num)]
def append(self, env_id: int, transition: Any) -> None:
self._transitions[env_id].append(transition)
if transition.done:
self._done_idx[env_id].append(len(self._transitions[env_id]))
def to_trajectories(self) -> Tuple[List[Any], List[int]]:
trajectories = sum(self._transitions, [])
lengths = [len(t) for t in self._transitions]
trajectory_end_idx = [reduce(lambda x, y: x + y, lengths[:i + 1]) for i in range(len(lengths))]
trajectory_end_idx = [t - 1 for t in trajectory_end_idx]
return trajectories, trajectory_end_idx
def to_episodes(self) -> List[List[Any]]:
episodes = []
for env_id in range(self.env_num):
last_idx = 0
for done_idx in self._done_idx[env_id]:
episodes.append(self._transitions[env_id][last_idx:done_idx])
last_idx = done_idx
return episodes
def clear(self):
for item in self._transitions:
item.clear()
for item in self._done_idx:
item.clear()
def inferencer(seed: int, policy: Policy, env: BaseEnvManager) -> Callable:
"""
Overview:
The middleware that executes the inference process.
Arguments:
- seed (:obj:`int`): Random seed.
- policy (:obj:`Policy`): The policy to be inferred.
- env (:obj:`BaseEnvManager`): The env where the inference process is performed. \
The env.ready_obs (:obj:`tnp.array`) will be used as model input.
"""
env.seed(seed)
def _inference(ctx: "OnlineRLContext"):
"""
Output of ctx:
- obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \
from all collector environments.
- action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id.
- inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \
and the value is inference result (Dict).
"""
if env.closed:
env.launch()
obs = ttorch.as_tensor(env.ready_obs)
ctx.obs = obs
obs = obs.to(dtype=ttorch.float32)
# TODO mask necessary rollout
obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
inference_output = policy.forward(obs, **ctx.collect_kwargs)
ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD
ctx.inference_output = inference_output
return _inference
def rolloutor(
policy: Policy,
env: BaseEnvManager,
transitions: TransitionList,
collect_print_freq=100,
) -> Callable:
"""
Overview:
The middleware that executes the transition process in the env.
Arguments:
- policy (:obj:`Policy`): The policy to be used during transition.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
- transitions (:obj:`TransitionList`): The transition information which will be filled \
in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \
and `done`.
"""
env_episode_id = [_ for _ in range(env.env_num)]
current_id = env.env_num
timer = EasyTimer()
last_train_iter = 0
total_envstep_count = 0
total_episode_count = 0
total_train_sample_count = 0
env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)}
episode_info = []
def _rollout(ctx: "OnlineRLContext"):
"""
Input of ctx:
- action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process.
- obs (:obj:`Dict[Tensor]`): The states fed into the transition dict.
- inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \
transition dict.
- train_iter (:obj:`int`): The train iteration count to be fed into the transition dict.
- env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \
transition call.
- env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \
trajectory stops.
"""
nonlocal current_id, env_info, episode_info, timer, \
total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter
timesteps = env.step(ctx.action)
ctx.env_step += len(timesteps)
timesteps = [t.tensor() for t in timesteps]
collected_sample = 0
collected_step = 0
collected_episode = 0
interaction_duration = timer.value / len(timesteps)
for i, timestep in enumerate(timesteps):
with timer:
transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep)
transition = ttorch.as_tensor(transition)
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]])
transitions.append(timestep.env_id, transition)
collected_step += 1
collected_sample += len(transition.obs)
env_info[timestep.env_id.item()]['step'] += 1
env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs)
env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration
if timestep.done:
info = {
'reward': timestep.info['eval_episode_return'],
'time': env_info[timestep.env_id.item()]['time'],
'step': env_info[timestep.env_id.item()]['step'],
'train_sample': env_info[timestep.env_id.item()]['train_sample'],
}
episode_info.append(info)
policy.reset([timestep.env_id.item()])
env_episode_id[timestep.env_id.item()] = current_id
collected_episode += 1
current_id += 1
ctx.env_episode += 1
total_envstep_count += collected_step
total_episode_count += collected_episode
total_train_sample_count += collected_sample
if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0:
output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count)
last_train_iter = ctx.train_iter
return _rollout
def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None:
"""
Overview:
Print the output log information. You can refer to the docs of `Best Practice` to understand \
the training generated logs and tensorboards.
Arguments:
- train_iter (:obj:`int`): the number of training iteration.
"""
episode_count = len(episode_info)
envstep_count = sum([d['step'] for d in episode_info])
train_sample_count = sum([d['train_sample'] for d in episode_info])
duration = sum([d['time'] for d in episode_info])
episode_return = [d['reward'].item() for d in episode_info]
info = {
'episode_count': episode_count,
'envstep_count': envstep_count,
'train_sample_count': train_sample_count,
'avg_envstep_per_episode': envstep_count / episode_count,
'avg_sample_per_episode': train_sample_count / episode_count,
'avg_envstep_per_sec': envstep_count / duration,
'avg_train_sample_per_sec': train_sample_count / duration,
'avg_episode_per_sec': episode_count / duration,
'reward_mean': np.mean(episode_return),
'reward_std': np.std(episode_return),
'reward_max': np.max(episode_return),
'reward_min': np.min(episode_return),
'total_envstep_count': total_envstep_count,
'total_train_sample_count': total_train_sample_count,
'total_episode_count': total_episode_count,
# 'each_reward': episode_return,
}
episode_info.clear()
logging.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))