AnnaMats's picture
Second Push
05c9ac2
from mlagents.torch_utils import torch
import warnings
from typing import Tuple, Optional, List
from mlagents.trainers.torch_entities.layers import (
LinearEncoder,
Initialization,
linear_layer,
LayerNorm,
)
from mlagents.trainers.torch_entities.model_serialization import exporting_to_onnx
from mlagents.trainers.exception import UnityTrainerException
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
if exporting_to_onnx.is_exporting():
with warnings.catch_warnings():
# We ignore a TracerWarning from PyTorch that warns that doing
# shape[n].item() will cause the trace to be incorrect (the trace might
# not generalize to other inputs)
# We ignore this warning because we know the model will always be
# run with inputs of the same shape
warnings.simplefilter("ignore")
# When exporting to ONNX, we want to transpose the entities. This is
# because ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
entities = [
torch.transpose(obs, 2, 1).reshape(
-1, obs.shape[1].item(), obs.shape[2].item()
)
for obs in entities
]
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
]
return key_masks
class MultiHeadAttention(torch.nn.Module):
NEG_INF = -1e6
def __init__(self, embedding_size: int, num_heads: int):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
"""
super().__init__()
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads
self.embedding_size: int = self.head_size * self.n_heads
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
n_q: int,
n_k: int,
key_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
b = -1 # the batch size
query = query.reshape(
b, n_q, self.n_heads, self.head_size
) # (b, n_q, h, emb / h)
key = key.reshape(b, n_k, self.n_heads, self.head_size) # (b, n_k, h, emb / h)
value = value.reshape(
b, n_k, self.n_heads, self.head_size
) # (b, n_k, h, emb / h)
query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb / h)
# The next few lines are equivalent to : key.permute([0, 2, 3, 1])
# This is a hack, ONNX will compress two permute operations and
# Barracuda will not like seeing `permute([0,2,3,1])`
key = key.permute([0, 2, 1, 3]) # (b, h, emb / h, n_k)
key -= 1
key += 1
key = key.permute([0, 1, 3, 2]) # (b, h, emb / h, n_k)
qk = torch.matmul(query, key) # (b, h, n_q, n_k)
if key_mask is None:
qk = qk / (self.embedding_size**0.5)
else:
key_mask = key_mask.reshape(b, 1, 1, n_k)
qk = (1 - key_mask) * qk / (
self.embedding_size**0.5
) + key_mask * self.NEG_INF
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)
value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb / h)
value_attention = torch.matmul(att, value) # (b, h, n_q, emb / h)
value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb / h)
value_attention = value_attention.reshape(
b, n_q, self.embedding_size
) # (b, n_q, emb)
return value_attention, att
class EntityEmbedding(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self
and additional entities. Can also concatenate self to entities for ego-centric self-
attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf.
"""
def __init__(
self,
entity_size: int,
entity_num_max_elements: Optional[int],
embedding_size: int,
):
"""
Constructs an EntityEmbedding module.
:param x_self_size: Size of "self" entity.
:param entity_size: Size of other entities.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
:param embedding_size: Embedding size for the entity encoder.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
self-attention.
"""
super().__init__()
self.self_size: int = 0
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
if entity_num_max_elements is not None:
self.entity_num_max_elements = entity_num_max_elements
self.embedding_size = embedding_size
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.self_ent_encoder = LinearEncoder(
self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
)
def add_self_embedding(self, size: int) -> None:
self.self_size = size
self.self_ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
)
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
if exporting_to_onnx.is_exporting():
# When exporting to ONNX, we want to transpose the entities. This is
# because ONNX only support input in NCHW (channel first) format.
# Barracuda also expect to get data in NCHW.
entities = torch.transpose(entities, 2, 1).reshape(
-1, num_entities, self.entity_size
)
if self.self_size > 0:
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
# Concatenate all observations with self
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.self_ent_encoder(entities)
return encoded_entities
class ResidualSelfAttention(torch.nn.Module):
"""
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used
with an EntityEmbedding module, to apply multi head self attention to encode information
about a "Self" and a list of relevant "Entities".
"""
EPSILON = 1e-7
def __init__(
self,
embedding_size: int,
entity_num_max_elements: Optional[int] = None,
num_heads: int = 4,
):
"""
Constructs a ResidualSelfAttention module.
:param embedding_size: Embedding sizee for attention mechanism and
Q, K, V encoders.
:param entity_num_max_elements: A List of ints representing the maximum number
of elements in an entity sequence. Should be of length num_entities. Pass None to
not restrict the number of elements; however, this will make the module
unexportable to ONNX/Barracuda.
:param num_heads: Number of heads for Multi Head Self-Attention
"""
super().__init__()
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
self.max_num_ent = entity_num_max_elements
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size
)
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.fc_q = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_k = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_v = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.fc_out = linear_layer(
embedding_size,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
self.residual_norm = LayerNorm()
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor:
# Gather the maximum number of entities information
mask = torch.cat(key_masks, dim=1)
inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)
value = self.fc_v(inp) # (b, n_k, emb)
# Only use max num if provided
if self.max_num_ent is not None:
num_ent = self.max_num_ent
else:
num_ent = inp.shape[1]
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
# Residual
output = self.fc_out(output) + inp
output = self.residual_norm(output)
# Average Pooling
numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1)
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
output = numerator / denominator
return output