File size: 15,551 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
from typing import List, Optional, Union, Dict
from easydict import EasyDict
import gym
import gymnasium
import copy
import numpy as np
import treetensor.numpy as tnp
from ding.envs.common.common_function import affine_transform
from ding.envs.env_wrappers import create_env_wrapper
from ding.torch_utils import to_ndarray
from ding.utils import CloudPickleWrapper
from .base_env import BaseEnv, BaseEnvTimestep
from .default_wrapper import get_default_wrappers
class DingEnvWrapper(BaseEnv):
"""
Overview:
This is a wrapper for the BaseEnv class, used to provide a consistent environment interface.
Interfaces:
__init__, reset, step, close, seed, random_action, _wrap_env, __repr__, create_collector_env_cfg,
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone
"""
def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
"""
Overview:
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \
instance should be passed in. For the former, i.e., an environment instance: The `env` parameter must not \
be `None`, but should be the instance. It does not support subprocess environment manager. Thus, it is \
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \
The `cfg` parameter must contain `env_id`.
Arguments:
- env (:obj:`gym.Env`): An environment instance to be wrapped.
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True.
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \
``evaluator``. Different caller may need different wrappers. Default is 'collector'.
"""
self._env = None
self._raw_env = env
self._cfg = cfg
self._seed_api = seed_api # some env may disable `env.seed` api
self._caller = caller
if self._cfg is None:
self._cfg = {}
self._cfg = EasyDict(self._cfg)
if 'act_scale' not in self._cfg:
self._cfg.act_scale = False
if 'rew_clip' not in self._cfg:
self._cfg.rew_clip = False
if 'env_wrapper' not in self._cfg:
self._cfg.env_wrapper = 'default'
if 'env_id' not in self._cfg:
self._cfg.env_id = None
if env is not None:
self._env = env
self._wrap_env(caller)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._action_space.seed(0) # default seed
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
else:
assert 'env_id' in self._cfg
self._init_flag = False
self._observation_space = None
self._action_space = None
self._reward_space = None
# Only if user specifies the replay_path, will the video be saved. So its inital value is None.
self._replay_path = None
# override
def reset(self) -> np.ndarray:
"""
Overview:
Resets the state of the environment. If the environment is not initialized, it will be created first.
Returns:
- obs (:obj:`Dict`): The new observation after reset.
"""
if not self._init_flag:
self._env = gym.make(self._cfg.env_id)
self._wrap_env(self._caller)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
self._init_flag = True
if self._replay_path is not None:
self._env = gym.wrappers.RecordVideo(
self._env,
video_folder=self._replay_path,
episode_trigger=lambda episode_id: True,
name_prefix='rl-video-{}'.format(id(self))
)
self._replay_path = None
if isinstance(self._env, gym.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
if self._seed_api:
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
if self._seed_api:
self._env.seed(self._seed)
self._action_space.seed(self._seed)
obs = self._env.reset()
elif isinstance(self._env, gymnasium.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._action_space.seed(self._seed + np_seed)
obs = self._env.reset(seed=self._seed + np_seed)
elif hasattr(self, '_seed'):
self._action_space.seed(self._seed)
obs = self._env.reset(seed=self._seed)
else:
obs = self._env.reset()
else:
raise RuntimeError("not support env type: {}".format(type(self._env)))
if self.observation_space.dtype == np.float32:
obs = to_ndarray(obs, dtype=np.float32)
else:
obs = to_ndarray(obs)
return obs
# override
def close(self) -> None:
"""
Overview:
Clean up the environment by closing and deleting it.
This method should be called when the environment is no longer needed.
Failing to call this method can lead to memory leaks.
"""
try:
self._env.close()
del self._env
except: # noqa
pass
# override
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
"""
Overview:
Set the seed for the environment.
Arguments:
- seed (:obj:`int`): The seed to set.
- dynamic_seed (:obj:`bool`): Whether to use dynamic seed, default is True.
"""
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
# override
def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep:
"""
Overview:
Execute the given action in the environment, and return the timestep (observation, reward, done, info).
Arguments:
- action (:obj:`Union[np.int64, np.ndarray]`): The action to execute in the environment.
Returns:
- timestep (:obj:`BaseEnvTimestep`): The timestep after the action execution.
"""
action = self._judge_action_type(action)
if self._cfg.act_scale:
action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high)
obs, rew, done, info = self._env.step(action)
if self._cfg.rew_clip:
rew = max(-10, rew)
rew = np.float32(rew)
if self.observation_space.dtype == np.float32:
obs = to_ndarray(obs, dtype=np.float32)
else:
obs = to_ndarray(obs)
rew = to_ndarray([rew], np.float32)
return BaseEnvTimestep(obs, rew, done, info)
def _judge_action_type(self, action: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]:
"""
Overview:
Ensure the action taken by the agent is of the correct type.
This method is used to standardize different action types to a common format.
Arguments:
- action (Union[np.ndarray, dict]): The action taken by the agent.
Returns:
- action (Union[np.ndarray, dict]): The formatted action.
"""
if isinstance(action, int):
return action
elif isinstance(action, np.int64):
return int(action)
elif isinstance(action, np.ndarray):
if action.shape == ():
action = action.item()
elif action.shape == (1, ) and action.dtype == np.int64:
action = action.item()
return action
elif isinstance(action, dict):
for k, v in action.items():
action[k] = self._judge_action_type(v)
return action
elif isinstance(action, tnp.ndarray):
return self._judge_action_type(action.json())
else:
raise TypeError(
'`action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
type(action), action
)
)
def random_action(self) -> np.ndarray:
"""
Overview:
Return a random action from the action space of the environment.
Returns:
- action (:obj:`np.ndarray`): The random action.
"""
random_action = self.action_space.sample()
if isinstance(random_action, np.ndarray):
pass
elif isinstance(random_action, int):
random_action = to_ndarray([random_action], dtype=np.int64)
elif isinstance(random_action, dict):
random_action = to_ndarray(random_action)
else:
raise TypeError(
'`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
type(random_action), random_action
)
)
return random_action
def _wrap_env(self, caller: str = 'collector') -> None:
"""
Overview:
Wrap the environment according to the configuration.
Arguments:
- caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \
Different caller may need different wrappers. Default is 'collector'.
"""
# wrapper_cfgs: Union[str, List]
wrapper_cfgs = self._cfg.env_wrapper
if isinstance(wrapper_cfgs, str):
wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id, caller)
# self._wrapper_cfgs: List[Union[Callable, Dict]]
self._wrapper_cfgs = wrapper_cfgs
for wrapper in self._wrapper_cfgs:
# wrapper: Union[Callable, Dict]
if isinstance(wrapper, Dict):
self._env = create_env_wrapper(self._env, wrapper)
else: # Callable, such as lambda anonymous function
self._env = wrapper(self._env)
def __repr__(self) -> str:
"""
Overview:
Return the string representation of the instance.
Returns:
- str (:obj:`str`): The string representation of the instance.
"""
return "DI-engine Env({}), generated by DingEnvWrapper".format(self._cfg.env_id)
@staticmethod
def create_collector_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Create a list of environment configuration for collectors based on the input configuration.
Arguments:
- cfg (:obj:`dict`): The input configuration dictionary.
Returns:
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for collectors.
"""
actor_env_num = cfg.pop('collector_env_num')
cfg = copy.deepcopy(cfg)
cfg.is_train = True
return [cfg for _ in range(actor_env_num)]
@staticmethod
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
"""
Overview:
Create a list of environment configuration for evaluators based on the input configuration.
Arguments:
- cfg (:obj:`dict`): The input configuration dictionary.
Returns:
- env_cfgs (:obj:`List[dict]`): The list of environment configurations for evaluators.
"""
evaluator_env_num = cfg.pop('evaluator_env_num')
cfg = copy.deepcopy(cfg)
cfg.is_train = False
return [cfg for _ in range(evaluator_env_num)]
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
"""
Overview:
Enable the save replay functionality. The replay will be saved at the specified path.
Arguments:
- replay_path (:obj:`Optional[str]`): The path to save the replay, default is None.
"""
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
@property
def observation_space(self) -> gym.spaces.Space:
"""
Overview:
Return the observation space of the wrapped environment.
The observation space represents the range and shape of possible observations
that the environment can provide to the agent.
Note:
If the data type of the observation space is float64, it's converted to float32
for better compatibility with most machine learning libraries.
Returns:
- observation_space (gym.spaces.Space): The observation space of the environment.
"""
if self._observation_space.dtype == np.float64:
self._observation_space.dtype = np.float32
return self._observation_space
@property
def action_space(self) -> gym.spaces.Space:
"""
Overview:
Return the action space of the wrapped environment.
The action space represents the range and shape of possible actions
that the agent can take in the environment.
Returns:
- action_space (gym.spaces.Space): The action space of the environment.
"""
return self._action_space
@property
def reward_space(self) -> gym.spaces.Space:
"""
Overview:
Return the reward space of the wrapped environment.
The reward space represents the range and shape of possible rewards
that the agent can receive as a result of its actions.
Returns:
- reward_space (gym.spaces.Space): The reward space of the environment.
"""
return self._reward_space
def clone(self, caller: str = 'collector') -> BaseEnv:
"""
Overview:
Clone the current environment wrapper, creating a new environment with the same settings.
Arguments:
- caller (str): A string representing the caller of this method, including ``collector`` or ``evaluator``. \
Different caller may need different wrappers. Default is 'collector'.
Returns:
- DingEnvWrapper: A new instance of the environment with the same settings.
"""
try:
spec = copy.deepcopy(self._raw_env.spec)
raw_env = CloudPickleWrapper(self._raw_env)
raw_env = copy.deepcopy(raw_env).data
raw_env.__setattr__('spec', spec)
except Exception:
raw_env = self._raw_env
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller)
|