from typing import TYPE_CHECKING from easydict import EasyDict import treetensor.torch as ttorch from ding.policy import get_random_policy from ding.envs import BaseEnvManager from ding.framework import task from .functional import inferencer, rolloutor, TransitionList if TYPE_CHECKING: from ding.framework import OnlineRLContext class StepCollector: """ Overview: The class of the collector running by steps, including model inference and transition \ process. Use the `__call__` method to execute the whole collection process. """ def __new__(cls, *args, **kwargs): if task.router.is_active and not task.has_role(task.role.COLLECTOR): return task.void() return super(StepCollector, cls).__new__(cls) def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: """ Arguments: - cfg (:obj:`EasyDict`): Config. - policy (:obj:`Policy`): The policy to be collected. - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ its derivatives are supported. - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ typically used in initial runs. """ self.cfg = cfg self.env = env self.policy = policy self.random_collect_size = random_collect_size self._transitions = TransitionList(self.env.env_num) self._inferencer = task.wrap(inferencer(cfg.seed, policy, env)) self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions)) def __call__(self, ctx: "OnlineRLContext") -> None: """ Overview: An encapsulation of inference and rollout middleware. Stop when completing \ the target number of steps. Input of ctx: - env_step (:obj:`int`): The env steps which will increase during collection. """ old = ctx.env_step if self.random_collect_size > 0 and old < self.random_collect_size: target_size = self.random_collect_size - old random_policy = get_random_policy(self.cfg, self.policy, self.env) current_inferencer = task.wrap(inferencer(self.cfg.seed, random_policy, self.env)) else: # compatible with old config, a train sample = unroll_len step target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len current_inferencer = self._inferencer while True: current_inferencer(ctx) self._rolloutor(ctx) if ctx.env_step - old >= target_size: ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories() self._transitions.clear() break class PPOFStepCollector: """ Overview: The class of the collector running by steps, including model inference and transition \ process. Use the `__call__` method to execute the whole collection process. """ def __new__(cls, *args, **kwargs): if task.router.is_active and not task.has_role(task.role.COLLECTOR): return task.void() return super(PPOFStepCollector, cls).__new__(cls) def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None: """ Arguments: - seed (:obj:`int`): Random seed. - policy (:obj:`Policy`): The policy to be collected. - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ its derivatives are supported. """ self.env = env self.env.seed(seed) self.policy = policy self.n_sample = n_sample self.unroll_len = unroll_len self._transitions = TransitionList(self.env.env_num) self._env_episode_id = [_ for _ in range(env.env_num)] self._current_id = env.env_num def __call__(self, ctx: "OnlineRLContext") -> None: """ Overview: An encapsulation of inference and rollout middleware. Stop when completing \ the target number of steps. Input of ctx: - env_step (:obj:`int`): The env steps which will increase during collection. """ device = self.policy._device old = ctx.env_step target_size = self.n_sample * self.unroll_len if self.env.closed: self.env.launch() while True: obs = ttorch.as_tensor(self.env.ready_obs).to(dtype=ttorch.float32) obs = obs.to(device) inference_output = self.policy.collect(obs, **ctx.collect_kwargs) inference_output = inference_output.cpu() action = inference_output.action.numpy() timesteps = self.env.step(action) ctx.env_step += len(timesteps) obs = obs.cpu() for i, timestep in enumerate(timesteps): transition = self.policy.process_transition(obs[i], inference_output[i], timestep) transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) transition.env_data_id = ttorch.as_tensor([self._env_episode_id[timestep.env_id]]) self._transitions.append(timestep.env_id, transition) if timestep.done: self.policy.reset([timestep.env_id]) self._env_episode_id[timestep.env_id] = self._current_id self._current_id += 1 ctx.env_episode += 1 if ctx.env_step - old >= target_size: ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories() self._transitions.clear() break class EpisodeCollector: """ Overview: The class of the collector running by episodes, including model inference and transition \ process. Use the `__call__` method to execute the whole collection process. """ def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: """ Arguments: - cfg (:obj:`EasyDict`): Config. - policy (:obj:`Policy`): The policy to be collected. - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ its derivatives are supported. - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ typically used in initial runs. """ self.cfg = cfg self.env = env self.policy = policy self.random_collect_size = random_collect_size self._transitions = TransitionList(self.env.env_num) self._inferencer = task.wrap(inferencer(cfg.seed, policy, env)) self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions)) def __call__(self, ctx: "OnlineRLContext") -> None: """ Overview: An encapsulation of inference and rollout middleware. Stop when completing the \ target number of episodes. Input of ctx: - env_episode (:obj:`int`): The env env_episode which will increase during collection. """ old = ctx.env_episode if self.random_collect_size > 0 and old < self.random_collect_size: target_size = self.random_collect_size - old random_policy = get_random_policy(self.cfg, self.policy, self.env) current_inferencer = task.wrap(inferencer(self.cfg, random_policy, self.env)) else: target_size = self.cfg.policy.collect.n_episode current_inferencer = self._inferencer while True: current_inferencer(ctx) self._rolloutor(ctx) if ctx.env_episode - old >= target_size: ctx.episodes = self._transitions.to_episodes() self._transitions.clear() break # TODO battle collector