from typing import Any, Dict, Optional from easydict import EasyDict import matplotlib.pyplot as plt import gym import copy import numpy as np from ding.envs.env.base_env import BaseEnvTimestep from ding.torch_utils.data_helper import to_ndarray from ding.utils.default_helper import deep_merge_dicts from dizoo.metadrive.env.drive_utils import BaseDriveEnv def draw_multi_channels_top_down_observation(obs, show_time=0.5): num_channels = obs.shape[-1] assert num_channels == 5 channel_names = [ "Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1", "Neighbor at step t-2" ] fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80) count = 0 def close_event(): plt.close() timer = fig.canvas.new_timer(interval=show_time * 1000) timer.add_callback(close_event) for i, name in enumerate(channel_names): count += 1 ax = axs[i] ax.imshow(obs[..., i], cmap="bone") ax.set_xticks([]) ax.set_yticks([]) ax.set_title(name) fig.suptitle("Multi-channels Top-down Observation") timer.start() plt.show() plt.close() class DriveEnvWrapper(gym.Wrapper): """ Overview: Environment wrapper to make ``gym.Env`` align with DI-engine definitions, so as to use utilities in DI-engine. It changes ``step``, ``reset`` and ``info`` method of ``gym.Env``, while others are straightly delivered. Arguments: - env (BaseDriveEnv): The environment to be wrapped. - cfg (Dict): Config dict. """ config = dict() def __init__(self, env: BaseDriveEnv, cfg: Dict = None, **kwargs) -> None: if cfg is None: self._cfg = self.__class__.default_config() elif 'cfg_type' not in cfg: self._cfg = self.__class__.default_config() self._cfg = deep_merge_dicts(self._cfg, cfg) else: self._cfg = cfg self.env = env if not hasattr(self.env, 'reward_space'): self.reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, )) if 'show_bird_view' in self._cfg and self._cfg['show_bird_view'] is True: self.show_bird_view = True else: self.show_bird_view = False self.action_space = self.env.action_space self.env = env def reset(self, *args, **kwargs) -> Any: """ Overview: Wrapper of ``reset`` method in env. The observations are converted to ``np.ndarray`` and final reward are recorded. Returns: - Any: Observations from environment """ obs = self.env.reset(*args, **kwargs) obs = to_ndarray(obs, dtype=np.float32) if isinstance(obs, np.ndarray) and len(obs.shape) == 3: obs = obs.transpose((2, 0, 1)) elif isinstance(obs, dict): vehicle_state = obs['vehicle_state'] birdview = obs['birdview'].transpose((2, 0, 1)) obs = {'vehicle_state': vehicle_state, 'birdview': birdview} self._eval_episode_return = 0.0 self._arrive_dest = False return obs def step(self, action: Any = None) -> BaseEnvTimestep: """ Overview: Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep`` namedtuple defined in DI-engine. It will also convert actions, observations and reward into ``np.ndarray``, and check legality if action contains control signal. Arguments: - action (Any, optional): Actions sent to env. Defaults to None. Returns: - BaseEnvTimestep: DI-engine format of env step returns. """ action = to_ndarray(action) obs, rew, done, info = self.env.step(action) if self.show_bird_view: draw_multi_channels_top_down_observation(obs, show_time=0.5) self._eval_episode_return += rew obs = to_ndarray(obs, dtype=np.float32) if isinstance(obs, np.ndarray) and len(obs.shape) == 3: obs = obs.transpose((2, 0, 1)) elif isinstance(obs, dict): vehicle_state = obs['vehicle_state'] birdview = obs['birdview'].transpose((2, 0, 1)) obs = {'vehicle_state': vehicle_state, 'birdview': birdview} rew = to_ndarray([rew], dtype=np.float32) if done: info['eval_episode_return'] = self._eval_episode_return return BaseEnvTimestep(obs, rew, done, info) @property def observation_space(self): return gym.spaces.Box(0, 1, shape=(5, 84, 84), dtype=np.float32) def seed(self, seed: int, dynamic_seed: bool = True) -> None: self._seed = seed self._dynamic_seed = dynamic_seed np.random.seed(self._seed) def enable_save_replay(self, replay_path: Optional[str] = None) -> None: if replay_path is None: replay_path = './video' self._replay_path = replay_path self.env = gym.wrappers.Monitor(self.env, self._replay_path, video_callable=lambda episode_id: True, force=True) @classmethod def default_config(cls: type) -> EasyDict: cfg = EasyDict(cls.config) cfg.cfg_type = cls.__name__ + 'Config' return copy.deepcopy(cfg) def __repr__(self) -> str: return repr(self.env) def render(self): self.env.render() def clone(self, caller: str): cfg = copy.deepcopy(self._cfg) return DriveEnvWrapper(self.env.clone(caller), cfg)