zjowowen's picture
init space
079c32c
raw
history blame
4.07 kB
from typing import Any, Union, List
import copy
import numpy as np
from numpy import dtype
import gym
from ding.envs import BaseEnv, BaseEnvTimestep
from ding.envs.common.common_function import affine_transform
from ding.torch_utils import to_ndarray, to_list
from ding.utils import ENV_REGISTRY
from .mujoco_multi import MujocoMulti
@ENV_REGISTRY.register('mujoco_multi')
class MujocoEnv(BaseEnv):
def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._init_flag = False
def reset(self) -> np.ndarray:
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._cfg.seed = self._seed + np_seed
elif hasattr(self, '_seed'):
self._cfg.seed = self._seed
if not self._init_flag:
self._env = MujocoMulti(env_args=self._cfg)
self._init_flag = True
obs = self._env.reset()
self._eval_episode_return = 0.
# TODO:
# self.env_info for scenario='Ant-v2', agent_conf="2x4d",
# {'state_shape': 2, 'obs_shape': 54,...}
# 'state_shape' is wrong, it should be 111
self.env_info = self._env.get_env_info()
# self._env.observation_space[agent].shape equals above 'state_shape'
self._num_agents = self.env_info['n_agents']
self._agents = [i for i in range(self._num_agents)]
self._observation_space = gym.spaces.Dict(
{
'agent_state': gym.spaces.Box(
low=float("-inf"), high=float("inf"), shape=obs['agent_state'].shape, dtype=np.float32
),
'global_state': gym.spaces.Box(
low=float("-inf"), high=float("inf"), shape=obs['global_state'].shape, dtype=np.float32
),
}
)
self._action_space = gym.spaces.Dict({agent: self._env.action_space[agent] for agent in self._agents})
single_agent_obs_space = self._env.action_space[self._agents[0]]
if isinstance(single_agent_obs_space, gym.spaces.Box):
self._action_dim = single_agent_obs_space.shape
elif isinstance(single_agent_obs_space, gym.spaces.Discrete):
self._action_dim = (single_agent_obs_space.n, )
else:
raise Exception('Only support `Box` or `Discrte` obs space for single agent.')
self._reward_space = gym.spaces.Dict(
{
agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
for agent in self._agents
}
)
return obs
def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
action = to_ndarray(action)
obs, rew, done, info = self._env.step(action)
self._eval_episode_return += rew
rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
if done:
info['eval_episode_return'] = self._eval_episode_return
return BaseEnvTimestep(obs, rew, done, info)
def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
random_action = to_ndarray([random_action], dtype=np.int64)
return random_action
@property
def num_agents(self) -> Any:
return self._num_agents
@property
def observation_space(self) -> gym.spaces.Space:
return self._observation_space
@property
def action_space(self) -> gym.spaces.Space:
return self._action_space
@property
def reward_space(self) -> gym.spaces.Space:
return self._reward_space
def __repr__(self) -> str:
return "DI-engine Multi-agent Mujoco Env({})".format(self._cfg.env_id)