gomoku / DI-engine /dizoo /metadrive /env /drive_wrapper.py
zjowowen's picture
init space
079c32c
from typing import Any, Dict, Optional
from easydict import EasyDict
import matplotlib.pyplot as plt
import gym
import copy
import numpy as np
from ding.envs.env.base_env import BaseEnvTimestep
from ding.torch_utils.data_helper import to_ndarray
from ding.utils.default_helper import deep_merge_dicts
from dizoo.metadrive.env.drive_utils import BaseDriveEnv
def draw_multi_channels_top_down_observation(obs, show_time=0.5):
num_channels = obs.shape[-1]
assert num_channels == 5
channel_names = [
"Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1",
"Neighbor at step t-2"
]
fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80)
count = 0
def close_event():
plt.close()
timer = fig.canvas.new_timer(interval=show_time * 1000)
timer.add_callback(close_event)
for i, name in enumerate(channel_names):
count += 1
ax = axs[i]
ax.imshow(obs[..., i], cmap="bone")
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(name)
fig.suptitle("Multi-channels Top-down Observation")
timer.start()
plt.show()
plt.close()
class DriveEnvWrapper(gym.Wrapper):
"""
Overview:
Environment wrapper to make ``gym.Env`` align with DI-engine definitions, so as to use utilities in DI-engine.
It changes ``step``, ``reset`` and ``info`` method of ``gym.Env``, while others are straightly delivered.
Arguments:
- env (BaseDriveEnv): The environment to be wrapped.
- cfg (Dict): Config dict.
"""
config = dict()
def __init__(self, env: BaseDriveEnv, cfg: Dict = None, **kwargs) -> None:
if cfg is None:
self._cfg = self.__class__.default_config()
elif 'cfg_type' not in cfg:
self._cfg = self.__class__.default_config()
self._cfg = deep_merge_dicts(self._cfg, cfg)
else:
self._cfg = cfg
self.env = env
if not hasattr(self.env, 'reward_space'):
self.reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, ))
if 'show_bird_view' in self._cfg and self._cfg['show_bird_view'] is True:
self.show_bird_view = True
else:
self.show_bird_view = False
self.action_space = self.env.action_space
self.env = env
def reset(self, *args, **kwargs) -> Any:
"""
Overview:
Wrapper of ``reset`` method in env. The observations are converted to ``np.ndarray`` and final reward
are recorded.
Returns:
- Any: Observations from environment
"""
obs = self.env.reset(*args, **kwargs)
obs = to_ndarray(obs, dtype=np.float32)
if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
obs = obs.transpose((2, 0, 1))
elif isinstance(obs, dict):
vehicle_state = obs['vehicle_state']
birdview = obs['birdview'].transpose((2, 0, 1))
obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
self._eval_episode_return = 0.0
self._arrive_dest = False
return obs
def step(self, action: Any = None) -> BaseEnvTimestep:
"""
Overview:
Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into
that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep``
namedtuple defined in DI-engine. It will also convert actions, observations and reward into
``np.ndarray``, and check legality if action contains control signal.
Arguments:
- action (Any, optional): Actions sent to env. Defaults to None.
Returns:
- BaseEnvTimestep: DI-engine format of env step returns.
"""
action = to_ndarray(action)
obs, rew, done, info = self.env.step(action)
if self.show_bird_view:
draw_multi_channels_top_down_observation(obs, show_time=0.5)
self._eval_episode_return += rew
obs = to_ndarray(obs, dtype=np.float32)
if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
obs = obs.transpose((2, 0, 1))
elif isinstance(obs, dict):
vehicle_state = obs['vehicle_state']
birdview = obs['birdview'].transpose((2, 0, 1))
obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
rew = to_ndarray([rew], dtype=np.float32)
if done:
info['eval_episode_return'] = self._eval_episode_return
return BaseEnvTimestep(obs, rew, done, info)
@property
def observation_space(self):
return gym.spaces.Box(0, 1, shape=(5, 84, 84), dtype=np.float32)
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
if replay_path is None:
replay_path = './video'
self._replay_path = replay_path
self.env = gym.wrappers.Monitor(self.env, self._replay_path, video_callable=lambda episode_id: True, force=True)
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(cls.config)
cfg.cfg_type = cls.__name__ + 'Config'
return copy.deepcopy(cfg)
def __repr__(self) -> str:
return repr(self.env)
def render(self):
self.env.render()
def clone(self, caller: str):
cfg = copy.deepcopy(self._cfg)
return DriveEnvWrapper(self.env.clone(caller), cfg)