File size: 8,143 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
from typing import TYPE_CHECKING
from easydict import EasyDict
import treetensor.torch as ttorch
from ding.policy import get_random_policy
from ding.envs import BaseEnvManager
from ding.framework import task
from .functional import inferencer, rolloutor, TransitionList
if TYPE_CHECKING:
from ding.framework import OnlineRLContext
class StepCollector:
"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
"""
def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()
return super(StepCollector, cls).__new__(cls)
def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None:
"""
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be collected.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
- random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \
typically used in initial runs.
"""
self.cfg = cfg
self.env = env
self.policy = policy
self.random_collect_size = random_collect_size
self._transitions = TransitionList(self.env.env_num)
self._inferencer = task.wrap(inferencer(cfg.seed, policy, env))
self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions))
def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing \
the target number of steps.
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
old = ctx.env_step
if self.random_collect_size > 0 and old < self.random_collect_size:
target_size = self.random_collect_size - old
random_policy = get_random_policy(self.cfg, self.policy, self.env)
current_inferencer = task.wrap(inferencer(self.cfg.seed, random_policy, self.env))
else:
# compatible with old config, a train sample = unroll_len step
target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len
current_inferencer = self._inferencer
while True:
current_inferencer(ctx)
self._rolloutor(ctx)
if ctx.env_step - old >= target_size:
ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories()
self._transitions.clear()
break
class PPOFStepCollector:
"""
Overview:
The class of the collector running by steps, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
"""
def __new__(cls, *args, **kwargs):
if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()
return super(PPOFStepCollector, cls).__new__(cls)
def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None:
"""
Arguments:
- seed (:obj:`int`): Random seed.
- policy (:obj:`Policy`): The policy to be collected.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
"""
self.env = env
self.env.seed(seed)
self.policy = policy
self.n_sample = n_sample
self.unroll_len = unroll_len
self._transitions = TransitionList(self.env.env_num)
self._env_episode_id = [_ for _ in range(env.env_num)]
self._current_id = env.env_num
def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing \
the target number of steps.
Input of ctx:
- env_step (:obj:`int`): The env steps which will increase during collection.
"""
device = self.policy._device
old = ctx.env_step
target_size = self.n_sample * self.unroll_len
if self.env.closed:
self.env.launch()
while True:
obs = ttorch.as_tensor(self.env.ready_obs).to(dtype=ttorch.float32)
obs = obs.to(device)
inference_output = self.policy.collect(obs, **ctx.collect_kwargs)
inference_output = inference_output.cpu()
action = inference_output.action.numpy()
timesteps = self.env.step(action)
ctx.env_step += len(timesteps)
obs = obs.cpu()
for i, timestep in enumerate(timesteps):
transition = self.policy.process_transition(obs[i], inference_output[i], timestep)
transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
transition.env_data_id = ttorch.as_tensor([self._env_episode_id[timestep.env_id]])
self._transitions.append(timestep.env_id, transition)
if timestep.done:
self.policy.reset([timestep.env_id])
self._env_episode_id[timestep.env_id] = self._current_id
self._current_id += 1
ctx.env_episode += 1
if ctx.env_step - old >= target_size:
ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories()
self._transitions.clear()
break
class EpisodeCollector:
"""
Overview:
The class of the collector running by episodes, including model inference and transition \
process. Use the `__call__` method to execute the whole collection process.
"""
def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None:
"""
Arguments:
- cfg (:obj:`EasyDict`): Config.
- policy (:obj:`Policy`): The policy to be collected.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
- random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \
typically used in initial runs.
"""
self.cfg = cfg
self.env = env
self.policy = policy
self.random_collect_size = random_collect_size
self._transitions = TransitionList(self.env.env_num)
self._inferencer = task.wrap(inferencer(cfg.seed, policy, env))
self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions))
def __call__(self, ctx: "OnlineRLContext") -> None:
"""
Overview:
An encapsulation of inference and rollout middleware. Stop when completing the \
target number of episodes.
Input of ctx:
- env_episode (:obj:`int`): The env env_episode which will increase during collection.
"""
old = ctx.env_episode
if self.random_collect_size > 0 and old < self.random_collect_size:
target_size = self.random_collect_size - old
random_policy = get_random_policy(self.cfg, self.policy, self.env)
current_inferencer = task.wrap(inferencer(self.cfg, random_policy, self.env))
else:
target_size = self.cfg.policy.collect.n_episode
current_inferencer = self._inferencer
while True:
current_inferencer(ctx)
self._rolloutor(ctx)
if ctx.env_episode - old >= target_size:
ctx.episodes = self._transitions.to_episodes()
self._transitions.clear()
break
# TODO battle collector
|