|
import cv2 |
|
import gym |
|
import os.path as osp |
|
import numpy as np |
|
from typing import Union, Optional |
|
from collections import deque |
|
from competitive_rl.pong.builtin_policies import get_builtin_agent_names, single_obs_space, single_act_space, get_random_policy, get_rule_based_policy |
|
from competitive_rl.utils.policy_serving import Policy |
|
|
|
|
|
def get_compute_action_function_ours(agent_name, num_envs=1): |
|
resource_dir = osp.join(osp.dirname(__file__), "resources", "pong") |
|
if agent_name == "STRONG": |
|
return Policy( |
|
single_obs_space, |
|
single_act_space, |
|
num_envs, |
|
osp.join(resource_dir, "checkpoint-strong.pkl"), |
|
use_light_model=False |
|
) |
|
if agent_name == "MEDIUM": |
|
return Policy( |
|
single_obs_space, |
|
single_act_space, |
|
num_envs, |
|
osp.join(resource_dir, "checkpoint-medium.pkl"), |
|
use_light_model=True |
|
) |
|
if agent_name == "ALPHA_PONG": |
|
return Policy( |
|
single_obs_space, |
|
single_act_space, |
|
num_envs, |
|
osp.join(resource_dir, "checkpoint-alphapong.pkl"), |
|
use_light_model=False |
|
) |
|
if agent_name == "WEAK": |
|
return Policy( |
|
single_obs_space, |
|
single_act_space, |
|
num_envs, |
|
osp.join(resource_dir, "checkpoint-weak.pkl"), |
|
use_light_model=True |
|
) |
|
if agent_name == "RANDOM": |
|
return get_random_policy(num_envs) |
|
if agent_name == "RULE_BASED": |
|
return get_rule_based_policy(num_envs) |
|
raise ValueError("Unknown agent name: {}".format(agent_name)) |
|
|
|
|
|
class BuiltinOpponentWrapper(gym.Wrapper): |
|
|
|
def __init__(self, env: 'gym.Env', num_envs: int = 1) -> None: |
|
super().__init__(env) |
|
self.agents = { |
|
agent_name: get_compute_action_function_ours(agent_name, num_envs) |
|
for agent_name in get_builtin_agent_names() |
|
} |
|
self.agent_names = list(self.agents) |
|
self.prev_opponent_obs = None |
|
self.current_opponent_name = "RULE_BASED" |
|
self.current_opponent = self.agents[self.current_opponent_name] |
|
self.observation_space = env.observation_space[0] |
|
self.action_space = env.action_space[0] |
|
self.num_envs = num_envs |
|
|
|
def reset_opponent(self, agent_name: str) -> None: |
|
assert agent_name in self.agent_names, (agent_name, self.agent_names) |
|
self.current_opponent_name = agent_name |
|
self.current_opponent = self.agents[self.current_opponent_name] |
|
|
|
def step(self, action): |
|
tuple_action = (action.item(), self.current_opponent(self.prev_opponent_obs)) |
|
obs, rew, done, info = self.env.step(tuple_action) |
|
self.prev_opponent_obs = obs[1] |
|
|
|
|
|
|
|
return obs[0], rew[0], done, info |
|
|
|
def reset(self): |
|
obs = self.env.reset() |
|
self.prev_opponent_obs = obs[1] |
|
return obs[0] |
|
|
|
def seed(self, s): |
|
self.env.seed(s) |
|
|
|
|
|
def wrap_env(env_id, builtin_wrap, opponent, frame_stack=4, warp_frame=True, only_info=False): |
|
"""Configure environment for DeepMind-style Atari. The observation is |
|
channel-first: (c, h, w) instead of (h, w, c). |
|
|
|
:param str env_id: the atari environment id. |
|
:param bool episode_life: wrap the episode life wrapper. |
|
:param bool clip_rewards: wrap the reward clipping wrapper. |
|
:param int frame_stack: wrap the frame stacking wrapper. |
|
:param bool scale: wrap the scaling observation wrapper. |
|
:param bool warp_frame: wrap the grayscale + resize observation wrapper. |
|
:return: the wrapped atari environment. |
|
""" |
|
if not only_info: |
|
env = gym.make(env_id) |
|
if builtin_wrap: |
|
env = BuiltinOpponentWrapper(env) |
|
env.reset_opponent(opponent) |
|
|
|
if warp_frame: |
|
env = WarpFrameWrapperCompetitveRl(env, builtin_wrap) |
|
if frame_stack: |
|
env = FrameStackWrapperCompetitiveRl(env, frame_stack, builtin_wrap) |
|
return env |
|
else: |
|
wrapper_info = '' |
|
if builtin_wrap: |
|
wrapper_info += BuiltinOpponentWrapper.__name__ + '\n' |
|
if warp_frame: |
|
wrapper_info = WarpFrameWrapperCompetitveRl.__name__ + '\n' |
|
if frame_stack: |
|
wrapper_info = FrameStackWrapperCompetitiveRl.__name__ + '\n' |
|
return wrapper_info |
|
|
|
|
|
class WarpFrameWrapperCompetitveRl(gym.ObservationWrapper): |
|
"""Warp frames to 84x84 as done in the Nature paper and later work. |
|
|
|
:param gym.Env env: the environment to wrap. |
|
""" |
|
|
|
def __init__(self, env, builtin_wrap): |
|
super().__init__(env) |
|
self.size = 84 |
|
obs_space = env.observation_space |
|
self.builtin_wrap = builtin_wrap |
|
if builtin_wrap: |
|
|
|
self.observation_space = gym.spaces.Box( |
|
low=np.min(obs_space.low), |
|
high=np.max(obs_space.high), |
|
shape=(self.size, self.size), |
|
dtype=obs_space.dtype |
|
) |
|
else: |
|
|
|
self.observation_space = gym.spaces.tuple.Tuple( |
|
[ |
|
gym.spaces.Box( |
|
low=np.min(obs_space[0].low), |
|
high=np.max(obs_space[0].high), |
|
shape=(self.size, self.size), |
|
dtype=obs_space[0].dtype |
|
) for _ in range(len(obs_space)) |
|
] |
|
) |
|
|
|
def observation(self, frame): |
|
"""returns the current observation from a frame""" |
|
if self.builtin_wrap: |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) |
|
return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA) |
|
else: |
|
frames = [] |
|
for one_frame in frame: |
|
one_frame = cv2.cvtColor(one_frame, cv2.COLOR_RGB2GRAY) |
|
one_frame = cv2.resize(one_frame, (self.size, self.size), interpolation=cv2.INTER_AREA) |
|
frames.append(one_frame) |
|
return frames |
|
|
|
|
|
class FrameStackWrapperCompetitiveRl(gym.Wrapper): |
|
"""Stack n_frames last frames. |
|
|
|
:param gym.Env env: the environment to wrap. |
|
:param int n_frames: the number of frames to stack. |
|
""" |
|
|
|
def __init__(self, env, n_frames, builtin_wrap): |
|
super().__init__(env) |
|
self.n_frames = n_frames |
|
|
|
self.builtin_wrap = builtin_wrap |
|
obs_space = env.observation_space |
|
if self.builtin_wrap: |
|
self.frames = deque([], maxlen=n_frames) |
|
shape = (n_frames, ) + obs_space.shape |
|
self.observation_space = gym.spaces.Box( |
|
low=np.min(obs_space.low), high=np.max(obs_space.high), shape=shape, dtype=obs_space.dtype |
|
) |
|
else: |
|
self.frames = [deque([], maxlen=n_frames) for _ in range(len(obs_space))] |
|
shape = (n_frames, ) + obs_space[0].shape |
|
self.observation_space = gym.spaces.tuple.Tuple( |
|
[ |
|
gym.spaces.Box( |
|
low=np.min(obs_space[0].low), |
|
high=np.max(obs_space[0].high), |
|
shape=shape, |
|
dtype=obs_space[0].dtype |
|
) for _ in range(len(obs_space)) |
|
] |
|
) |
|
|
|
def reset(self): |
|
if self.builtin_wrap: |
|
obs = self.env.reset() |
|
for _ in range(self.n_frames): |
|
self.frames.append(obs) |
|
return self._get_ob(self.frames) |
|
else: |
|
obs = self.env.reset() |
|
for i, one_obs in enumerate(obs): |
|
for _ in range(self.n_frames): |
|
self.frames[i].append(one_obs) |
|
return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))]) |
|
|
|
def step(self, action): |
|
obs, reward, done, info = self.env.step(action) |
|
if self.builtin_wrap: |
|
self.frames.append(obs) |
|
return self._get_ob(self.frames), reward, done, info |
|
else: |
|
for i, one_obs in enumerate(obs): |
|
self.frames[i].append(one_obs) |
|
return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))], axis=0), reward, done, info |
|
|
|
@staticmethod |
|
def _get_ob(frames): |
|
|
|
|
|
return np.stack(frames, axis=0) |
|
|