gomoku / DI-engine /dizoo /competitive_rl /envs /competitive_rl_env_wrapper.py
zjowowen's picture
init space
079c32c
import cv2
import gym
import os.path as osp
import numpy as np
from typing import Union, Optional
from collections import deque
from competitive_rl.pong.builtin_policies import get_builtin_agent_names, single_obs_space, single_act_space, get_random_policy, get_rule_based_policy
from competitive_rl.utils.policy_serving import Policy
def get_compute_action_function_ours(agent_name, num_envs=1):
resource_dir = osp.join(osp.dirname(__file__), "resources", "pong")
if agent_name == "STRONG":
return Policy(
single_obs_space,
single_act_space,
num_envs,
osp.join(resource_dir, "checkpoint-strong.pkl"),
use_light_model=False
)
if agent_name == "MEDIUM":
return Policy(
single_obs_space,
single_act_space,
num_envs,
osp.join(resource_dir, "checkpoint-medium.pkl"),
use_light_model=True
)
if agent_name == "ALPHA_PONG":
return Policy(
single_obs_space,
single_act_space,
num_envs,
osp.join(resource_dir, "checkpoint-alphapong.pkl"),
use_light_model=False
)
if agent_name == "WEAK":
return Policy(
single_obs_space,
single_act_space,
num_envs,
osp.join(resource_dir, "checkpoint-weak.pkl"),
use_light_model=True
)
if agent_name == "RANDOM":
return get_random_policy(num_envs)
if agent_name == "RULE_BASED":
return get_rule_based_policy(num_envs)
raise ValueError("Unknown agent name: {}".format(agent_name))
class BuiltinOpponentWrapper(gym.Wrapper):
def __init__(self, env: 'gym.Env', num_envs: int = 1) -> None: # noqa
super().__init__(env)
self.agents = {
agent_name: get_compute_action_function_ours(agent_name, num_envs)
for agent_name in get_builtin_agent_names()
}
self.agent_names = list(self.agents)
self.prev_opponent_obs = None
self.current_opponent_name = "RULE_BASED"
self.current_opponent = self.agents[self.current_opponent_name]
self.observation_space = env.observation_space[0]
self.action_space = env.action_space[0]
self.num_envs = num_envs
def reset_opponent(self, agent_name: str) -> None:
assert agent_name in self.agent_names, (agent_name, self.agent_names)
self.current_opponent_name = agent_name
self.current_opponent = self.agents[self.current_opponent_name]
def step(self, action):
tuple_action = (action.item(), self.current_opponent(self.prev_opponent_obs))
obs, rew, done, info = self.env.step(tuple_action)
self.prev_opponent_obs = obs[1]
# if done.ndim == 2:
# done = done[:, 0]
# return obs[0], rew[:, 0].reshape(-1, 1), done.reshape(-1, 1), info
return obs[0], rew[0], done, info
def reset(self):
obs = self.env.reset()
self.prev_opponent_obs = obs[1]
return obs[0]
def seed(self, s):
self.env.seed(s)
def wrap_env(env_id, builtin_wrap, opponent, frame_stack=4, warp_frame=True, only_info=False):
"""Configure environment for DeepMind-style Atari. The observation is
channel-first: (c, h, w) instead of (h, w, c).
:param str env_id: the atari environment id.
:param bool episode_life: wrap the episode life wrapper.
:param bool clip_rewards: wrap the reward clipping wrapper.
:param int frame_stack: wrap the frame stacking wrapper.
:param bool scale: wrap the scaling observation wrapper.
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
:return: the wrapped atari environment.
"""
if not only_info:
env = gym.make(env_id)
if builtin_wrap:
env = BuiltinOpponentWrapper(env)
env.reset_opponent(opponent)
if warp_frame:
env = WarpFrameWrapperCompetitveRl(env, builtin_wrap)
if frame_stack:
env = FrameStackWrapperCompetitiveRl(env, frame_stack, builtin_wrap)
return env
else:
wrapper_info = ''
if builtin_wrap:
wrapper_info += BuiltinOpponentWrapper.__name__ + '\n'
if warp_frame:
wrapper_info = WarpFrameWrapperCompetitveRl.__name__ + '\n'
if frame_stack:
wrapper_info = FrameStackWrapperCompetitiveRl.__name__ + '\n'
return wrapper_info
class WarpFrameWrapperCompetitveRl(gym.ObservationWrapper):
"""Warp frames to 84x84 as done in the Nature paper and later work.
:param gym.Env env: the environment to wrap.
"""
def __init__(self, env, builtin_wrap):
super().__init__(env)
self.size = 84
obs_space = env.observation_space
self.builtin_wrap = builtin_wrap
if builtin_wrap:
# single player
self.observation_space = gym.spaces.Box(
low=np.min(obs_space.low),
high=np.max(obs_space.high),
shape=(self.size, self.size),
dtype=obs_space.dtype
)
else:
# double player
self.observation_space = gym.spaces.tuple.Tuple(
[
gym.spaces.Box(
low=np.min(obs_space[0].low),
high=np.max(obs_space[0].high),
shape=(self.size, self.size),
dtype=obs_space[0].dtype
) for _ in range(len(obs_space))
]
)
def observation(self, frame):
"""returns the current observation from a frame"""
if self.builtin_wrap:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
else:
frames = []
for one_frame in frame:
one_frame = cv2.cvtColor(one_frame, cv2.COLOR_RGB2GRAY)
one_frame = cv2.resize(one_frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
frames.append(one_frame)
return frames
class FrameStackWrapperCompetitiveRl(gym.Wrapper):
"""Stack n_frames last frames.
:param gym.Env env: the environment to wrap.
:param int n_frames: the number of frames to stack.
"""
def __init__(self, env, n_frames, builtin_wrap):
super().__init__(env)
self.n_frames = n_frames
self.builtin_wrap = builtin_wrap
obs_space = env.observation_space
if self.builtin_wrap:
self.frames = deque([], maxlen=n_frames)
shape = (n_frames, ) + obs_space.shape
self.observation_space = gym.spaces.Box(
low=np.min(obs_space.low), high=np.max(obs_space.high), shape=shape, dtype=obs_space.dtype
)
else:
self.frames = [deque([], maxlen=n_frames) for _ in range(len(obs_space))]
shape = (n_frames, ) + obs_space[0].shape
self.observation_space = gym.spaces.tuple.Tuple(
[
gym.spaces.Box(
low=np.min(obs_space[0].low),
high=np.max(obs_space[0].high),
shape=shape,
dtype=obs_space[0].dtype
) for _ in range(len(obs_space))
]
)
def reset(self):
if self.builtin_wrap:
obs = self.env.reset()
for _ in range(self.n_frames):
self.frames.append(obs)
return self._get_ob(self.frames)
else:
obs = self.env.reset()
for i, one_obs in enumerate(obs):
for _ in range(self.n_frames):
self.frames[i].append(one_obs)
return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))])
def step(self, action):
obs, reward, done, info = self.env.step(action)
if self.builtin_wrap:
self.frames.append(obs)
return self._get_ob(self.frames), reward, done, info
else:
for i, one_obs in enumerate(obs):
self.frames[i].append(one_obs)
return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))], axis=0), reward, done, info
@staticmethod
def _get_ob(frames):
# the original wrapper use `LazyFrames` but since we use np buffer,
# it has no effect
return np.stack(frames, axis=0)