File size: 2,404 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from collections import namedtuple
import numpy as np
def convert(dictionary):
return namedtuple('GenericDict', dictionary.keys())(**dictionary)
class MultiAgentEnv(object):
def __init__(self, batch_size=None, **kwargs):
# Unpack arguments from sacred
args = kwargs["env_args"]
if isinstance(args, dict):
args = convert(args)
self.args = args
if getattr(args, "seed", None) is not None:
self.seed = args.seed
self.rs = np.random.RandomState(self.seed) # initialise numpy random state
def step(self, actions):
""" Returns reward, terminated, info """
raise NotImplementedError
def get_obs(self):
""" Returns all agent observations in a list """
raise NotImplementedError
def get_obs_agent(self, agent_id):
""" Returns observation for agent_id """
raise NotImplementedError
def get_obs_size(self):
""" Returns the shape of the observation """
raise NotImplementedError
def get_state(self):
raise NotImplementedError
def get_state_size(self):
""" Returns the shape of the state"""
raise NotImplementedError
def get_avail_actions(self):
raise NotImplementedError
def get_avail_agent_actions(self, agent_id):
""" Returns the available actions for agent_id """
raise NotImplementedError
def get_total_actions(self):
""" Returns the total number of actions an agent could ever take """
# TODO: This is only suitable for a discrete 1 dimensional action space for each agent
raise NotImplementedError
def get_stats(self):
raise NotImplementedError
# TODO: Temp hack
def get_agg_stats(self, stats):
return {}
def reset(self):
""" Returns initial observations and states"""
raise NotImplementedError
def render(self):
raise NotImplementedError
def close(self):
raise NotImplementedError
def seed(self, seed):
raise NotImplementedError
def get_env_info(self):
env_info = {
"state_shape": self.get_state_size(),
"obs_shape": self.get_obs_size(),
"n_actions": self.get_total_actions(),
"n_agents": self.n_agents,
"episode_limit": self.episode_limit
}
return env_info
|