Spaces:
Sleeping
Sleeping
from itertools import chain | |
from gym_minigrid.minigrid import * | |
from gym_minigrid.register import register | |
from gym_minigrid.envs import DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, \ | |
DiverseExit8x8Env, Exiter8x8Env, Helper8x8Env | |
from gym_minigrid.envs import DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar, \ | |
EasyTeachingGamesGrammar, ExiterGrammar | |
import time | |
from collections import deque | |
class SocialEnvMetaGrammar(object): | |
def __init__(self, grammar_list, env_list): | |
self.templates = [] | |
self.things = [] | |
self.original_template_idx = [] | |
self.original_thing_idx = [] | |
self.meta_template_idx_to_env_name = {} | |
self.meta_thing_idx_to_env_name = {} | |
self.template_idx, self.thing_idx = 0, 0 | |
env_names = [e.__class__.__name__ for e in env_list] | |
for g, env_name in zip(grammar_list, env_names): | |
# add templates | |
self.templates += g.templates | |
# add things | |
self.things += g.things | |
# save original idx for both | |
self.original_template_idx += list(range(0, len(g.templates))) | |
self.original_thing_idx += list(range(0, len(g.things))) | |
# update meta_idx to env_names dictionaries | |
self.meta_template_idx_to_env_name.update(dict.fromkeys(list(range(self.template_idx, | |
self.template_idx + len(g.templates))), | |
env_name)) | |
self.template_idx += len(g.templates) | |
self.meta_thing_idx_to_env_name.update(dict.fromkeys(list(range(self.thing_idx, | |
self.thing_idx + len(g.things))), | |
env_name)) | |
self.thing_idx += len(g.things) | |
self.grammar_action_space = spaces.MultiDiscrete([len(self.templates), len(self.things)]) | |
def construct_utterance(self, action): | |
return self.templates[int(action[0])] + " " + self.things[int(action[1])] + " " | |
def random_utterance(self): | |
return np.random.choice(self.templates) + " " + np.random.choice(self.things) + " " | |
def construct_original_action(self, action, current_env_name): | |
template_env_name = self.meta_template_idx_to_env_name[int(action[0])] | |
thing_env_name = self.meta_thing_idx_to_env_name[int(action[1])] | |
if template_env_name == current_env_name and thing_env_name == current_env_name: | |
original_action = [self.original_template_idx[int(action[0])], self.original_thing_idx[int(action[1])]] | |
else: | |
original_action = [np.nan, np.nan] | |
return original_action | |
class SocialEnv(gym.Env): | |
""" | |
Meta-Environment containing all other environment (multi-task learning) | |
""" | |
def __init__( | |
self, | |
size=8, | |
hidden_npc=False, | |
is_test_env=False | |
): | |
# Number of cells (width and height) in the agent view | |
self.agent_view_size = 7 | |
# Number of object dimensions (i.e. number of channels in symbolic image) | |
self.nb_obj_dims = 4 | |
# Observations are dictionaries containing an | |
# encoding of the grid and a textual 'mission' string | |
self.observation_space = spaces.Box( | |
low=0, | |
high=255, | |
shape=(self.agent_view_size, self.agent_view_size, self.nb_obj_dims), | |
dtype='uint8' | |
) | |
self.observation_space = spaces.Dict({ | |
'image': self.observation_space | |
}) | |
self.hidden_npc = hidden_npc # TODO: implement hidden npc | |
# TODO get max step from env list | |
self.env_list = [DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, DiverseExit8x8Env, | |
Exiter8x8Env] | |
self.all_npc_utterance_actions = sorted(list(set(chain(*[e.all_npc_utterance_actions for e in self.env_list])))) | |
self.grammar_list = [DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar, | |
EasyTeachingGamesGrammar, ExiterGrammar] | |
if is_test_env: | |
self.env_list[-1] = Helper8x8Env | |
# instanciate all envs | |
self.env_list = [env() for env in self.env_list] | |
self.current_env = None | |
self.metaGrammar = SocialEnvMetaGrammar(self.grammar_list, self.env_list) | |
# Actions are discrete integer values | |
self.action_space = spaces.MultiDiscrete([len(MiniGridEnv.Actions), | |
*self.metaGrammar.grammar_action_space.nvec]) | |
self.actions = MiniGridEnv.Actions | |
self._window = None | |
def reset(self): | |
# select a new social environment at random, for each new episode | |
old_window = None | |
if self.current_env: # a previous env exists, save old window | |
old_window = self.current_env.window | |
# sample new environment | |
self.current_env = np.random.choice(self.env_list) | |
obs = self.current_env.reset() | |
# carry on window if this env is not the first | |
if old_window: | |
self.current_env.window = old_window | |
return obs | |
def seed(self, seed=1337): | |
# Seed the random number generator | |
for env in self.env_list: | |
env.seed(seed) | |
np.random.seed(seed) | |
return [seed] | |
def step(self, action): | |
assert (self.current_env) | |
if len(action) == 1: # agent cannot speak | |
utterance_action = [np.nan, np.nan] | |
else: | |
utterance_action = action[1:] | |
if len(action) >= 1 and not all(np.isnan(utterance_action)): # if agent speaks, contruct env-specific action | |
action[1:] = self.metaGrammar.construct_original_action(action[1:], self.current_env.__class__.__name__) | |
return self.current_env.step(action) | |
def window(self): | |
return self.current_env.window | |
def window(self, value): | |
self.current_env.window = value | |
def render(self, *args, **kwargs): | |
assert self.current_env | |
return self.current_env.render(*args, **kwargs) | |
def step_count(self): | |
return self.current_env.step_count | |
def get_mission(self): | |
return self.current_env.get_mission() | |
class SocialEnv8x8Env(SocialEnv): | |
def __init__(self, **kwargs): | |
super().__init__(size=8, **kwargs) | |
register( | |
id='MiniGrid-SocialEnv-5x5-v0', | |
entry_point='gym_minigrid.envs:SocialEnvEnv' | |
) | |
register( | |
id='MiniGrid-SocialEnv-8x8-v0', | |
entry_point='gym_minigrid.envs:SocialEnv8x8Env' | |
) | |