gomoku / DI-engine /dizoo /metadrive /env /drive_wrapper.py
zjowowen's picture
init space
079c32c
raw
history blame
5.71 kB
from typing import Any, Dict, Optional
from easydict import EasyDict
import matplotlib.pyplot as plt
import gym
import copy
import numpy as np
from ding.envs.env.base_env import BaseEnvTimestep
from ding.torch_utils.data_helper import to_ndarray
from ding.utils.default_helper import deep_merge_dicts
from dizoo.metadrive.env.drive_utils import BaseDriveEnv
def draw_multi_channels_top_down_observation(obs, show_time=0.5):
num_channels = obs.shape[-1]
assert num_channels == 5
channel_names = [
"Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1",
"Neighbor at step t-2"
]
fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80)
count = 0
def close_event():
plt.close()
timer = fig.canvas.new_timer(interval=show_time * 1000)
timer.add_callback(close_event)
for i, name in enumerate(channel_names):
count += 1
ax = axs[i]
ax.imshow(obs[..., i], cmap="bone")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(name)
fig.suptitle("Multi-channels Top-down Observation")
timer.start()
plt.show()
plt.close()
class DriveEnvWrapper(gym.Wrapper):
"""
Overview:
Environment wrapper to make ``gym.Env`` align with DI-engine definitions, so as to use utilities in DI-engine.
It changes ``step``, ``reset`` and ``info`` method of ``gym.Env``, while others are straightly delivered.
Arguments:
- env (BaseDriveEnv): The environment to be wrapped.
- cfg (Dict): Config dict.
"""
config = dict()
def __init__(self, env: BaseDriveEnv, cfg: Dict = None, **kwargs) -> None:
if cfg is None:
self._cfg = self.__class__.default_config()
elif 'cfg_type' not in cfg:
self._cfg = self.__class__.default_config()
self._cfg = deep_merge_dicts(self._cfg, cfg)
else:
self._cfg = cfg
self.env = env
if not hasattr(self.env, 'reward_space'):
self.reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, ))
if 'show_bird_view' in self._cfg and self._cfg['show_bird_view'] is True:
self.show_bird_view = True
else:
self.show_bird_view = False
self.action_space = self.env.action_space
self.env = env
def reset(self, *args, **kwargs) -> Any:
"""
Overview:
Wrapper of ``reset`` method in env. The observations are converted to ``np.ndarray`` and final reward
are recorded.
Returns:
- Any: Observations from environment
"""
obs = self.env.reset(*args, **kwargs)
obs = to_ndarray(obs, dtype=np.float32)
if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
obs = obs.transpose((2, 0, 1))
elif isinstance(obs, dict):
vehicle_state = obs['vehicle_state']
birdview = obs['birdview'].transpose((2, 0, 1))
obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
self._eval_episode_return = 0.0
self._arrive_dest = False
return obs
def step(self, action: Any = None) -> BaseEnvTimestep:
"""
Overview:
Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into
that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep``
namedtuple defined in DI-engine. It will also convert actions, observations and reward into
``np.ndarray``, and check legality if action contains control signal.
Arguments:
- action (Any, optional): Actions sent to env. Defaults to None.
Returns:
- BaseEnvTimestep: DI-engine format of env step returns.
"""
action = to_ndarray(action)
obs, rew, done, info = self.env.step(action)
if self.show_bird_view:
draw_multi_channels_top_down_observation(obs, show_time=0.5)
self._eval_episode_return += rew
obs = to_ndarray(obs, dtype=np.float32)
if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
obs = obs.transpose((2, 0, 1))
elif isinstance(obs, dict):
vehicle_state = obs['vehicle_state']
birdview = obs['birdview'].transpose((2, 0, 1))
obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
rew = to_ndarray([rew], dtype=np.float32)
if done:
info['eval_episode_return'] = self._eval_episode_return
return BaseEnvTimestep(obs, rew, done, info)
@property
def observation_space(self):
return gym.spaces.Box(0, 1, shape=(5, 84, 84), dtype=np.float32)
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
self.env = gym.wrappers.Monitor(self.env, self._replay_path, video_callable=lambda episode_id: True, force=True)
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(cls.config)
cfg.cfg_type = cls.__name__ + 'Config'
return copy.deepcopy(cfg)
def __repr__(self) -> str:
return repr(self.env)
def render(self):
self.env.render()
def clone(self, caller: str):
cfg = copy.deepcopy(self._cfg)
return DriveEnvWrapper(self.env.clone(caller), cfg)