from functools import partial import gym from gym.spaces import Box from gym.wrappers import TimeLimit import numpy as np from .multiagentenv import MultiAgentEnv from .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs # using code from https://github.com/ikostrikov/pytorch-ddpg-naf class NormalizedActions(gym.ActionWrapper): def _action(self, action): action = (action + 1) / 2 action *= (self.action_space.high - self.action_space.low) action += self.action_space.low return action def action(self, action_): return self._action(action_) def _reverse_action(self, action): action -= self.action_space.low action /= (self.action_space.high - self.action_space.low) action = action * 2 - 1 return action class MujocoMulti(MultiAgentEnv): def __init__(self, batch_size=None, **kwargs): super().__init__(batch_size, **kwargs) self.add_agent_id = kwargs["env_args"]["add_agent_id"] self.scenario = kwargs["env_args"]["scenario"] # e.g. Ant-v2 self.agent_conf = kwargs["env_args"]["agent_conf"] # e.g. '2x3' self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges( self.scenario, self.agent_conf ) self.n_agents = len(self.agent_partitions) self.n_actions = max([len(l) for l in self.agent_partitions]) self.obs_add_global_pos = kwargs["env_args"].get("obs_add_global_pos", False) self.agent_obsk = kwargs["env_args"].get( "agent_obsk", None ) # if None, fully observable else k>=0 implies observe nearest k agents or joints self.agent_obsk_agents = kwargs["env_args"].get( "agent_obsk_agents", False ) # observe full k nearest agents (True) or just single joints (False) if self.agent_obsk is not None: self.k_categories_label = kwargs["env_args"].get("k_categories") if self.k_categories_label is None: if self.scenario in ["Ant-v2", "manyagent_ant"]: self.k_categories_label = "qpos,qvel,cfrc_ext|qpos" elif self.scenario in ["Humanoid-v2", "HumanoidStandup-v2"]: self.k_categories_label = "qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos" elif self.scenario in ["Reacher-v2"]: self.k_categories_label = "qpos,qvel,fingertip_dist|qpos" elif self.scenario in ["coupled_half_cheetah"]: self.k_categories_label = "qpos,qvel,ten_J,ten_length,ten_velocity|" else: self.k_categories_label = "qpos,qvel|qpos" k_split = self.k_categories_label.split("|") self.k_categories = [k_split[k if k < len(k_split) else -1].split(",") for k in range(self.agent_obsk + 1)] self.global_categories_label = kwargs["env_args"].get("global_categories") self.global_categories = self.global_categories_label.split( "," ) if self.global_categories_label is not None else [] if self.agent_obsk is not None: self.k_dicts = [ get_joints_at_kdist( agent_id, self.agent_partitions, self.mujoco_edges, k=self.agent_obsk, kagents=False, ) for agent_id in range(self.n_agents) ] # load scenario from script self.episode_limit = self.args.episode_limit self.env_version = kwargs["env_args"].get("env_version", 2) if self.env_version == 2: try: self.wrapped_env = NormalizedActions(gym.make(self.scenario)) except gym.error.Error: # env not in gym if self.scenario in ["manyagent_ant"]: from .manyagent_ant import ManyAgentAntEnv as this_env elif self.scenario in ["manyagent_swimmer"]: from .manyagent_swimmer import ManyAgentSwimmerEnv as this_env elif self.scenario in ["coupled_half_cheetah"]: from .coupled_half_cheetah import CoupledHalfCheetah as this_env else: raise NotImplementedError('Custom env not implemented!') self.wrapped_env = NormalizedActions( TimeLimit(this_env(**kwargs["env_args"]), max_episode_steps=self.episode_limit) ) else: assert False, "not implemented!" self.timelimit_env = self.wrapped_env.env self.timelimit_env._max_episode_steps = self.episode_limit if gym.version.VERSION > '0.22.0': # for compatibility # get original no wrapped env self.env = self.timelimit_env.env.env.env.env else: self.env = self.timelimit_env.env self.timelimit_env.reset() self.obs_size = self.get_obs_size() # COMPATIBILITY self.n = self.n_agents self.observation_space = [ Box(low=np.array([-10] * self.n_agents), high=np.array([10] * self.n_agents)) for _ in range(self.n_agents) ] acdims = [len(ap) for ap in self.agent_partitions] self.action_space = tuple( [ Box( self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])], self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])] ) for a in range(self.n_agents) ] ) def step(self, actions): # need to remove dummy actions that arise due to unequal action vector sizes across agents flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)]) obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions) self.steps += 1 info = {} info.update(info_n) if done_n: if self.steps < self.episode_limit: info["episode_limit"] = False # the next state will be masked out else: info["episode_limit"] = True # the next state will not be masked out obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()} return obs, reward_n, done_n, info def get_obs(self): """ Returns all agent observat3ions in a list """ obs_n = [] for a in range(self.n_agents): obs_n.append(self.get_obs_agent(a)) return np.array(obs_n).astype(np.float32) def get_obs_agent(self, agent_id): if self.agent_obsk is None: return self.env._get_obs() else: return build_obs( self.env, self.k_dicts[agent_id], self.k_categories, self.mujoco_globals, self.global_categories, vec_len=getattr(self, "obs_size", None) ) def get_obs_size(self): """ Returns the shape of the observation """ if self.agent_obsk is None: return self.get_obs_agent(0).size else: return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)]) def get_state(self, team=None): # TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.) state_n = [] if self.add_agent_id: state = self.env._get_obs() for a in range(self.n_agents): agent_id_feats = np.zeros(self.n_agents, dtype=np.float32) agent_id_feats[a] = 1.0 state_i = np.concatenate([state, agent_id_feats]) state_n.append(state_i) else: for a in range(self.n_agents): state_n.append(self.env._get_obs()) return np.array(state_n).astype(np.float32) def get_state_size(self): """ Returns the shape of the state""" return len(self.get_state()) def get_avail_actions(self): # all actions are always available return np.ones(shape=( self.n_agents, self.n_actions, )) def get_avail_agent_actions(self, agent_id): """ Returns the available actions for agent_id """ return np.ones(shape=(self.n_actions, )) def get_total_actions(self): """ Returns the total number of actions an agent could ever take """ return self.n_actions # CAREFUL! - for continuous dims, this is action space dim rather # return self.env.action_space.shape[0] def get_stats(self): return {} # TODO: Temp hack def get_agg_stats(self, stats): return {} def reset(self, **kwargs): """ Returns initial observations and states""" self.steps = 0 self.timelimit_env.reset() obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()} return obs def render(self, **kwargs): self.env.render(**kwargs) def close(self): pass #raise NotImplementedError def seed(self, args): pass def get_env_info(self): env_info = { "state_shape": self.get_state_size(), "obs_shape": self.get_obs_size(), "n_actions": self.get_total_actions(), "n_agents": self.n_agents, "episode_limit": self.episode_limit, "action_spaces": self.action_space, "actions_dtype": np.float32, "normalise_actions": False } return env_info