|
from collections import defaultdict |
|
import math |
|
import queue |
|
from time import sleep, time |
|
import gym |
|
from ding.framework import Supervisor |
|
from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable |
|
from ding.framework.supervisor import ChildType, RecvPayload, SendPayload |
|
from ding.utils import make_key_as_identifier |
|
from ditk import logging |
|
from ding.data import ShmBufferContainer |
|
import enum |
|
import treetensor.numpy as tnp |
|
import numbers |
|
if TYPE_CHECKING: |
|
from gym.spaces import Space |
|
|
|
|
|
class EnvState(enum.IntEnum): |
|
""" |
|
VOID -> RUN -> DONE |
|
""" |
|
VOID = 0 |
|
INIT = 1 |
|
RUN = 2 |
|
RESET = 3 |
|
DONE = 4 |
|
ERROR = 5 |
|
NEED_RESET = 6 |
|
|
|
|
|
class EnvRetryType(str, enum.Enum): |
|
RESET = "reset" |
|
RENEW = "renew" |
|
|
|
|
|
class EnvSupervisor(Supervisor): |
|
""" |
|
Manage multiple envs with supervisor. |
|
|
|
New features (compared to env manager): |
|
- Consistent interface in multi-process and multi-threaded mode. |
|
- Add asynchronous features and recommend using asynchronous methods. |
|
- Reset is performed after an error is encountered in the step method. |
|
|
|
Breaking changes (compared to env manager): |
|
- Without some states. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
type_: ChildType = ChildType.PROCESS, |
|
env_fn: List[Callable] = None, |
|
retry_type: EnvRetryType = EnvRetryType.RESET, |
|
max_try: Optional[int] = None, |
|
max_retry: Optional[int] = None, |
|
auto_reset: bool = True, |
|
reset_timeout: Optional[int] = None, |
|
step_timeout: Optional[int] = None, |
|
retry_waiting_time: Optional[int] = None, |
|
episode_num: int = float("inf"), |
|
shared_memory: bool = True, |
|
copy_on_get: bool = True, |
|
**kwargs |
|
) -> None: |
|
""" |
|
Overview: |
|
Supervisor that manage a group of envs. |
|
Arguments: |
|
- type_ (:obj:`ChildType`): Type of child process. |
|
- env_fn (:obj:`List[Callable]`): The function to create environment |
|
- retry_type (:obj:`EnvRetryType`): Retry reset or renew env. |
|
- max_try (:obj:`EasyDict`): Max try times for reset or step action. |
|
- max_retry (:obj:`Optional[int]`): Alias of max_try. |
|
- auto_reset (:obj:`bool`): Auto reset env if reach done. |
|
- reset_timeout (:obj:`Optional[int]`): Timeout in seconds for reset. |
|
- step_timeout (:obj:`Optional[int]`): Timeout in seconds for step. |
|
- retry_waiting_time (:obj:`Optional[float]`): Wait time on each retry. |
|
- shared_memory (:obj:`bool`): Use shared memory in multiprocessing. |
|
- copy_on_get (:obj:`bool`): Use copy on get in multiprocessing. |
|
""" |
|
if kwargs: |
|
logging.warning("Unknown parameters on env supervisor: {}".format(kwargs)) |
|
super().__init__(type_=type_) |
|
if type_ is not ChildType.PROCESS and (shared_memory or copy_on_get): |
|
logging.warning("shared_memory and copy_on_get only works in process mode.") |
|
self._shared_memory = type_ is ChildType.PROCESS and shared_memory |
|
self._copy_on_get = type_ is ChildType.PROCESS and copy_on_get |
|
self._env_fn = env_fn |
|
self._create_env_ref() |
|
self._obs_buffers = None |
|
if env_fn: |
|
if self._shared_memory: |
|
obs_space = self._observation_space |
|
if isinstance(obs_space, gym.spaces.Dict): |
|
|
|
|
|
|
|
shape = {k: v.shape for k, v in obs_space.spaces.items()} |
|
dtype = {k: v.dtype for k, v in obs_space.spaces.items()} |
|
else: |
|
shape = obs_space.shape |
|
dtype = obs_space.dtype |
|
self._obs_buffers = { |
|
env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get) |
|
for env_id in range(len(self._env_fn)) |
|
} |
|
for env_init in env_fn: |
|
self.register(env_init, shm_buffer=self._obs_buffers, shm_callback=self._shm_callback) |
|
else: |
|
for env_init in env_fn: |
|
self.register(env_init) |
|
self._retry_type = retry_type |
|
self._auto_reset = auto_reset |
|
if max_retry: |
|
logging.warning("The `max_retry` is going to be deprecated, use `max_try` instead!") |
|
self._max_try = max_try or max_retry or 1 |
|
self._reset_timeout = reset_timeout |
|
self._step_timeout = step_timeout |
|
self._retry_waiting_time = retry_waiting_time |
|
self._env_replay_path = None |
|
self._episode_num = episode_num |
|
self._init_states() |
|
|
|
def _init_states(self): |
|
self._env_seed = {} |
|
self._env_dynamic_seed = None |
|
self._env_replay_path = None |
|
self._env_states = {} |
|
self._reset_param = {} |
|
self._ready_obs = {} |
|
self._env_episode_count = {i: 0 for i in range(self.env_num)} |
|
self._retry_times = defaultdict(lambda: 0) |
|
self._last_called = defaultdict(lambda: {"step": math.inf, "reset": math.inf}) |
|
|
|
def _shm_callback(self, payload: RecvPayload, obs_buffers: Any): |
|
""" |
|
Overview: |
|
This method will be called in child worker, so we can put large data into shared memory |
|
and replace the original payload data to none, then reduce the serialization/deserialization cost. |
|
""" |
|
if payload.method == "reset" and payload.data is not None: |
|
obs_buffers[payload.proc_id].fill(payload.data) |
|
payload.data = None |
|
elif payload.method == "step" and payload.data is not None: |
|
obs_buffers[payload.proc_id].fill(payload.data.obs) |
|
payload.data._replace(obs=None) |
|
|
|
def _create_env_ref(self): |
|
|
|
self._env_ref = self._env_fn[0]() |
|
self._env_ref.reset() |
|
self._observation_space = self._env_ref.observation_space |
|
self._action_space = self._env_ref.action_space |
|
self._reward_space = self._env_ref.reward_space |
|
self._env_ref.close() |
|
|
|
def step(self, actions: Union[Dict[int, List[Any]], List[Any]], block: bool = True) -> Optional[List[tnp.ndarray]]: |
|
""" |
|
Overview: |
|
Execute env step according to input actions. And reset an env if done. |
|
Arguments: |
|
- actions (:obj:`List[tnp.ndarray]`): Actions came from outer caller like policy, \ |
|
in structure of {env_id: actions}. |
|
- block (:obj:`bool`): If block, return timesteps, else return none. |
|
Returns: |
|
- timesteps (:obj:`List[tnp.ndarray]`): Each timestep is a tnp.array with observation, reward, done, \ |
|
info, env_id. |
|
""" |
|
assert not self.closed, "Env supervisor has closed." |
|
if isinstance(actions, List): |
|
actions = {i: p for i, p in enumerate(actions)} |
|
assert actions, "Action is empty!" |
|
|
|
send_payloads = [] |
|
|
|
for env_id, act in actions.items(): |
|
payload = SendPayload(proc_id=env_id, method="step", args=[act]) |
|
send_payloads.append(payload) |
|
self.send(payload) |
|
|
|
if not block: |
|
|
|
return |
|
|
|
|
|
recv_payloads = self.recv_all( |
|
send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._step_timeout |
|
) |
|
return [payload.data for payload in recv_payloads] |
|
|
|
def recv(self, ignore_err: bool = False) -> RecvPayload: |
|
""" |
|
Overview: |
|
Wait for recv payload, this function will block the thread. |
|
Arguments: |
|
- ignore_err (:obj:`bool`): If ignore_err is true, payload with error object will be discarded.\ |
|
This option will not catch the exception. |
|
Returns: |
|
- recv_payload (:obj:`RecvPayload`): Recv payload. |
|
""" |
|
self._detect_timeout() |
|
try: |
|
payload = super().recv(ignore_err=True, timeout=0.1) |
|
payload = self._recv_callback(payload=payload) |
|
if payload.err: |
|
return self.recv(ignore_err=ignore_err) |
|
else: |
|
return payload |
|
except queue.Empty: |
|
return self.recv(ignore_err=ignore_err) |
|
|
|
def _detect_timeout(self): |
|
""" |
|
Overview: |
|
Try to restart all timeout environments if detected timeout. |
|
""" |
|
for env_id in self._last_called: |
|
if self._step_timeout and time() - self._last_called[env_id]["step"] > self._step_timeout: |
|
payload = RecvPayload( |
|
proc_id=env_id, method="step", err=TimeoutError("Step timeout on env {}".format(env_id)) |
|
) |
|
self._recv_queue.put(payload) |
|
continue |
|
if self._reset_timeout and time() - self._last_called[env_id]["reset"] > self._reset_timeout: |
|
payload = RecvPayload( |
|
proc_id=env_id, method="reset", err=TimeoutError("Step timeout on env {}".format(env_id)) |
|
) |
|
self._recv_queue.put(payload) |
|
continue |
|
|
|
@property |
|
def env_num(self) -> int: |
|
return len(self._children) |
|
|
|
@property |
|
def observation_space(self) -> 'Space': |
|
return self._observation_space |
|
|
|
@property |
|
def action_space(self) -> 'Space': |
|
return self._action_space |
|
|
|
@property |
|
def reward_space(self) -> 'Space': |
|
return self._reward_space |
|
|
|
@property |
|
def ready_obs(self) -> tnp.array: |
|
""" |
|
Overview: |
|
Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios. |
|
Return: |
|
- ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data. |
|
Example: |
|
>>> obs = env_manager.ready_obs |
|
>>> action = model(obs) # model input np obs and output np action |
|
>>> timesteps = env_manager.step(action) |
|
""" |
|
active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN] |
|
active_env.sort() |
|
obs = [self._ready_obs.get(i) for i in active_env] |
|
if len(obs) == 0: |
|
return tnp.array([]) |
|
return tnp.stack(obs) |
|
|
|
@property |
|
def ready_obs_id(self) -> List[int]: |
|
return [i for i, s in self.env_states.items() if s == EnvState.RUN] |
|
|
|
@property |
|
def done(self) -> bool: |
|
return all([s == EnvState.DONE for s in self.env_states.values()]) |
|
|
|
@property |
|
def method_name_list(self) -> List[str]: |
|
return ['reset', 'step', 'seed', 'close', 'enable_save_replay'] |
|
|
|
@property |
|
def env_states(self) -> Dict[int, EnvState]: |
|
return {env_id: self._env_states.get(env_id) or EnvState.VOID for env_id in range(self.env_num)} |
|
|
|
def env_state_done(self, env_id: int) -> bool: |
|
return self.env_states[env_id] == EnvState.DONE |
|
|
|
def launch(self, reset_param: Optional[Dict] = None, block: bool = True) -> None: |
|
""" |
|
Overview: |
|
Set up the environments and their parameters. |
|
Arguments: |
|
- reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \ |
|
value is the cooresponding reset parameters. |
|
- block (:obj:`block`): Whether will block the process and wait for reset states. |
|
""" |
|
assert self.closed, "Please first close the env supervisor before launch it" |
|
if reset_param is not None: |
|
assert len(reset_param) == self.env_num |
|
self.start_link() |
|
self._send_seed(self._env_seed, self._env_dynamic_seed, block=block) |
|
self.reset(reset_param, block=block) |
|
self._enable_env_replay() |
|
|
|
def reset(self, reset_param: Optional[Dict[int, List[Any]]] = None, block: bool = True) -> None: |
|
""" |
|
Overview: |
|
Reset an environment. |
|
Arguments: |
|
- reset_param (:obj:`Optional[Dict[int, List[Any]]]`): Dict of reset parameters for each environment, \ |
|
key is the env_id, value is the cooresponding reset parameters. |
|
- block (:obj:`block`): Whether will block the process and wait for reset states. |
|
""" |
|
if not reset_param: |
|
reset_param = {i: {} for i in range(self.env_num)} |
|
elif isinstance(reset_param, List): |
|
reset_param = {i: p for i, p in enumerate(reset_param)} |
|
|
|
send_payloads = [] |
|
|
|
for env_id, kw_param in reset_param.items(): |
|
self._reset_param[env_id] = kw_param |
|
send_payloads += self._reset(env_id, kw_param=kw_param) |
|
|
|
if not block: |
|
return |
|
|
|
self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) |
|
|
|
def _recv_callback( |
|
self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None |
|
) -> RecvPayload: |
|
""" |
|
Overview: |
|
The callback function for each received payload, within this method will modify the state of \ |
|
each environment, replace objects in shared memory, and determine if a retry is needed due to an error. |
|
Arguments: |
|
- payload (:obj:`RecvPayload`): The received payload. |
|
- remain_payloads (:obj:`Optional[Dict[str, SendPayload]]`): The callback may be called many times \ |
|
until remain_payloads be cleared, you can append new payload into remain_payloads to call this \ |
|
callback recursively. |
|
""" |
|
self._set_shared_obs(payload=payload) |
|
self.change_state(payload=payload) |
|
if payload.method == "reset": |
|
return self._recv_reset_callback(payload=payload, remain_payloads=remain_payloads) |
|
elif payload.method == "step": |
|
return self._recv_step_callback(payload=payload, remain_payloads=remain_payloads) |
|
return payload |
|
|
|
def _set_shared_obs(self, payload: RecvPayload): |
|
if self._obs_buffers is None: |
|
return |
|
if payload.method == "reset" and payload.err is None: |
|
payload.data = self._obs_buffers[payload.proc_id].get() |
|
elif payload.method == "step" and payload.err is None: |
|
payload.data._replace(obs=self._obs_buffers[payload.proc_id].get()) |
|
|
|
def _recv_reset_callback( |
|
self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None |
|
) -> RecvPayload: |
|
assert payload.method == "reset", "Recv error callback({}) in reset callback!".format(payload.method) |
|
if remain_payloads is None: |
|
remain_payloads = {} |
|
env_id = payload.proc_id |
|
if payload.err: |
|
self._retry_times[env_id] += 1 |
|
if self._retry_times[env_id] > self._max_try - 1: |
|
self.shutdown(5) |
|
raise RuntimeError( |
|
"Env {} reset has exceeded max_try({}), and the latest exception is: {}".format( |
|
env_id, self._max_try, payload.err |
|
) |
|
) |
|
if self._retry_waiting_time: |
|
sleep(self._retry_waiting_time) |
|
if self._retry_type == EnvRetryType.RENEW: |
|
self._children[env_id].restart() |
|
send_payloads = self._reset(env_id) |
|
for p in send_payloads: |
|
remain_payloads[p.req_id] = p |
|
else: |
|
self._retry_times[env_id] = 0 |
|
self._ready_obs[env_id] = payload.data |
|
return payload |
|
|
|
def _recv_step_callback( |
|
self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None |
|
) -> RecvPayload: |
|
assert payload.method == "step", "Recv error callback({}) in step callback!".format(payload.method) |
|
if remain_payloads is None: |
|
remain_payloads = {} |
|
if payload.err: |
|
send_payloads = self._reset(payload.proc_id) |
|
for p in send_payloads: |
|
remain_payloads[p.req_id] = p |
|
info = {"abnormal": True, "err": payload.err} |
|
payload.data = tnp.array( |
|
{ |
|
'obs': None, |
|
'reward': None, |
|
'done': None, |
|
'info': info, |
|
'env_id': payload.proc_id |
|
} |
|
) |
|
else: |
|
obs, reward, done, info, *_ = payload.data |
|
if done: |
|
self._env_episode_count[payload.proc_id] += 1 |
|
if self._env_episode_count[payload.proc_id] < self._episode_num and self._auto_reset: |
|
send_payloads = self._reset(payload.proc_id) |
|
for p in send_payloads: |
|
remain_payloads[p.req_id] = p |
|
|
|
|
|
info = make_key_as_identifier(info) |
|
payload.data = tnp.array( |
|
{ |
|
'obs': obs, |
|
'reward': reward, |
|
'done': done, |
|
'info': info, |
|
'env_id': payload.proc_id |
|
} |
|
) |
|
self._ready_obs[payload.proc_id] = obs |
|
return payload |
|
|
|
def _reset(self, env_id: int, kw_param: Optional[Dict[str, Any]] = None) -> List[SendPayload]: |
|
""" |
|
Overview: |
|
Reset an environment. This method does not wait for the result to be returned. |
|
Arguments: |
|
- env_id (:obj:`int`): Environment id. |
|
- kw_param (:obj:`Optional[Dict[str, Any]]`): Reset parameters for the environment. |
|
Returns: |
|
- send_payloads (:obj:`List[SendPayload]`): The request payloads for seed and reset actions. |
|
""" |
|
assert not self.closed, "Env supervisor has closed." |
|
send_payloads = [] |
|
kw_param = kw_param or self._reset_param[env_id] |
|
|
|
if self._env_replay_path is not None and self.env_states[env_id] == EnvState.RUN: |
|
logging.warning("Please don't reset an unfinished env when you enable save replay, we just skip it") |
|
return send_payloads |
|
|
|
|
|
payload = SendPayload(proc_id=env_id, method="reset", kwargs=kw_param) |
|
send_payloads.append(payload) |
|
self.send(payload) |
|
|
|
return send_payloads |
|
|
|
def _send_seed(self, env_seed: Dict[int, int], env_dynamic_seed: Optional[bool] = None, block: bool = True) -> None: |
|
send_payloads = [] |
|
for env_id, seed in env_seed.items(): |
|
if seed is None: |
|
continue |
|
args = [seed] |
|
if env_dynamic_seed is not None: |
|
args.append(env_dynamic_seed) |
|
payload = SendPayload(proc_id=env_id, method="seed", args=args) |
|
send_payloads.append(payload) |
|
self.send(payload) |
|
if not block or not send_payloads: |
|
return |
|
self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout) |
|
|
|
def change_state(self, payload: RecvPayload): |
|
self._last_called[payload.proc_id][payload.method] = math.inf |
|
if payload.err: |
|
self._env_states[payload.proc_id] = EnvState.ERROR |
|
elif payload.method == "reset": |
|
self._env_states[payload.proc_id] = EnvState.RUN |
|
elif payload.method == "step": |
|
if payload.data[2]: |
|
self._env_states[payload.proc_id] = EnvState.DONE |
|
|
|
def send(self, payload: SendPayload) -> None: |
|
self._last_called[payload.proc_id][payload.method] = time() |
|
return super().send(payload) |
|
|
|
def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: Optional[bool] = None) -> None: |
|
""" |
|
Overview: |
|
Set the seed for each environment. The seed function will not be called until supervisor.launch \ |
|
was called. |
|
Arguments: |
|
- seed (:obj:`Union[Dict[int, int], List[int], int]`): List of seeds for each environment; \ |
|
Or one seed for the first environment and other seeds are generated automatically. \ |
|
Note that in threading mode, no matter how many seeds are given, only the last one will take effect. \ |
|
Because the execution in the thread is asynchronous, the results of each experiment \ |
|
are different even if a fixed seed is used. |
|
- dynamic_seed (:obj:`Optional[bool]`): Dynamic seed is used in the training environment, \ |
|
trying to make the random seed of each episode different, they are all generated in the reset \ |
|
method by a random generator 100 * np.random.randint(1 , 1000) (but the seed of this random \ |
|
number generator is fixed by the environmental seed method, guranteeing the reproducibility \ |
|
of the experiment). You need not pass the dynamic_seed parameter in the seed method, or pass \ |
|
the parameter as True. |
|
""" |
|
self._env_seed = {} |
|
if isinstance(seed, numbers.Integral): |
|
self._env_seed = {i: seed + i for i in range(self.env_num)} |
|
elif isinstance(seed, list): |
|
assert len(seed) == self.env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self.env_num) |
|
self._env_seed = {i: _seed for i, _seed in enumerate(seed)} |
|
elif isinstance(seed, dict): |
|
self._env_seed = {env_id: s for env_id, s in seed.items()} |
|
else: |
|
raise TypeError("Invalid seed arguments type: {}".format(type(seed))) |
|
self._env_dynamic_seed = dynamic_seed |
|
|
|
def enable_save_replay(self, replay_path: Union[List[str], str]) -> None: |
|
""" |
|
Overview: |
|
Set each env's replay save path. |
|
Arguments: |
|
- replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \ |
|
Or one path for all environments. |
|
""" |
|
if isinstance(replay_path, str): |
|
replay_path = [replay_path] * self.env_num |
|
self._env_replay_path = replay_path |
|
|
|
def _enable_env_replay(self): |
|
if self._env_replay_path is None: |
|
return |
|
send_payloads = [] |
|
for env_id, s in enumerate(self._env_replay_path): |
|
payload = SendPayload(proc_id=env_id, method="enable_save_replay", args=[s]) |
|
send_payloads.append(payload) |
|
self.send(payload) |
|
self.recv_all(send_payloads=send_payloads) |
|
|
|
def __getattr__(self, key: str) -> List[Any]: |
|
if not hasattr(self._env_ref, key): |
|
raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key)) |
|
return super().__getattr__(key) |
|
|
|
def close(self, timeout: Optional[float] = None) -> None: |
|
""" |
|
In order to be compatible with BaseEnvManager, the new version can use `shutdown` directly. |
|
""" |
|
self.shutdown(timeout=timeout) |
|
|
|
def shutdown(self, timeout: Optional[float] = None) -> None: |
|
if self._running: |
|
send_payloads = [] |
|
for env_id in range(self.env_num): |
|
payload = SendPayload(proc_id=env_id, method="close") |
|
send_payloads.append(payload) |
|
self.send(payload) |
|
self.recv_all(send_payloads=send_payloads, ignore_err=True, timeout=timeout) |
|
super().shutdown(timeout=timeout) |
|
self._init_states() |
|
|
|
@property |
|
def closed(self) -> bool: |
|
return not self._running |
|
|