|
from typing import List, Optional, Union, Dict |
|
from easydict import EasyDict |
|
import gym |
|
import gymnasium |
|
import copy |
|
import numpy as np |
|
import treetensor.numpy as tnp |
|
|
|
from ding.envs.common.common_function import affine_transform |
|
from ding.envs.env_wrappers import create_env_wrapper |
|
from ding.torch_utils import to_ndarray |
|
from ding.utils import CloudPickleWrapper |
|
from .base_env import BaseEnv, BaseEnvTimestep |
|
from .default_wrapper import get_default_wrappers |
|
|
|
|
|
class DingEnvWrapper(BaseEnv): |
|
""" |
|
Overview: |
|
This is a wrapper for the BaseEnv class, used to provide a consistent environment interface. |
|
Interfaces: |
|
__init__, reset, step, close, seed, random_action, _wrap_env, __repr__, create_collector_env_cfg, |
|
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone |
|
""" |
|
|
|
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None: |
|
""" |
|
Overview: |
|
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \ |
|
instance should be passed in. For the former, i.e., an environment instance: The `env` parameter must not \ |
|
be `None`, but should be the instance. It does not support subprocess environment manager. Thus, it is \ |
|
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \ |
|
The `cfg` parameter must contain `env_id`. |
|
Arguments: |
|
- env (:obj:`gym.Env`): An environment instance to be wrapped. |
|
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance. |
|
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True. |
|
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \ |
|
``evaluator``. Different caller may need different wrappers. Default is 'collector'. |
|
""" |
|
self._env = None |
|
self._raw_env = env |
|
self._cfg = cfg |
|
self._seed_api = seed_api |
|
self._caller = caller |
|
if self._cfg is None: |
|
self._cfg = {} |
|
self._cfg = EasyDict(self._cfg) |
|
if 'act_scale' not in self._cfg: |
|
self._cfg.act_scale = False |
|
if 'rew_clip' not in self._cfg: |
|
self._cfg.rew_clip = False |
|
if 'env_wrapper' not in self._cfg: |
|
self._cfg.env_wrapper = 'default' |
|
if 'env_id' not in self._cfg: |
|
self._cfg.env_id = None |
|
if env is not None: |
|
self._env = env |
|
self._wrap_env(caller) |
|
self._observation_space = self._env.observation_space |
|
self._action_space = self._env.action_space |
|
self._action_space.seed(0) |
|
self._reward_space = gym.spaces.Box( |
|
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 |
|
) |
|
self._init_flag = True |
|
else: |
|
assert 'env_id' in self._cfg |
|
self._init_flag = False |
|
self._observation_space = None |
|
self._action_space = None |
|
self._reward_space = None |
|
|
|
self._replay_path = None |
|
|
|
|
|
def reset(self) -> np.ndarray: |
|
""" |
|
Overview: |
|
Resets the state of the environment. If the environment is not initialized, it will be created first. |
|
Returns: |
|
- obs (:obj:`Dict`): The new observation after reset. |
|
""" |
|
if not self._init_flag: |
|
self._env = gym.make(self._cfg.env_id) |
|
self._wrap_env(self._caller) |
|
self._observation_space = self._env.observation_space |
|
self._action_space = self._env.action_space |
|
self._reward_space = gym.spaces.Box( |
|
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32 |
|
) |
|
self._init_flag = True |
|
if self._replay_path is not None: |
|
self._env = gym.wrappers.RecordVideo( |
|
self._env, |
|
video_folder=self._replay_path, |
|
episode_trigger=lambda episode_id: True, |
|
name_prefix='rl-video-{}'.format(id(self)) |
|
) |
|
self._replay_path = None |
|
if isinstance(self._env, gym.Env): |
|
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: |
|
np_seed = 100 * np.random.randint(1, 1000) |
|
if self._seed_api: |
|
self._env.seed(self._seed + np_seed) |
|
self._action_space.seed(self._seed + np_seed) |
|
elif hasattr(self, '_seed'): |
|
if self._seed_api: |
|
self._env.seed(self._seed) |
|
self._action_space.seed(self._seed) |
|
obs = self._env.reset() |
|
elif isinstance(self._env, gymnasium.Env): |
|
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: |
|
np_seed = 100 * np.random.randint(1, 1000) |
|
self._action_space.seed(self._seed + np_seed) |
|
obs = self._env.reset(seed=self._seed + np_seed) |
|
elif hasattr(self, '_seed'): |
|
self._action_space.seed(self._seed) |
|
obs = self._env.reset(seed=self._seed) |
|
else: |
|
obs = self._env.reset() |
|
else: |
|
raise RuntimeError("not support env type: {}".format(type(self._env))) |
|
if self.observation_space.dtype == np.float32: |
|
obs = to_ndarray(obs, dtype=np.float32) |
|
else: |
|
obs = to_ndarray(obs) |
|
return obs |
|
|
|
|
|
def close(self) -> None: |
|
""" |
|
Overview: |
|
Clean up the environment by closing and deleting it. |
|
This method should be called when the environment is no longer needed. |
|
Failing to call this method can lead to memory leaks. |
|
""" |
|
try: |
|
self._env.close() |
|
del self._env |
|
except: |
|
pass |
|
|
|
|
|
def seed(self, seed: int, dynamic_seed: bool = True) -> None: |
|
""" |
|
Overview: |
|
Set the seed for the environment. |
|
Arguments: |
|
- seed (:obj:`int`): The seed to set. |
|
- dynamic_seed (:obj:`bool`): Whether to use dynamic seed, default is True. |
|
""" |
|
self._seed = seed |
|
self._dynamic_seed = dynamic_seed |
|
np.random.seed(self._seed) |
|
|
|
|
|
def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep: |
|
""" |
|
Overview: |
|
Execute the given action in the environment, and return the timestep (observation, reward, done, info). |
|
Arguments: |
|
- action (:obj:`Union[np.int64, np.ndarray]`): The action to execute in the environment. |
|
Returns: |
|
- timestep (:obj:`BaseEnvTimestep`): The timestep after the action execution. |
|
""" |
|
action = self._judge_action_type(action) |
|
if self._cfg.act_scale: |
|
action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high) |
|
obs, rew, done, info = self._env.step(action) |
|
if self._cfg.rew_clip: |
|
rew = max(-10, rew) |
|
rew = np.float32(rew) |
|
if self.observation_space.dtype == np.float32: |
|
obs = to_ndarray(obs, dtype=np.float32) |
|
else: |
|
obs = to_ndarray(obs) |
|
rew = to_ndarray([rew], np.float32) |
|
return BaseEnvTimestep(obs, rew, done, info) |
|
|
|
def _judge_action_type(self, action: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]: |
|
""" |
|
Overview: |
|
Ensure the action taken by the agent is of the correct type. |
|
This method is used to standardize different action types to a common format. |
|
Arguments: |
|
- action (Union[np.ndarray, dict]): The action taken by the agent. |
|
Returns: |
|
- action (Union[np.ndarray, dict]): The formatted action. |
|
""" |
|
if isinstance(action, int): |
|
return action |
|
elif isinstance(action, np.int64): |
|
return int(action) |
|
elif isinstance(action, np.ndarray): |
|
if action.shape == (): |
|
action = action.item() |
|
elif action.shape == (1, ) and action.dtype == np.int64: |
|
action = action.item() |
|
return action |
|
elif isinstance(action, dict): |
|
for k, v in action.items(): |
|
action[k] = self._judge_action_type(v) |
|
return action |
|
elif isinstance(action, tnp.ndarray): |
|
return self._judge_action_type(action.json()) |
|
else: |
|
raise TypeError( |
|
'`action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( |
|
type(action), action |
|
) |
|
) |
|
|
|
def random_action(self) -> np.ndarray: |
|
""" |
|
Overview: |
|
Return a random action from the action space of the environment. |
|
Returns: |
|
- action (:obj:`np.ndarray`): The random action. |
|
""" |
|
random_action = self.action_space.sample() |
|
if isinstance(random_action, np.ndarray): |
|
pass |
|
elif isinstance(random_action, int): |
|
random_action = to_ndarray([random_action], dtype=np.int64) |
|
elif isinstance(random_action, dict): |
|
random_action = to_ndarray(random_action) |
|
else: |
|
raise TypeError( |
|
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format( |
|
type(random_action), random_action |
|
) |
|
) |
|
return random_action |
|
|
|
def _wrap_env(self, caller: str = 'collector') -> None: |
|
""" |
|
Overview: |
|
Wrap the environment according to the configuration. |
|
Arguments: |
|
- caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \ |
|
Different caller may need different wrappers. Default is 'collector'. |
|
""" |
|
|
|
wrapper_cfgs = self._cfg.env_wrapper |
|
if isinstance(wrapper_cfgs, str): |
|
wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id, caller) |
|
|
|
self._wrapper_cfgs = wrapper_cfgs |
|
for wrapper in self._wrapper_cfgs: |
|
|
|
if isinstance(wrapper, Dict): |
|
self._env = create_env_wrapper(self._env, wrapper) |
|
else: |
|
self._env = wrapper(self._env) |
|
|
|
def __repr__(self) -> str: |
|
""" |
|
Overview: |
|
Return the string representation of the instance. |
|
Returns: |
|
- str (:obj:`str`): The string representation of the instance. |
|
""" |
|
return "DI-engine Env({}), generated by DingEnvWrapper".format(self._cfg.env_id) |
|
|
|
@staticmethod |
|
def create_collector_env_cfg(cfg: dict) -> List[dict]: |
|
""" |
|
Overview: |
|
Create a list of environment configuration for collectors based on the input configuration. |
|
Arguments: |
|
- cfg (:obj:`dict`): The input configuration dictionary. |
|
Returns: |
|
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for collectors. |
|
""" |
|
actor_env_num = cfg.pop('collector_env_num') |
|
cfg = copy.deepcopy(cfg) |
|
cfg.is_train = True |
|
return [cfg for _ in range(actor_env_num)] |
|
|
|
@staticmethod |
|
def create_evaluator_env_cfg(cfg: dict) -> List[dict]: |
|
""" |
|
Overview: |
|
Create a list of environment configuration for evaluators based on the input configuration. |
|
Arguments: |
|
- cfg (:obj:`dict`): The input configuration dictionary. |
|
Returns: |
|
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for evaluators. |
|
""" |
|
evaluator_env_num = cfg.pop('evaluator_env_num') |
|
cfg = copy.deepcopy(cfg) |
|
cfg.is_train = False |
|
return [cfg for _ in range(evaluator_env_num)] |
|
|
|
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: |
|
""" |
|
Overview: |
|
Enable the save replay functionality. The replay will be saved at the specified path. |
|
Arguments: |
|
- replay_path (:obj:`Optional[str]`): The path to save the replay, default is None. |
|
""" |
|
if replay_path is None: |
|
replay_path = './video' |
|
self._replay_path = replay_path |
|
|
|
@property |
|
def observation_space(self) -> gym.spaces.Space: |
|
""" |
|
Overview: |
|
Return the observation space of the wrapped environment. |
|
The observation space represents the range and shape of possible observations |
|
that the environment can provide to the agent. |
|
Note: |
|
If the data type of the observation space is float64, it's converted to float32 |
|
for better compatibility with most machine learning libraries. |
|
Returns: |
|
- observation_space (gym.spaces.Space): The observation space of the environment. |
|
""" |
|
if self._observation_space.dtype == np.float64: |
|
self._observation_space.dtype = np.float32 |
|
return self._observation_space |
|
|
|
@property |
|
def action_space(self) -> gym.spaces.Space: |
|
""" |
|
Overview: |
|
Return the action space of the wrapped environment. |
|
The action space represents the range and shape of possible actions |
|
that the agent can take in the environment. |
|
Returns: |
|
- action_space (gym.spaces.Space): The action space of the environment. |
|
""" |
|
return self._action_space |
|
|
|
@property |
|
def reward_space(self) -> gym.spaces.Space: |
|
""" |
|
Overview: |
|
Return the reward space of the wrapped environment. |
|
The reward space represents the range and shape of possible rewards |
|
that the agent can receive as a result of its actions. |
|
Returns: |
|
- reward_space (gym.spaces.Space): The reward space of the environment. |
|
""" |
|
return self._reward_space |
|
|
|
def clone(self, caller: str = 'collector') -> BaseEnv: |
|
""" |
|
Overview: |
|
Clone the current environment wrapper, creating a new environment with the same settings. |
|
Arguments: |
|
- caller (str): A string representing the caller of this method, including ``collector`` or ``evaluator``. \ |
|
Different caller may need different wrappers. Default is 'collector'. |
|
Returns: |
|
- DingEnvWrapper: A new instance of the environment with the same settings. |
|
""" |
|
try: |
|
spec = copy.deepcopy(self._raw_env.spec) |
|
raw_env = CloudPickleWrapper(self._raw_env) |
|
raw_env = copy.deepcopy(raw_env).data |
|
raw_env.__setattr__('spec', spec) |
|
except Exception: |
|
raw_env = self._raw_env |
|
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller) |
|
|