""" Overview: This file implements the core modules of GTrXL Transformer as described in "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764). """ from typing import Optional, Dict, List import warnings import numpy as np import torch import torch.nn as nn from ding.torch_utils.network.nn_module import fc_block, build_normalization, F class PositionalEmbedding(nn.Module): """ Overview: The PositionalEmbedding module implements the positional embedding used in the vanilla Transformer model. Interfaces: ``__init__``, ``forward`` .. note:: This implementation is adapted from https://github.com/kimiyoung/transformer-xl/blob/ \ master/pytorch/mem_transformer.py """ def __init__(self, embedding_dim: int): """ Overview: Initialize the PositionalEmbedding module. Arguments: - embedding_dim: (:obj:`int`): The dimensionality of the embeddings. """ super(PositionalEmbedding, self).__init__() self.embedding_dim = embedding_dim inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim)) # (embedding_dim / 2) self.register_buffer('inv_freq', inv_freq) def forward(self, pos_seq: torch.Tensor) -> torch.Tensor: """ Overview: Compute positional embedding given a sequence of positions. Arguments: - pos_seq (:obj:`torch.Tensor`): The positional sequence, \ typically a 1D tensor of integers in the form of [seq_len-1, seq_len-2, ..., 1, 0], Returns: - pos_embedding (:obj:`torch.Tensor`): The computed positional embeddings. \ The shape of the tensor is (seq_len, 1, embedding_dim). """ sinusoid_inp = torch.outer(pos_seq, self.inv_freq) # For position embedding, the order of sin/cos is negligible. # This is because tokens are consumed by the matrix multiplication which is permutation-invariant. pos_embedding = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) return pos_embedding.unsqueeze(1) class GRUGatingUnit(torch.nn.Module): """ Overview: The GRUGatingUnit module implements the GRU gating mechanism used in the GTrXL model. Interfaces: ``__init__``, ``forward`` """ def __init__(self, input_dim: int, bg: float = 2.): """ Overview: Initialize the GRUGatingUnit module. Arguments: - input_dim (:obj:`int`): The dimensionality of the input. - bg (:obj:`bg`): The gate bias. By setting bg > 0 we can explicitly initialize the gating mechanism to \ be close to the identity map. This can greatly improve the learning speed and stability since it \ initializes the agent close to a Markovian policy (ignore attention at the beginning). """ super(GRUGatingUnit, self).__init__() self.Wr = torch.nn.Linear(input_dim, input_dim, bias=False) self.Ur = torch.nn.Linear(input_dim, input_dim, bias=False) self.Wz = torch.nn.Linear(input_dim, input_dim, bias=False) self.Uz = torch.nn.Linear(input_dim, input_dim, bias=False) self.Wg = torch.nn.Linear(input_dim, input_dim, bias=False) self.Ug = torch.nn.Linear(input_dim, input_dim, bias=False) self.bg = nn.Parameter(torch.full([input_dim], bg)) # bias self.sigmoid = torch.nn.Sigmoid() self.tanh = torch.nn.Tanh() def forward(self, x: torch.Tensor, y: torch.Tensor): """ Overview: Compute the output value using the GRU gating mechanism. Arguments: - x: (:obj:`torch.Tensor`): The first input tensor. - y: (:obj:`torch.Tensor`): The second input tensor. \ x and y should have the same shape and their last dimension should match the input_dim. Returns: - g: (:obj:`torch.Tensor`): The output of the GRU gating mechanism. \ The shape of g matches the shapes of x and y. """ r = self.sigmoid(self.Wr(y) + self.Ur(x)) z = self.sigmoid(self.Wz(y) + self.Uz(x) - self.bg) h = self.tanh(self.Wg(y) + self.Ug(torch.mul(r, x))) # element wise multiplication g = torch.mul(1 - z, x) + torch.mul(z, h) return g # x.shape == y.shape == g.shape class Memory: """ Overview: A class that stores the context used to add memory to Transformer. Interfaces: ``__init__``, ``init``, ``update``, ``get``, ``to`` .. note:: For details, refer to Transformer-XL: https://arxiv.org/abs/1901.02860 """ def __init__( self, memory_len: int = 20, batch_size: int = 64, embedding_dim: int = 256, layer_num: int = 3, memory: Optional[torch.Tensor] = None ) -> None: """ Overview: Initialize the Memory module. Arguments: - memory_len (:obj:`int`): The dimension of memory, i.e., how many past observations to use as memory. - batch_size (:obj:`int`): The dimension of each batch. - embedding_dim (:obj:`int`): The dimension of embedding, which is the dimension of a single observation \ after embedding. - layer_num (:obj:`int`): The number of transformer layers. - memory (:obj:`Optional[torch.Tensor]`): The initial memory. Default is None. """ super(Memory, self).__init__() self.embedding_dim = embedding_dim self.bs = batch_size self.layer_num = layer_num self.memory_len = memory_len self.memory = None self.init(memory) def init(self, memory: Optional[torch.Tensor] = None): """ Overview: Initialize memory with an input list of tensors or create it automatically given its dimensions. Arguments: - memory (:obj:`Optional[torch.Tensor]`): Input memory tensor with shape \ (layer_num, memory_len, bs, embedding_dim). Its shape is (layer_num, memory_len, bs, embedding_dim), \ where memory_len is length of memory, bs is batch size and embedding_dim is the dimension of embedding. """ if memory is not None: self.memory = memory layer_num_plus1, self.memory_len, self.bs, self.embedding_dim = memory.shape self.layer_num = layer_num_plus1 - 1 else: self.memory = torch.zeros( self.layer_num + 1, self.memory_len, self.bs, self.embedding_dim, dtype=torch.float ) def update(self, hidden_state: List[torch.Tensor]): """ Overview: Update the memory given a sequence of hidden states. Example for single layer: memory_len=3, hidden_size_len=2, bs=3 m00 m01 m02 h00 h01 h02 m20 m21 m22 m = m10 m11 m12 h = h10 h11 h12 => new_m = h00 h01 h02 m20 m21 m22 h10 h11 h12 Arguments: - hidden_state: (:obj:`List[torch.Tensor]`): The hidden states to update the memory. \ Each tensor in the list has shape (cur_seq, bs, embedding_dim), where cur_seq \ is the length of the sequence. Returns: - memory: (:obj:`Optional[torch.Tensor]`): The updated memory, with shape \ (layer_num, memory_len, bs, embedding_dim). """ if self.memory is None or hidden_state is None: raise ValueError('Failed to update memory! Memory would be None') # TODO add support of no memory sequence_len = hidden_state[0].shape[0] with torch.no_grad(): new_memory = [] end = self.memory_len + sequence_len beg = max(0, end - self.memory_len) for i in range(self.layer_num + 1): m = self.memory[i] h = hidden_state[i] cat = torch.cat([m, h], dim=0) new_memory.append(cat[beg:end].detach()) new_memory = torch.stack(new_memory, dim=0) self.memory = new_memory return new_memory def get(self): """ Overview: Get the current memory. Returns: - memory: (:obj:`Optional[torch.Tensor]`): The current memory, \ with shape (layer_num, memory_len, bs, embedding_dim). """ return self.memory def to(self, device: str = 'cpu'): """ Overview: Move the current memory to the specified device. Arguments: device (:obj:`str`): The device to move the memory to. Default is 'cpu'. """ self.memory = self.memory.to(device) class AttentionXL(torch.nn.Module): """ Overview: An implementation of the Attention mechanism used in the TransformerXL model. Interfaces: ``__init__``, ``forward`` """ def __init__(self, input_dim: int, head_dim: int, head_num: int, dropout: nn.Module) -> None: """ Overview: Initialize the AttentionXL module. Arguments: - input_dim (:obj:`int`): The dimensionality of the input features. - head_dim (:obj:`int`): The dimensionality of each attention head. - head_num (:obj:`int`): The number of attention heads. - dropout (:obj:`nn.Module`): The dropout layer to use """ super(AttentionXL, self).__init__() self.head_num = head_num self.head_dim = head_dim self.dropout = dropout self.attention_kv = fc_block(input_dim, head_dim * head_num * 2) # key, value self.attention_q = fc_block(input_dim, head_dim * head_num) # query (not computed with past hidden states) self.project = fc_block(head_dim * head_num, input_dim) # project attention output back to input_dim self.project_pos = fc_block(input_dim, head_dim * head_num) # project the positional embedding self.scale = 1 / (head_dim ** 0.5) # for scaled dot product attention def _rel_shift(self, x: torch.Tensor, zero_upper: bool = False) -> torch.Tensor: """ Overview: Perform a relative shift operation on the attention score matrix. Example: a00 a01 a02 0 a00 a01 a02 0 a00 a01 a02 0 a10 a02 0 0 a10 a11 a12 => 0 a10 a11 a12 => a02 0 a10 => a11 a12 0 => a11 a12 0 a20 a21 a22 0 a20 a21 a22 a11 a12 0 a20 a21 a22 a20 a21 a22 a20 a21 a22 1) Append one "column" of zeros to the left 2) Reshape the matrix from [3 x 4] into [4 x 3] 3) Remove the first "row" 4) Mask out the upper triangle (optional) .. note:: See the following material for better understanding: https://github.com/kimiyoung/transformer-xl/issues/8 https://arxiv.org/pdf/1901.02860.pdf (Appendix B) Arguments: - x (:obj:`torch.Tensor`): The input tensor with shape (cur_seq, full_seq, bs, head_num). - zero_upper (:obj:`bool`): If True, the upper-right triangle of the matrix is set to zero. Returns: - x (:obj:`torch.Tensor`): The input tensor after the relative shift operation, \ with shape (cur_seq, full_seq, bs, head_num). """ x_padded = F.pad(x, [1, 0]) # step 1 x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) # step 2 x = x_padded[:, :, 1:].view_as(x) # step 3 if zero_upper: ones = torch.ones((x.size(2), x.size(3))).unsqueeze(0).unsqueeze(0) x = x * torch.tril(ones.to(x.device), x.size(3) - x.size(2)) # step 4 return x def forward( self, inputs: torch.Tensor, pos_embedding: torch.Tensor, full_input: torch.Tensor, u: torch.nn.Parameter, v: torch.nn.Parameter, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Overview: Compute the forward pass for the AttentionXL module. Arguments: - inputs (:obj:`torch.Tensor`): The attention input with shape (cur_seq, bs, input_dim). - pos_embedding (:obj:`torch.Tensor`): The positional embedding with shape (full_seq, 1, full_seq). - full_input (:obj:`torch.Tensor`): The concatenated memory and input tensor with shape \ (full_seq, bs, input_dim). - u (:obj:`torch.nn.Parameter`): The content parameter with shape (head_num, head_dim). - v (:obj:`torch.nn.Parameter`): The position parameter with shape (head_num, head_dim). - mask (:obj:`Optional[torch.Tensor]`): The attention mask with shape (cur_seq, full_seq, 1). \ If None, no masking is applied. Returns: - output (:obj:`torch.Tensor`): The output of the attention mechanism with shape (cur_seq, bs, input_dim). """ bs, cur_seq, full_seq = inputs.shape[1], inputs.shape[0], full_input.shape[0] prev_seq = full_seq - cur_seq kv = self.attention_kv(full_input) key, value = torch.chunk(kv, 2, dim=-1) # full_seq x bs x num_head*dim_head query = self.attention_q(inputs) # cur_seq x bs x num_head*dim_head r = self.project_pos(pos_embedding) # full_seq x 1 x num_head*dim_head key = key.view(full_seq, bs, self.head_num, self.head_dim) query = query.view(cur_seq, bs, self.head_num, self.head_dim) value = value.view(cur_seq + prev_seq, bs, self.head_num, self.head_dim) r = r.view(full_seq, self.head_num, self.head_dim) # (query + u) * key^T q_u = query + u content_attn = q_u.permute(1, 2, 0, 3) @ key.permute(1, 2, 3, 0) # bs x head_num x cur_seq x full_seq # (query + v) * R^T q_v = query + v position_attn = q_v.permute(1, 2, 0, 3) @ r.permute(1, 2, 0) # bs x head_num x cur_seq x full_seq position_attn = self._rel_shift(position_attn) attn = content_attn + position_attn # bs x head_num x cur_seq x full_seq attn.mul_(self.scale) # fills float('-inf') where mask is True to let softmax ignore those positions. if mask is not None and mask.any().item(): mask = mask.permute(2, 0, 1).unsqueeze(1) # 1 x 1 x cur_seq x full_seq assert mask.shape[2:] == attn.shape[2:] # check shape of mask attn = attn.masked_fill(mask, -float("inf")).type_as(attn) attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) # multiply softmax output by value attn_vec = attn @ value.permute(1, 2, 0, 3) attn_vec = attn_vec.permute(2, 0, 1, 3) attn_vec = attn_vec.contiguous().view(cur_seq, bs, self.head_num * self.head_dim) # cur_seq x bs x head_num * head_dim output = self.dropout(self.project(attn_vec)) # cur_seq x bs x input_dim return output class GatedTransformerXLLayer(torch.nn.Module): """ Overview: This class implements the attention layer of GTrXL (Gated Transformer-XL). Interfaces: ``__init__``, ``forward`` """ def __init__( self, input_dim: int, head_dim: int, hidden_dim: int, head_num: int, mlp_num: int, dropout: nn.Module, activation: nn.Module, gru_gating: bool = True, gru_bias: float = 2. ) -> None: """ Overview: Initialize GatedTransformerXLLayer. Arguments: - input_dim (:obj:`int`): The dimension of the input tensor. - head_dim (:obj:`int`): The dimension of each head in the multi-head attention. - hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP. - head_num (:obj:`int`): The number of heads for the multi-head attention. - mlp_num (:obj:`int`): The number of MLP layers in the attention layer. - dropout (:obj:`nn.Module`): The dropout module used in the MLP and attention layers. - activation (:obj:`nn.Module`): The activation function to be used in the MLP layers. - gru_gating (:obj:`bool`, optional): Whether to use GRU gates. If False, replace GRU gates with \ residual connections. Default is True. - gru_bias (:obj:`float`, optional): The bias of the GRU gate. Default is 2. """ super(GatedTransformerXLLayer, self).__init__() self.dropout = dropout self.gating = gru_gating if self.gating is True: self.gate1 = GRUGatingUnit(input_dim, gru_bias) self.gate2 = GRUGatingUnit(input_dim, gru_bias) self.attention = AttentionXL( input_dim, head_dim, head_num, dropout, ) layers = [] dims = [input_dim] + [hidden_dim] * (mlp_num - 1) + [input_dim] for i in range(mlp_num): layers.append(fc_block(dims[i], dims[i + 1], activation=activation)) if i != mlp_num - 1: layers.append(self.dropout) layers.append(self.dropout) self.mlp = nn.Sequential(*layers) self.layernorm1 = build_normalization('LN')(input_dim) self.layernorm2 = build_normalization('LN')(input_dim) self.activation = activation def forward( self, inputs: torch.Tensor, pos_embedding: torch.Tensor, u: torch.nn.Parameter, v: torch.nn.Parameter, memory: torch.Tensor, mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Overview: Compute forward pass of GTrXL layer. Arguments: - inputs (:obj:`torch.Tensor`): The attention input tensor of shape (cur_seq, bs, input_dim). - pos_embedding (:obj:`torch.Tensor`): The positional embedding tensor of shape (full_seq, 1, full_seq). - u (:obj:`torch.nn.Parameter`): The content parameter tensor of shape (head_num, head_dim). - v (:obj:`torch.nn.Parameter`): The position parameter tensor of shape (head_num, head_dim). - memory (:obj:`torch.Tensor`): The memory tensor of shape (prev_seq, bs, input_dim). - mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor of shape (cur_seq, full_seq, 1). Default is None. Returns: - output (:obj:`torch.Tensor`): layer output of shape (cur_seq, bs, input_dim) """ # concat memory with input across sequence dimension full_input = torch.cat([memory, inputs], dim=0) # full_seq x bs x input_dim x1 = self.layernorm1(full_input) a1 = self.dropout(self.attention(inputs, pos_embedding, x1, u, v, mask=mask)) a1 = self.activation(a1) # RELU after attention o1 = self.gate1(inputs, a1) if self.gating else inputs + a1 x2 = self.layernorm2(o1) m2 = self.dropout(self.mlp(x2)) o2 = self.gate2(o1, m2) if self.gating else o1 + m2 return o2 class GTrXL(nn.Module): """ Overview: GTrXL Transformer implementation as described in "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764). Interfaces: ``__init__``, ``forward``, ``reset_memory``, ``get_memory`` """ def __init__( self, input_dim: int, head_dim: int = 128, embedding_dim: int = 256, head_num: int = 2, mlp_num: int = 2, layer_num: int = 3, memory_len: int = 64, dropout_ratio: float = 0., activation: nn.Module = nn.ReLU(), gru_gating: bool = True, gru_bias: float = 2., use_embedding_layer: bool = True, ) -> None: """Overview: Init GTrXL Model. Arguments: - input_dim (:obj:`int`): The dimension of the input observation. - head_dim (:obj:`int`, optional): The dimension of each head. Default is 128. - embedding_dim (:obj:`int`, optional): The dimension of the embedding. Default is 256. - head_num (:obj:`int`, optional): The number of heads for multi-head attention. Default is 2. - mlp_num (:obj:`int`, optional): The number of MLP layers in the attention layer. Default is 2. - layer_num (:obj:`int`, optional): The number of transformer layers. Default is 3. - memory_len (:obj:`int`, optional): The length of memory. Default is 64. - dropout_ratio (:obj:`float`, optional): The dropout ratio. Default is 0. - activation (:obj:`nn.Module`, optional): The activation function. Default is nn.ReLU(). - gru_gating (:obj:`bool`, optional): If False, replace GRU gates with residual connections. \ Default is True. - gru_bias (:obj:`float`, optional): The GRU gate bias. Default is 2.0. - use_embedding_layer (:obj:`bool`, optional): If False, don't use input embedding layer. Default is True. Raises: - AssertionError: If `embedding_dim` is not an even number. """ super(GTrXL, self).__init__() assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim) self.head_num = head_num self.head_dim = head_dim self.layer_num = layer_num if isinstance(input_dim, list): input_dim = np.prod(input_dim) self.use_embedding_layer = use_embedding_layer if use_embedding_layer: self.embedding = fc_block(input_dim, embedding_dim, activation=activation) self.activation = activation self.pos_embedding = PositionalEmbedding(embedding_dim) # memory to save hidden states of past segments # it will be initialized in the forward method to get its size dynamically self.memory = None self.memory_len = memory_len layers = [] dims = [embedding_dim] + [embedding_dim] * layer_num self.dropout = nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Identity() for i in range(layer_num): layers.append( GatedTransformerXLLayer( dims[i], head_dim, embedding_dim, head_num, mlp_num, self.dropout, self.activation, gru_gating, gru_bias ) ) self.layers = nn.Sequential(*layers) self.embedding_dim = embedding_dim # u and v are the parameters to compute global content bias and global positional bias self.u, self.v = ( torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)), ) self.att_mask = {} # create an attention mask for each different seq_len, in this way we don't need to create a # new one each time we call the forward method self.pos_embedding_dict = {} # create a pos embedding for each different seq_len def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None): """ Overview: Clear or set the memory of GTrXL. Arguments: - batch_size (:obj:`Optional[int]`): The batch size. Default is None. - state (:obj:`Optional[torch.Tensor]`): The input memory with shape \ (layer_num, memory_len, bs, embedding_dim). Default is None. """ self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim) if batch_size is not None: self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num) elif state is not None: self.memory.init(state) def get_memory(self): """ Overview: Returns the memory of GTrXL. Returns: - memory (:obj:`Optional[torch.Tensor]`): The output memory or None if memory has not been initialized. \ The shape is (layer_num, memory_len, bs, embedding_dim). """ if self.memory is None: return None else: return self.memory.get() def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = True) -> Dict[str, torch.Tensor]: """ Overview: Performs a forward pass on the GTrXL. Arguments: - x (:obj:`torch.Tensor`): The input tensor with shape (seq_len, bs, input_size). - batch_first (:obj:`bool`, optional): If the input data has shape (bs, seq_len, input_size), \ set this parameter to True to transpose along the first and second dimension and obtain shape \ (seq_len, bs, input_size). This does not affect the output memory. Default is False. \ - return_mem (:obj:`bool`, optional): If False, return only the output tensor without dict. Default is True. Returns: - x (:obj:`Dict[str, torch.Tensor]`): A dictionary containing the transformer output of shape \ (seq_len, bs, embedding_size) and memory of shape (layer_num, seq_len, bs, embedding_size). """ if batch_first: x = torch.transpose(x, 1, 0) # bs x cur_seq x input_dim -> cur_seq x bs x input_dim cur_seq, bs = x.shape[:2] memory = None if self.memory is None else self.memory.get() if memory is None: self.reset_memory(bs) # (layer_num+1) x memory_len x batch_size x embedding_dim elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim: warnings.warn( "Memory {} and Input {} dimensions don't match," " this will cause the memory to be initialized to fit your input!".format( list(memory.shape[-2:]), [x.shape[-2]] + [self.embedding_dim] ) ) self.reset_memory(bs) self.memory.to(x.device) memory = self.memory.get() if self.use_embedding_layer: x = self.dropout(self.embedding(x)) prev_seq = self.memory_len full_seq = cur_seq + prev_seq if cur_seq in self.att_mask.keys(): attn_mask = self.att_mask[cur_seq] else: attn_mask = ( torch.triu( torch.ones((cur_seq, full_seq)), diagonal=1 + prev_seq, # fixed in train, eval, collect ).bool().unsqueeze(-1).to(x.device) ) # cur_seq x full_seq x 1 self.att_mask[cur_seq] = attn_mask if cur_seq in self.pos_embedding_dict.keys(): pos_embedding = self.pos_embedding_dict[cur_seq] else: pos_ips = torch.arange(full_seq - 1, -1, -1.0, dtype=torch.float) # full_seq pos_embedding = self.pos_embedding(pos_ips.to(x.device)) self.pos_embedding_dict[cur_seq] = pos_embedding pos_embedding = self.dropout(pos_embedding) # full_seq x 1 x embedding_dim hidden_state = [x] out = x for i in range(self.layer_num): layer = self.layers[i] out = layer( out, pos_embedding, self.u, self.v, mask=attn_mask, memory=memory[i], # (layer_num+1) x memory_len x batch_size x embedding_dim ) # cur_seq x bs x embedding_dim hidden_state.append(out.clone()) out = self.dropout(out) self.memory.update(hidden_state) # (layer_num+1) x memory_len x batch_size x embedding_dim if batch_first: out = torch.transpose(out, 1, 0) # cur_seq x bs x embedding_dim -> bs x cur_seq x embedding_dim if return_mem: output = {"logit": out, "memory": memory} # return the content of the memory before the last update else: output = {"logit": out} return output