zjowowen's picture
init space
079c32c
from functools import partial
import gym
from gym.spaces import Box
from gym.wrappers import TimeLimit
import numpy as np
from .multiagentenv import MultiAgentEnv
from .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs
# using code from https://github.com/ikostrikov/pytorch-ddpg-naf
class NormalizedActions(gym.ActionWrapper):
def _action(self, action):
action = (action + 1) / 2
action *= (self.action_space.high - self.action_space.low)
action += self.action_space.low
return action
def action(self, action_):
return self._action(action_)
def _reverse_action(self, action):
action -= self.action_space.low
action /= (self.action_space.high - self.action_space.low)
action = action * 2 - 1
return action
class MujocoMulti(MultiAgentEnv):
def __init__(self, batch_size=None, **kwargs):
super().__init__(batch_size, **kwargs)
self.add_agent_id = kwargs["env_args"]["add_agent_id"]
self.scenario = kwargs["env_args"]["scenario"] # e.g. Ant-v2
self.agent_conf = kwargs["env_args"]["agent_conf"] # e.g. '2x3'
self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges(
self.scenario, self.agent_conf
)
self.n_agents = len(self.agent_partitions)
self.n_actions = max([len(l) for l in self.agent_partitions])
self.obs_add_global_pos = kwargs["env_args"].get("obs_add_global_pos", False)
self.agent_obsk = kwargs["env_args"].get(
"agent_obsk", None
) # if None, fully observable else k>=0 implies observe nearest k agents or joints
self.agent_obsk_agents = kwargs["env_args"].get(
"agent_obsk_agents", False
) # observe full k nearest agents (True) or just single joints (False)
if self.agent_obsk is not None:
self.k_categories_label = kwargs["env_args"].get("k_categories")
if self.k_categories_label is None:
if self.scenario in ["Ant-v2", "manyagent_ant"]:
self.k_categories_label = "qpos,qvel,cfrc_ext|qpos"
elif self.scenario in ["Humanoid-v2", "HumanoidStandup-v2"]:
self.k_categories_label = "qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos"
elif self.scenario in ["Reacher-v2"]:
self.k_categories_label = "qpos,qvel,fingertip_dist|qpos"
elif self.scenario in ["coupled_half_cheetah"]:
self.k_categories_label = "qpos,qvel,ten_J,ten_length,ten_velocity|"
else:
self.k_categories_label = "qpos,qvel|qpos"
k_split = self.k_categories_label.split("|")
self.k_categories = [k_split[k if k < len(k_split) else -1].split(",") for k in range(self.agent_obsk + 1)]
self.global_categories_label = kwargs["env_args"].get("global_categories")
self.global_categories = self.global_categories_label.split(
","
) if self.global_categories_label is not None else []
if self.agent_obsk is not None:
self.k_dicts = [
get_joints_at_kdist(
agent_id,
self.agent_partitions,
self.mujoco_edges,
k=self.agent_obsk,
kagents=False,
) for agent_id in range(self.n_agents)
]
# load scenario from script
self.episode_limit = self.args.episode_limit
self.env_version = kwargs["env_args"].get("env_version", 2)
if self.env_version == 2:
try:
self.wrapped_env = NormalizedActions(gym.make(self.scenario))
except gym.error.Error: # env not in gym
if self.scenario in ["manyagent_ant"]:
from .manyagent_ant import ManyAgentAntEnv as this_env
elif self.scenario in ["manyagent_swimmer"]:
from .manyagent_swimmer import ManyAgentSwimmerEnv as this_env
elif self.scenario in ["coupled_half_cheetah"]:
from .coupled_half_cheetah import CoupledHalfCheetah as this_env
else:
raise NotImplementedError('Custom env not implemented!')
self.wrapped_env = NormalizedActions(
TimeLimit(this_env(**kwargs["env_args"]), max_episode_steps=self.episode_limit)
)
else:
assert False, "not implemented!"
self.timelimit_env = self.wrapped_env.env
self.timelimit_env._max_episode_steps = self.episode_limit
if gym.version.VERSION > '0.22.0': # for compatibility
# get original no wrapped env
self.env = self.timelimit_env.env.env.env.env
else:
self.env = self.timelimit_env.env
self.timelimit_env.reset()
self.obs_size = self.get_obs_size()
# COMPATIBILITY
self.n = self.n_agents
self.observation_space = [
Box(low=np.array([-10] * self.n_agents), high=np.array([10] * self.n_agents)) for _ in range(self.n_agents)
]
acdims = [len(ap) for ap in self.agent_partitions]
self.action_space = tuple(
[
Box(
self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])],
self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])]
) for a in range(self.n_agents)
]
)
def step(self, actions):
# need to remove dummy actions that arise due to unequal action vector sizes across agents
flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)])
obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions)
self.steps += 1
info = {}
info.update(info_n)
if done_n:
if self.steps < self.episode_limit:
info["episode_limit"] = False # the next state will be masked out
else:
info["episode_limit"] = True # the next state will not be masked out
obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
return obs, reward_n, done_n, info
def get_obs(self):
""" Returns all agent observat3ions in a list """
obs_n = []
for a in range(self.n_agents):
obs_n.append(self.get_obs_agent(a))
return np.array(obs_n).astype(np.float32)
def get_obs_agent(self, agent_id):
if self.agent_obsk is None:
return self.env._get_obs()
else:
return build_obs(
self.env,
self.k_dicts[agent_id],
self.k_categories,
self.mujoco_globals,
self.global_categories,
vec_len=getattr(self, "obs_size", None)
)
def get_obs_size(self):
""" Returns the shape of the observation """
if self.agent_obsk is None:
return self.get_obs_agent(0).size
else:
return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)])
def get_state(self, team=None):
# TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.)
state_n = []
if self.add_agent_id:
state = self.env._get_obs()
for a in range(self.n_agents):
agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)
agent_id_feats[a] = 1.0
state_i = np.concatenate([state, agent_id_feats])
state_n.append(state_i)
else:
for a in range(self.n_agents):
state_n.append(self.env._get_obs())
return np.array(state_n).astype(np.float32)
def get_state_size(self):
""" Returns the shape of the state"""
return len(self.get_state())
def get_avail_actions(self): # all actions are always available
return np.ones(shape=(
self.n_agents,
self.n_actions,
))
def get_avail_agent_actions(self, agent_id):
""" Returns the available actions for agent_id """
return np.ones(shape=(self.n_actions, ))
def get_total_actions(self):
""" Returns the total number of actions an agent could ever take """
return self.n_actions # CAREFUL! - for continuous dims, this is action space dim rather
# return self.env.action_space.shape[0]
def get_stats(self):
return {}
# TODO: Temp hack
def get_agg_stats(self, stats):
return {}
def reset(self, **kwargs):
""" Returns initial observations and states"""
self.steps = 0
self.timelimit_env.reset()
obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
return obs
def render(self, **kwargs):
self.env.render(**kwargs)
def close(self):
pass
#raise NotImplementedError
def seed(self, args):
pass
def get_env_info(self):
env_info = {
"state_shape": self.get_state_size(),
"obs_shape": self.get_obs_size(),
"n_actions": self.get_total_actions(),
"n_agents": self.n_agents,
"episode_limit": self.episode_limit,
"action_spaces": self.action_space,
"actions_dtype": np.float32,
"normalise_actions": False
}
return env_info