File size: 11,931 Bytes
05c9ac2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
from mlagents.torch_utils import torch
import warnings
from typing import Tuple, Optional, List
from mlagents.trainers.torch_entities.layers import (
    LinearEncoder,
    Initialization,
    linear_layer,
    LayerNorm,
)
from mlagents.trainers.torch_entities.model_serialization import exporting_to_onnx
from mlagents.trainers.exception import UnityTrainerException


def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
    all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
    layer to mask the padding observations.
    """
    with torch.no_grad():

        if exporting_to_onnx.is_exporting():
            with warnings.catch_warnings():
                # We ignore a TracerWarning from PyTorch that warns that doing
                # shape[n].item() will cause the trace to be incorrect (the trace might
                # not generalize to other inputs)
                # We ignore this warning because we know the model will always be
                # run with inputs of the same shape
                warnings.simplefilter("ignore")
                # When exporting to ONNX, we want to transpose the entities. This is
                # because ONNX only support input in NCHW (channel first) format.
                # Barracuda also expect to get data in NCHW.
                entities = [
                    torch.transpose(obs, 2, 1).reshape(
                        -1, obs.shape[1].item(), obs.shape[2].item()
                    )
                    for obs in entities
                ]

        # Generate the masking tensors for each entities tensor (mask only if all zeros)
        key_masks: List[torch.Tensor] = [
            (torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
        ]
    return key_masks


class MultiHeadAttention(torch.nn.Module):

    NEG_INF = -1e6

    def __init__(self, embedding_size: int, num_heads: int):
        """
        Multi Head Attention module. We do not use the regular Torch implementation since
        Barracuda does not support some operators it uses.
        Takes as input to the forward method 3 tensors:
        - query: of dimensions (batch_size, number_of_queries, embedding_size)
        - key: of dimensions (batch_size, number_of_keys, embedding_size)
        - value: of dimensions (batch_size, number_of_keys, embedding_size)
        The forward method will return 2 tensors:
        - The output: (batch_size, number_of_queries, embedding_size)
        - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
        :param embedding_size: The size of the embeddings that will be generated (should be
        dividable by the num_heads)
        :param total_max_elements: The maximum total number of entities that can be passed to
        the module
        :param num_heads: The number of heads of the attention module
        """
        super().__init__()
        self.n_heads = num_heads
        self.head_size: int = embedding_size // self.n_heads
        self.embedding_size: int = self.head_size * self.n_heads

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        n_q: int,
        n_k: int,
        key_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        b = -1  # the batch size

        query = query.reshape(
            b, n_q, self.n_heads, self.head_size
        )  # (b, n_q, h, emb / h)
        key = key.reshape(b, n_k, self.n_heads, self.head_size)  # (b, n_k, h, emb / h)
        value = value.reshape(
            b, n_k, self.n_heads, self.head_size
        )  # (b, n_k, h, emb / h)

        query = query.permute([0, 2, 1, 3])  # (b, h, n_q, emb / h)
        # The next few lines are equivalent to : key.permute([0, 2, 3, 1])
        # This is a hack, ONNX will compress two permute operations and
        # Barracuda will not like seeing `permute([0,2,3,1])`
        key = key.permute([0, 2, 1, 3])  # (b, h, emb / h, n_k)
        key -= 1
        key += 1
        key = key.permute([0, 1, 3, 2])  # (b, h, emb / h, n_k)

        qk = torch.matmul(query, key)  # (b, h, n_q, n_k)

        if key_mask is None:
            qk = qk / (self.embedding_size**0.5)
        else:
            key_mask = key_mask.reshape(b, 1, 1, n_k)
            qk = (1 - key_mask) * qk / (
                self.embedding_size**0.5
            ) + key_mask * self.NEG_INF

        att = torch.softmax(qk, dim=3)  # (b, h, n_q, n_k)

        value = value.permute([0, 2, 1, 3])  # (b, h, n_k, emb / h)
        value_attention = torch.matmul(att, value)  # (b, h, n_q, emb / h)

        value_attention = value_attention.permute([0, 2, 1, 3])  # (b, n_q, h, emb / h)
        value_attention = value_attention.reshape(
            b, n_q, self.embedding_size
        )  # (b, n_q, emb)

        return value_attention, att


class EntityEmbedding(torch.nn.Module):
    """
    A module used to embed entities before passing them to a self-attention block.
    Used in conjunction with ResidualSelfAttention to encode information about a self
    and additional entities. Can also concatenate self to entities for ego-centric self-
    attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf.
    """

    def __init__(
        self,
        entity_size: int,
        entity_num_max_elements: Optional[int],
        embedding_size: int,
    ):
        """
        Constructs an EntityEmbedding module.
        :param x_self_size: Size of "self" entity.
        :param entity_size: Size of other entities.
        :param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
            Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
        :param embedding_size: Embedding size for the entity encoder.
        :param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
            self-attention.
        """
        super().__init__()
        self.self_size: int = 0
        self.entity_size: int = entity_size
        self.entity_num_max_elements: int = -1
        if entity_num_max_elements is not None:
            self.entity_num_max_elements = entity_num_max_elements
        self.embedding_size = embedding_size
        # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
        self.self_ent_encoder = LinearEncoder(
            self.entity_size,
            1,
            self.embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / self.embedding_size) ** 0.5,
        )

    def add_self_embedding(self, size: int) -> None:
        self.self_size = size
        self.self_ent_encoder = LinearEncoder(
            self.self_size + self.entity_size,
            1,
            self.embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / self.embedding_size) ** 0.5,
        )

    def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
        num_entities = self.entity_num_max_elements
        if num_entities < 0:
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements."
                )
            num_entities = entities.shape[1]

        if exporting_to_onnx.is_exporting():
            # When exporting to ONNX, we want to transpose the entities. This is
            # because ONNX only support input in NCHW (channel first) format.
            # Barracuda also expect to get data in NCHW.
            entities = torch.transpose(entities, 2, 1).reshape(
                -1, num_entities, self.entity_size
            )

        if self.self_size > 0:
            expanded_self = x_self.reshape(-1, 1, self.self_size)
            expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
            # Concatenate all observations with self
            entities = torch.cat([expanded_self, entities], dim=2)
        # Encode entities
        encoded_entities = self.self_ent_encoder(entities)
        return encoded_entities


class ResidualSelfAttention(torch.nn.Module):
    """
    Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used
    with an EntityEmbedding module, to apply multi head self attention to encode information
    about a "Self" and a list of relevant "Entities".
    """

    EPSILON = 1e-7

    def __init__(
        self,
        embedding_size: int,
        entity_num_max_elements: Optional[int] = None,
        num_heads: int = 4,
    ):
        """
        Constructs a ResidualSelfAttention module.
        :param embedding_size: Embedding sizee for attention mechanism and
            Q, K, V encoders.
        :param entity_num_max_elements: A List of ints representing the maximum number
            of elements in an entity sequence. Should be of length num_entities. Pass None to
            not restrict the number of elements; however, this will make the module
            unexportable to ONNX/Barracuda.
        :param num_heads: Number of heads for Multi Head Self-Attention
        """
        super().__init__()
        self.max_num_ent: Optional[int] = None
        if entity_num_max_elements is not None:
            self.max_num_ent = entity_num_max_elements

        self.attention = MultiHeadAttention(
            num_heads=num_heads, embedding_size=embedding_size
        )

        # Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
        self.fc_q = linear_layer(
            embedding_size,
            embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / embedding_size) ** 0.5,
        )
        self.fc_k = linear_layer(
            embedding_size,
            embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / embedding_size) ** 0.5,
        )
        self.fc_v = linear_layer(
            embedding_size,
            embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / embedding_size) ** 0.5,
        )
        self.fc_out = linear_layer(
            embedding_size,
            embedding_size,
            kernel_init=Initialization.Normal,
            kernel_gain=(0.125 / embedding_size) ** 0.5,
        )
        self.embedding_norm = LayerNorm()
        self.residual_norm = LayerNorm()

    def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor:
        # Gather the maximum number of entities information
        mask = torch.cat(key_masks, dim=1)

        inp = self.embedding_norm(inp)
        # Feed to self attention
        query = self.fc_q(inp)  # (b, n_q, emb)
        key = self.fc_k(inp)  # (b, n_k, emb)
        value = self.fc_v(inp)  # (b, n_k, emb)

        # Only use max num if provided
        if self.max_num_ent is not None:
            num_ent = self.max_num_ent
        else:
            num_ent = inp.shape[1]
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements."
                )

        output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
        # Residual
        output = self.fc_out(output) + inp
        output = self.residual_norm(output)
        # Average Pooling
        numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1)
        denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
        output = numerator / denominator
        return output