grg's picture
Cleaned old git history
be5548b
from itertools import chain
from gym_minigrid.minigrid import *
from gym_minigrid.register import register
from gym_minigrid.envs import DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, \
DiverseExit8x8Env, Exiter8x8Env, Helper8x8Env
from gym_minigrid.envs import DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar, \
EasyTeachingGamesGrammar, ExiterGrammar
import time
from collections import deque
class SocialEnvMetaGrammar(object):
def __init__(self, grammar_list, env_list):
self.templates = []
self.things = []
self.original_template_idx = []
self.original_thing_idx = []
self.meta_template_idx_to_env_name = {}
self.meta_thing_idx_to_env_name = {}
self.template_idx, self.thing_idx = 0, 0
env_names = [e.__class__.__name__ for e in env_list]
for g, env_name in zip(grammar_list, env_names):
# add templates
self.templates += g.templates
# add things
self.things += g.things
# save original idx for both
self.original_template_idx += list(range(0, len(g.templates)))
self.original_thing_idx += list(range(0, len(g.things)))
# update meta_idx to env_names dictionaries
self.meta_template_idx_to_env_name.update(dict.fromkeys(list(range(self.template_idx,
self.template_idx + len(g.templates))),
env_name))
self.template_idx += len(g.templates)
self.meta_thing_idx_to_env_name.update(dict.fromkeys(list(range(self.thing_idx,
self.thing_idx + len(g.things))),
env_name))
self.thing_idx += len(g.things)
self.grammar_action_space = spaces.MultiDiscrete([len(self.templates), len(self.things)])
@classmethod
def construct_utterance(self, action):
return self.templates[int(action[0])] + " " + self.things[int(action[1])] + " "
@classmethod
def random_utterance(self):
return np.random.choice(self.templates) + " " + np.random.choice(self.things) + " "
def construct_original_action(self, action, current_env_name):
template_env_name = self.meta_template_idx_to_env_name[int(action[0])]
thing_env_name = self.meta_thing_idx_to_env_name[int(action[1])]
if template_env_name == current_env_name and thing_env_name == current_env_name:
original_action = [self.original_template_idx[int(action[0])], self.original_thing_idx[int(action[1])]]
else:
original_action = [np.nan, np.nan]
return original_action
class SocialEnv(gym.Env):
"""
Meta-Environment containing all other environment (multi-task learning)
"""
def __init__(
self,
size=8,
hidden_npc=False,
is_test_env=False
):
# Number of cells (width and height) in the agent view
self.agent_view_size = 7
# Number of object dimensions (i.e. number of channels in symbolic image)
self.nb_obj_dims = 4
# Observations are dictionaries containing an
# encoding of the grid and a textual 'mission' string
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(self.agent_view_size, self.agent_view_size, self.nb_obj_dims),
dtype='uint8'
)
self.observation_space = spaces.Dict({
'image': self.observation_space
})
self.hidden_npc = hidden_npc # TODO: implement hidden npc
# TODO get max step from env list
self.env_list = [DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, DiverseExit8x8Env,
Exiter8x8Env]
self.all_npc_utterance_actions = sorted(list(set(chain(*[e.all_npc_utterance_actions for e in self.env_list]))))
self.grammar_list = [DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar,
EasyTeachingGamesGrammar, ExiterGrammar]
if is_test_env:
self.env_list[-1] = Helper8x8Env
# instanciate all envs
self.env_list = [env() for env in self.env_list]
self.current_env = None
self.metaGrammar = SocialEnvMetaGrammar(self.grammar_list, self.env_list)
# Actions are discrete integer values
self.action_space = spaces.MultiDiscrete([len(MiniGridEnv.Actions),
*self.metaGrammar.grammar_action_space.nvec])
self.actions = MiniGridEnv.Actions
self._window = None
def reset(self):
# select a new social environment at random, for each new episode
old_window = None
if self.current_env: # a previous env exists, save old window
old_window = self.current_env.window
# sample new environment
self.current_env = np.random.choice(self.env_list)
obs = self.current_env.reset()
# carry on window if this env is not the first
if old_window:
self.current_env.window = old_window
return obs
def seed(self, seed=1337):
# Seed the random number generator
for env in self.env_list:
env.seed(seed)
np.random.seed(seed)
return [seed]
def step(self, action):
assert (self.current_env)
if len(action) == 1: # agent cannot speak
utterance_action = [np.nan, np.nan]
else:
utterance_action = action[1:]
if len(action) >= 1 and not all(np.isnan(utterance_action)): # if agent speaks, contruct env-specific action
action[1:] = self.metaGrammar.construct_original_action(action[1:], self.current_env.__class__.__name__)
return self.current_env.step(action)
@property
def window(self):
return self.current_env.window
@window.setter
def window(self, value):
self.current_env.window = value
def render(self, *args, **kwargs):
assert self.current_env
return self.current_env.render(*args, **kwargs)
@property
def step_count(self):
return self.current_env.step_count
def get_mission(self):
return self.current_env.get_mission()
class SocialEnv8x8Env(SocialEnv):
def __init__(self, **kwargs):
super().__init__(size=8, **kwargs)
register(
id='MiniGrid-SocialEnv-5x5-v0',
entry_point='gym_minigrid.envs:SocialEnvEnv'
)
register(
id='MiniGrid-SocialEnv-8x8-v0',
entry_point='gym_minigrid.envs:SocialEnv8x8Env'
)