File size: 6,561 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from typing import Any, Union, List
import copy
import numpy as np
import gym
import competitive_rl
from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
from ding.envs.common.env_element import EnvElement, EnvElementInfo
from ding.envs.common.common_function import affine_transform
from ding.torch_utils import to_ndarray, to_list
from .competitive_rl_env_wrapper import BuiltinOpponentWrapper, wrap_env
from ding.utils import ENV_REGISTRY
competitive_rl.register_competitive_envs()
"""
The observation spaces:
cPong-v0: Box(210, 160, 3)
cPongDouble-v0: Tuple(Box(210, 160, 3), Box(210, 160, 3))
cCarRacing-v0: Box(96, 96, 1)
cCarRacingDouble-v0: Box(96, 96, 1)
The action spaces:
cPong-v0: Discrete(3)
cPongDouble-v0: Tuple(Discrete(3), Discrete(3))
cCarRacing-v0: Box(2,)
cCarRacingDouble-v0: Dict(0:Box(2,), 1:Box(2,))
cPongTournament-v0
"""
COMPETITIVERL_INFO_DICT = {
'cPongDouble-v0': BaseEnvInfo(
agent_num=1,
obs_space=EnvElementInfo(
shape=(210, 160, 3),
# shape=(4, 84, 84),
value={
'min': 0,
'max': 255,
'dtype': np.float32
},
),
act_space=EnvElementInfo(
shape=(1, ), # different with https://github.com/cuhkrlcourse/competitive-rl#usage
value={
'min': 0,
'max': 3,
'dtype': np.float32
},
),
rew_space=EnvElementInfo(
shape=(1, ),
value={
'min': np.float32("-inf"),
'max': np.float32("inf"),
'dtype': np.float32
},
),
use_wrappers=None,
),
}
@ENV_REGISTRY.register('competitive_rl')
class CompetitiveRlEnv(BaseEnv):
def __init__(self, cfg: dict) -> None:
self._cfg = cfg
self._env_id = self._cfg.env_id
# opponent_type is used to control builtin opponent agent, which is useful in evaluator.
is_evaluator = self._cfg.get("is_evaluator", False)
opponent_type = None
if is_evaluator:
opponent_type = self._cfg.get("opponent_type", None)
self._builtin_wrap = self._env_id == "cPongDouble-v0" and is_evaluator and opponent_type == "builtin"
self._opponent = self._cfg.get('eval_opponent', 'RULE_BASED')
self._init_flag = False
def reset(self) -> np.ndarray:
if not self._init_flag:
self._env = self._make_env(only_info=False)
self._init_flag = True
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._env.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
self._env.seed(self._seed)
obs = self._env.reset()
obs = to_ndarray(obs)
obs = self.process_obs(obs) # process
if self._builtin_wrap:
self._eval_episode_return = np.array([0.])
else:
self._eval_episode_return = np.array([0., 0.])
return obs
def close(self) -> None:
if self._init_flag:
self._env.close()
self._init_flag = False
def seed(self, seed: int, dynamic_seed: bool = True) -> None:
self._seed = seed
self._dynamic_seed = dynamic_seed
np.random.seed(self._seed)
def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
action = to_ndarray(action)
action = self.process_action(action) # process
obs, rew, done, info = self._env.step(action)
if not isinstance(rew, tuple):
rew = [rew]
rew = np.array(rew)
self._eval_episode_return += rew
obs = to_ndarray(obs)
obs = self.process_obs(obs) # process
if done:
info['eval_episode_return'] = self._eval_episode_return
return BaseEnvTimestep(obs, rew, done, info)
def info(self) -> BaseEnvInfo:
if self._env_id in COMPETITIVERL_INFO_DICT:
info = copy.deepcopy(COMPETITIVERL_INFO_DICT[self._env_id])
info.use_wrappers = self._make_env(only_info=True)
obs_shape, act_shape, rew_shape = update_shape(
info.obs_space.shape, info.act_space.shape, info.rew_space.shape, info.use_wrappers.split('\n')
)
info.obs_space.shape = obs_shape
info.act_space.shape = act_shape
info.rew_space.shape = rew_shape
if not self._builtin_wrap:
info.obs_space.shape = (2, ) + info.obs_space.shape
info.act_space.shape = (2, )
info.rew_space.shape = (2, )
return info
else:
raise NotImplementedError('{} not found in COMPETITIVERL_INFO_DICT [{}]'\
.format(self._env_id, COMPETITIVERL_INFO_DICT.keys()))
def _make_env(self, only_info=False):
return wrap_env(self._env_id, self._builtin_wrap, self._opponent, only_info=only_info)
def __repr__(self) -> str:
return "DI-engine Competitve RL Env({})".format(self._cfg.env_id)
@staticmethod
def create_collector_env_cfg(cfg: dict) -> List[dict]:
collector_cfg = copy.deepcopy(cfg)
collector_env_num = collector_cfg.pop('collector_env_num', 1)
collector_cfg.is_evaluator = False
return [collector_cfg for _ in range(collector_env_num)]
@staticmethod
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
evaluator_cfg = copy.deepcopy(cfg)
evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
evaluator_cfg.is_evaluator = True
return [evaluator_cfg for _ in range(evaluator_env_num)]
def process_action(self, action: np.ndarray) -> Union[tuple, dict, np.ndarray]:
# If in double agent env, transfrom action passed in from outside to tuple or dict type.
if self._env_id == "cPongDouble-v0" and not self._builtin_wrap:
return (action[0].squeeze(), action[1].squeeze())
elif self._env_id == "cCarRacingDouble-v0":
return {0: action[0].squeeze(), 1: action[1].squeeze()}
else:
return action.squeeze()
def process_obs(self, obs: Union[tuple, np.ndarray]) -> Union[tuple, np.ndarray]:
# Copy observation for car racing double agent env, in case to be in alignment with pong double agent env.
if self._env_id == "cCarRacingDouble-v0":
obs = np.stack([obs, copy.deepcopy(obs)])
return obs
|