File size: 6,900 Bytes
be5548b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from itertools import chain
from gym_minigrid.minigrid import *
from gym_minigrid.register import register
 
from gym_minigrid.envs import DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, \
    DiverseExit8x8Env, Exiter8x8Env, Helper8x8Env
from gym_minigrid.envs import DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar, \
    EasyTeachingGamesGrammar, ExiterGrammar
import time
from collections import deque


class SocialEnvMetaGrammar(object):

    def __init__(self, grammar_list, env_list):
        self.templates = []
        self.things = []
        self.original_template_idx = []
        self.original_thing_idx = []

        self.meta_template_idx_to_env_name = {}
        self.meta_thing_idx_to_env_name = {}
        self.template_idx, self.thing_idx = 0, 0
        env_names = [e.__class__.__name__ for e in env_list]

        for g, env_name in zip(grammar_list, env_names):
            # add templates
            self.templates += g.templates
            # add things
            self.things += g.things

            # save original idx for both
            self.original_template_idx += list(range(0, len(g.templates)))
            self.original_thing_idx += list(range(0, len(g.things)))

            # update meta_idx to env_names dictionaries
            self.meta_template_idx_to_env_name.update(dict.fromkeys(list(range(self.template_idx,
                                                                               self.template_idx + len(g.templates))),
                                                                    env_name))
            self.template_idx += len(g.templates)

            self.meta_thing_idx_to_env_name.update(dict.fromkeys(list(range(self.thing_idx,
                                                                            self.thing_idx + len(g.things))),
                                                                 env_name))
            self.thing_idx += len(g.things)

        self.grammar_action_space = spaces.MultiDiscrete([len(self.templates), len(self.things)])

    @classmethod
    def construct_utterance(self, action):
        return self.templates[int(action[0])] + " " + self.things[int(action[1])] + " "

    @classmethod
    def random_utterance(self):
        return np.random.choice(self.templates) + " " + np.random.choice(self.things) + " "

    def construct_original_action(self, action, current_env_name):
        template_env_name = self.meta_template_idx_to_env_name[int(action[0])]
        thing_env_name = self.meta_thing_idx_to_env_name[int(action[1])]

        if template_env_name == current_env_name and thing_env_name == current_env_name:
            original_action = [self.original_template_idx[int(action[0])], self.original_thing_idx[int(action[1])]]
        else:
            original_action = [np.nan, np.nan]
        return original_action


class SocialEnv(gym.Env):
    """
    Meta-Environment containing all other environment (multi-task learning)
    """

    def __init__(
            self,
            size=8,
            hidden_npc=False,
            is_test_env=False

    ):

        # Number of cells (width and height) in the agent view
        self.agent_view_size = 7

        # Number of object dimensions (i.e. number of channels in symbolic image)
        self.nb_obj_dims = 4

        # Observations are dictionaries containing an
        # encoding of the grid and a textual 'mission' string
        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.agent_view_size, self.agent_view_size, self.nb_obj_dims),
            dtype='uint8'
        )
        self.observation_space = spaces.Dict({
            'image': self.observation_space
        })

        self.hidden_npc = hidden_npc  # TODO: implement hidden npc

        # TODO get max step from env list

        self.env_list = [DanceWithOneNPC8x8Env, CoinThief8x8Env, TalkItOutPolite8x8Env, ShowMe8x8Env, DiverseExit8x8Env,
                         Exiter8x8Env]
        self.all_npc_utterance_actions = sorted(list(set(chain(*[e.all_npc_utterance_actions for e in self.env_list]))))
        self.grammar_list = [DanceWithOneNPCGrammar, CoinThiefGrammar, TalkItOutPoliteGrammar, DemonstrationGrammar,
                             EasyTeachingGamesGrammar, ExiterGrammar]

        if is_test_env:
            self.env_list[-1] = Helper8x8Env

        # instanciate all envs
        self.env_list = [env() for env in self.env_list]

        self.current_env = None

        self.metaGrammar = SocialEnvMetaGrammar(self.grammar_list, self.env_list)

        # Actions are discrete integer values
        self.action_space = spaces.MultiDiscrete([len(MiniGridEnv.Actions),
                                                  *self.metaGrammar.grammar_action_space.nvec])
        self.actions = MiniGridEnv.Actions

        self._window = None

    def reset(self):
        # select a new social environment at random, for each new episode

        old_window = None
        if self.current_env:  # a previous env exists, save old window
            old_window = self.current_env.window

        # sample new environment
        self.current_env = np.random.choice(self.env_list)
        obs = self.current_env.reset()

        # carry on window if this env is not the first
        if old_window:
            self.current_env.window = old_window
        return obs

    def seed(self, seed=1337):
        # Seed the random number generator
        for env in self.env_list:
            env.seed(seed)
        np.random.seed(seed)
        return [seed]

    def step(self, action):
        assert (self.current_env)
        if len(action) == 1:  # agent cannot speak
            utterance_action = [np.nan, np.nan]
        else:
            utterance_action = action[1:]

        if len(action) >= 1 and not all(np.isnan(utterance_action)):  # if agent speaks, contruct env-specific action
            action[1:] = self.metaGrammar.construct_original_action(action[1:], self.current_env.__class__.__name__)

        return self.current_env.step(action)

    @property
    def window(self):
        return self.current_env.window

    @window.setter
    def window(self, value):
        self.current_env.window = value

    def render(self, *args, **kwargs):
        assert self.current_env
        return self.current_env.render(*args, **kwargs)

    @property
    def step_count(self):
        return self.current_env.step_count

    def get_mission(self):
        return self.current_env.get_mission()


class SocialEnv8x8Env(SocialEnv):
    def __init__(self, **kwargs):
        super().__init__(size=8, **kwargs)


register(
    id='MiniGrid-SocialEnv-5x5-v0',
    entry_point='gym_minigrid.envs:SocialEnvEnv'
)

register(
    id='MiniGrid-SocialEnv-8x8-v0',
    entry_point='gym_minigrid.envs:SocialEnv8x8Env'
)