Spaces:
Running
Running
File size: 6,900 Bytes
be5548b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from itertools import chain
from gym_minigrid.minigrid import *
from gym_minigrid.register import register
from gym_minigrid.envs import DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, \
DiverseExit8x8Env, Exiter8x8Env, Helper8x8Env
from gym_minigrid.envs import DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar, \
EasyTeachingGamesGrammar, ExiterGrammar
import time
from collections import deque
class SocialEnvMetaGrammar(object):
def __init__(self, grammar_list, env_list):
self.templates = []
self.things = []
self.original_template_idx = []
self.original_thing_idx = []
self.meta_template_idx_to_env_name = {}
self.meta_thing_idx_to_env_name = {}
self.template_idx, self.thing_idx = 0, 0
env_names = [e.__class__.__name__ for e in env_list]
for g, env_name in zip(grammar_list, env_names):
# add templates
self.templates += g.templates
# add things
self.things += g.things
# save original idx for both
self.original_template_idx += list(range(0, len(g.templates)))
self.original_thing_idx += list(range(0, len(g.things)))
# update meta_idx to env_names dictionaries
self.meta_template_idx_to_env_name.update(dict.fromkeys(list(range(self.template_idx,
self.template_idx + len(g.templates))),
env_name))
self.template_idx += len(g.templates)
self.meta_thing_idx_to_env_name.update(dict.fromkeys(list(range(self.thing_idx,
self.thing_idx + len(g.things))),
env_name))
self.thing_idx += len(g.things)
self.grammar_action_space = spaces.MultiDiscrete([len(self.templates), len(self.things)])
@classmethod
def construct_utterance(self, action):
return self.templates[int(action[0])] + " " + self.things[int(action[1])] + " "
@classmethod
def random_utterance(self):
return np.random.choice(self.templates) + " " + np.random.choice(self.things) + " "
def construct_original_action(self, action, current_env_name):
template_env_name = self.meta_template_idx_to_env_name[int(action[0])]
thing_env_name = self.meta_thing_idx_to_env_name[int(action[1])]
if template_env_name == current_env_name and thing_env_name == current_env_name:
original_action = [self.original_template_idx[int(action[0])], self.original_thing_idx[int(action[1])]]
else:
original_action = [np.nan, np.nan]
return original_action
class SocialEnv(gym.Env):
"""
Meta-Environment containing all other environment (multi-task learning)
"""
def __init__(
self,
size=8,
hidden_npc=False,
is_test_env=False
):
# Number of cells (width and height) in the agent view
self.agent_view_size = 7
# Number of object dimensions (i.e. number of channels in symbolic image)
self.nb_obj_dims = 4
# Observations are dictionaries containing an
# encoding of the grid and a textual 'mission' string
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(self.agent_view_size, self.agent_view_size, self.nb_obj_dims),
dtype='uint8'
)
self.observation_space = spaces.Dict({
'image': self.observation_space
})
self.hidden_npc = hidden_npc # TODO: implement hidden npc
# TODO get max step from env list
self.env_list = [DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, DiverseExit8x8Env,
Exiter8x8Env]
self.all_npc_utterance_actions = sorted(list(set(chain(*[e.all_npc_utterance_actions for e in self.env_list]))))
self.grammar_list = [DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar,
EasyTeachingGamesGrammar, ExiterGrammar]
if is_test_env:
self.env_list[-1] = Helper8x8Env
# instanciate all envs
self.env_list = [env() for env in self.env_list]
self.current_env = None
self.metaGrammar = SocialEnvMetaGrammar(self.grammar_list, self.env_list)
# Actions are discrete integer values
self.action_space = spaces.MultiDiscrete([len(MiniGridEnv.Actions),
*self.metaGrammar.grammar_action_space.nvec])
self.actions = MiniGridEnv.Actions
self._window = None
def reset(self):
# select a new social environment at random, for each new episode
old_window = None
if self.current_env: # a previous env exists, save old window
old_window = self.current_env.window
# sample new environment
self.current_env = np.random.choice(self.env_list)
obs = self.current_env.reset()
# carry on window if this env is not the first
if old_window:
self.current_env.window = old_window
return obs
def seed(self, seed=1337):
# Seed the random number generator
for env in self.env_list:
env.seed(seed)
np.random.seed(seed)
return [seed]
def step(self, action):
assert (self.current_env)
if len(action) == 1: # agent cannot speak
utterance_action = [np.nan, np.nan]
else:
utterance_action = action[1:]
if len(action) >= 1 and not all(np.isnan(utterance_action)): # if agent speaks, contruct env-specific action
action[1:] = self.metaGrammar.construct_original_action(action[1:], self.current_env.__class__.__name__)
return self.current_env.step(action)
@property
def window(self):
return self.current_env.window
@window.setter
def window(self, value):
self.current_env.window = value
def render(self, *args, **kwargs):
assert self.current_env
return self.current_env.render(*args, **kwargs)
@property
def step_count(self):
return self.current_env.step_count
def get_mission(self):
return self.current_env.get_mission()
class SocialEnv8x8Env(SocialEnv):
def __init__(self, **kwargs):
super().__init__(size=8, **kwargs)
register(
id='MiniGrid-SocialEnv-5x5-v0',
entry_point='gym_minigrid.envs:SocialEnvEnv'
)
register(
id='MiniGrid-SocialEnv-8x8-v0',
entry_point='gym_minigrid.envs:SocialEnv8x8Env'
)
|