|
from mlagents.torch_utils import torch |
|
import warnings |
|
from typing import Tuple, Optional, List |
|
from mlagents.trainers.torch_entities.layers import ( |
|
LinearEncoder, |
|
Initialization, |
|
linear_layer, |
|
LayerNorm, |
|
) |
|
from mlagents.trainers.torch_entities.model_serialization import exporting_to_onnx |
|
from mlagents.trainers.exception import UnityTrainerException |
|
|
|
|
|
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]: |
|
""" |
|
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was |
|
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention |
|
layer to mask the padding observations. |
|
""" |
|
with torch.no_grad(): |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
with warnings.catch_warnings(): |
|
|
|
|
|
|
|
|
|
|
|
warnings.simplefilter("ignore") |
|
|
|
|
|
|
|
entities = [ |
|
torch.transpose(obs, 2, 1).reshape( |
|
-1, obs.shape[1].item(), obs.shape[2].item() |
|
) |
|
for obs in entities |
|
] |
|
|
|
|
|
key_masks: List[torch.Tensor] = [ |
|
(torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities |
|
] |
|
return key_masks |
|
|
|
|
|
class MultiHeadAttention(torch.nn.Module): |
|
|
|
NEG_INF = -1e6 |
|
|
|
def __init__(self, embedding_size: int, num_heads: int): |
|
""" |
|
Multi Head Attention module. We do not use the regular Torch implementation since |
|
Barracuda does not support some operators it uses. |
|
Takes as input to the forward method 3 tensors: |
|
- query: of dimensions (batch_size, number_of_queries, embedding_size) |
|
- key: of dimensions (batch_size, number_of_keys, embedding_size) |
|
- value: of dimensions (batch_size, number_of_keys, embedding_size) |
|
The forward method will return 2 tensors: |
|
- The output: (batch_size, number_of_queries, embedding_size) |
|
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys) |
|
:param embedding_size: The size of the embeddings that will be generated (should be |
|
dividable by the num_heads) |
|
:param total_max_elements: The maximum total number of entities that can be passed to |
|
the module |
|
:param num_heads: The number of heads of the attention module |
|
""" |
|
super().__init__() |
|
self.n_heads = num_heads |
|
self.head_size: int = embedding_size // self.n_heads |
|
self.embedding_size: int = self.head_size * self.n_heads |
|
|
|
def forward( |
|
self, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
n_q: int, |
|
n_k: int, |
|
key_mask: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
b = -1 |
|
|
|
query = query.reshape( |
|
b, n_q, self.n_heads, self.head_size |
|
) |
|
key = key.reshape(b, n_k, self.n_heads, self.head_size) |
|
value = value.reshape( |
|
b, n_k, self.n_heads, self.head_size |
|
) |
|
|
|
query = query.permute([0, 2, 1, 3]) |
|
|
|
|
|
|
|
key = key.permute([0, 2, 1, 3]) |
|
key -= 1 |
|
key += 1 |
|
key = key.permute([0, 1, 3, 2]) |
|
|
|
qk = torch.matmul(query, key) |
|
|
|
if key_mask is None: |
|
qk = qk / (self.embedding_size**0.5) |
|
else: |
|
key_mask = key_mask.reshape(b, 1, 1, n_k) |
|
qk = (1 - key_mask) * qk / ( |
|
self.embedding_size**0.5 |
|
) + key_mask * self.NEG_INF |
|
|
|
att = torch.softmax(qk, dim=3) |
|
|
|
value = value.permute([0, 2, 1, 3]) |
|
value_attention = torch.matmul(att, value) |
|
|
|
value_attention = value_attention.permute([0, 2, 1, 3]) |
|
value_attention = value_attention.reshape( |
|
b, n_q, self.embedding_size |
|
) |
|
|
|
return value_attention, att |
|
|
|
|
|
class EntityEmbedding(torch.nn.Module): |
|
""" |
|
A module used to embed entities before passing them to a self-attention block. |
|
Used in conjunction with ResidualSelfAttention to encode information about a self |
|
and additional entities. Can also concatenate self to entities for ego-centric self- |
|
attention. Inspired by architecture used in https://arxiv.org/pdf/1909.07528.pdf. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
entity_size: int, |
|
entity_num_max_elements: Optional[int], |
|
embedding_size: int, |
|
): |
|
""" |
|
Constructs an EntityEmbedding module. |
|
:param x_self_size: Size of "self" entity. |
|
:param entity_size: Size of other entities. |
|
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted. |
|
Needs to be assigned in order for model to be exportable to ONNX and Barracuda. |
|
:param embedding_size: Embedding size for the entity encoder. |
|
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric |
|
self-attention. |
|
""" |
|
super().__init__() |
|
self.self_size: int = 0 |
|
self.entity_size: int = entity_size |
|
self.entity_num_max_elements: int = -1 |
|
if entity_num_max_elements is not None: |
|
self.entity_num_max_elements = entity_num_max_elements |
|
self.embedding_size = embedding_size |
|
|
|
self.self_ent_encoder = LinearEncoder( |
|
self.entity_size, |
|
1, |
|
self.embedding_size, |
|
kernel_init=Initialization.Normal, |
|
kernel_gain=(0.125 / self.embedding_size) ** 0.5, |
|
) |
|
|
|
def add_self_embedding(self, size: int) -> None: |
|
self.self_size = size |
|
self.self_ent_encoder = LinearEncoder( |
|
self.self_size + self.entity_size, |
|
1, |
|
self.embedding_size, |
|
kernel_init=Initialization.Normal, |
|
kernel_gain=(0.125 / self.embedding_size) ** 0.5, |
|
) |
|
|
|
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: |
|
num_entities = self.entity_num_max_elements |
|
if num_entities < 0: |
|
if exporting_to_onnx.is_exporting(): |
|
raise UnityTrainerException( |
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
number of elements." |
|
) |
|
num_entities = entities.shape[1] |
|
|
|
if exporting_to_onnx.is_exporting(): |
|
|
|
|
|
|
|
entities = torch.transpose(entities, 2, 1).reshape( |
|
-1, num_entities, self.entity_size |
|
) |
|
|
|
if self.self_size > 0: |
|
expanded_self = x_self.reshape(-1, 1, self.self_size) |
|
expanded_self = torch.cat([expanded_self] * num_entities, dim=1) |
|
|
|
entities = torch.cat([expanded_self, entities], dim=2) |
|
|
|
encoded_entities = self.self_ent_encoder(entities) |
|
return encoded_entities |
|
|
|
|
|
class ResidualSelfAttention(torch.nn.Module): |
|
""" |
|
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used |
|
with an EntityEmbedding module, to apply multi head self attention to encode information |
|
about a "Self" and a list of relevant "Entities". |
|
""" |
|
|
|
EPSILON = 1e-7 |
|
|
|
def __init__( |
|
self, |
|
embedding_size: int, |
|
entity_num_max_elements: Optional[int] = None, |
|
num_heads: int = 4, |
|
): |
|
""" |
|
Constructs a ResidualSelfAttention module. |
|
:param embedding_size: Embedding sizee for attention mechanism and |
|
Q, K, V encoders. |
|
:param entity_num_max_elements: A List of ints representing the maximum number |
|
of elements in an entity sequence. Should be of length num_entities. Pass None to |
|
not restrict the number of elements; however, this will make the module |
|
unexportable to ONNX/Barracuda. |
|
:param num_heads: Number of heads for Multi Head Self-Attention |
|
""" |
|
super().__init__() |
|
self.max_num_ent: Optional[int] = None |
|
if entity_num_max_elements is not None: |
|
self.max_num_ent = entity_num_max_elements |
|
|
|
self.attention = MultiHeadAttention( |
|
num_heads=num_heads, embedding_size=embedding_size |
|
) |
|
|
|
|
|
self.fc_q = linear_layer( |
|
embedding_size, |
|
embedding_size, |
|
kernel_init=Initialization.Normal, |
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
) |
|
self.fc_k = linear_layer( |
|
embedding_size, |
|
embedding_size, |
|
kernel_init=Initialization.Normal, |
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
) |
|
self.fc_v = linear_layer( |
|
embedding_size, |
|
embedding_size, |
|
kernel_init=Initialization.Normal, |
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
) |
|
self.fc_out = linear_layer( |
|
embedding_size, |
|
embedding_size, |
|
kernel_init=Initialization.Normal, |
|
kernel_gain=(0.125 / embedding_size) ** 0.5, |
|
) |
|
self.embedding_norm = LayerNorm() |
|
self.residual_norm = LayerNorm() |
|
|
|
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: |
|
|
|
mask = torch.cat(key_masks, dim=1) |
|
|
|
inp = self.embedding_norm(inp) |
|
|
|
query = self.fc_q(inp) |
|
key = self.fc_k(inp) |
|
value = self.fc_v(inp) |
|
|
|
|
|
if self.max_num_ent is not None: |
|
num_ent = self.max_num_ent |
|
else: |
|
num_ent = inp.shape[1] |
|
if exporting_to_onnx.is_exporting(): |
|
raise UnityTrainerException( |
|
"Trying to export an attention mechanism that doesn't have a set max \ |
|
number of elements." |
|
) |
|
|
|
output, _ = self.attention(query, key, value, num_ent, num_ent, mask) |
|
|
|
output = self.fc_out(output) + inp |
|
output = self.residual_norm(output) |
|
|
|
numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1) |
|
denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON |
|
output = numerator / denominator |
|
return output |
|
|