""" Implementation of "Attention is All You Need" """ import torch.nn as nn from onmt.encoders.encoder import EncoderBase from onmt.modules import MultiHeadedAttention from onmt.modules.position_ffn import PositionwiseFeedForward from onmt.modules.position_ffn import ActivationFunction from onmt.utils.misc import sequence_mask from onmt.modules.rmsnorm import RMSNorm class TransformerEncoderLayer(nn.Module): """ A single layer of the transformer encoder. Args: d_model (int): the dimension of keys/values/queries in MultiHeadedAttention, also the input size of the first-layer of the PositionwiseFeedForward. heads (int): the number of head for MultiHeadedAttention. d_ff (int): the second-layer of the PositionwiseFeedForward. dropout (float): dropout probability(0-1.0). pos_ffn_activation_fn (ActivationFunction): activation function choice for PositionwiseFeedForward layer """ def __init__( self, d_model, heads, d_ff, dropout, attention_dropout, max_relative_positions=0, relative_positions_buckets=0, pos_ffn_activation_fn=ActivationFunction.relu, add_qkvbias=False, num_kv=0, add_ffnbias=True, parallel_residual=False, layer_norm="standard", norm_eps=1e-6, use_ckpting=[], parallel_gpu=1, ): super(TransformerEncoderLayer, self).__init__() self.self_attn = MultiHeadedAttention( heads, d_model, dropout=attention_dropout, is_decoder=False, max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, attn_type="self", add_qkvbias=add_qkvbias, num_kv=num_kv, use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, ) self.feed_forward = PositionwiseFeedForward( d_model, d_ff, dropout, pos_ffn_activation_fn, add_ffnbias, parallel_residual, layer_norm, norm_eps, use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, ) self.parallel_residual = parallel_residual if layer_norm == "standard": self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps) elif layer_norm == "rms": self.layer_norm = RMSNorm(d_model, eps=norm_eps) else: raise ValueError(f"{layer_norm} layer norm type is not supported") self.dropout = nn.Dropout(dropout) def forward(self, layer_in, mask): """ Args: layer_in (FloatTensor): ``(batch_size, src_len, model_dim)`` mask (LongTensor): ``(batch_size, 1, src_len)`` Returns: (FloatTensor): * layer_out ``(batch_size, src_len, model_dim)`` """ norm_layer_in = self.layer_norm(layer_in) context, _ = self.self_attn( norm_layer_in, norm_layer_in, norm_layer_in, mask=mask ) if self.parallel_residual: # feed_forward applies residual, so we remove and apply residual with un-normed layer_out = ( self.feed_forward(norm_layer_in) - norm_layer_in + layer_in + self.dropout(context) ) else: layer_out = self.dropout(context) + layer_in layer_out = self.feed_forward(layer_out) return layer_out def update_dropout(self, dropout, attention_dropout): self.self_attn.update_dropout(attention_dropout) self.feed_forward.update_dropout(dropout) self.dropout.p = dropout class TransformerEncoder(EncoderBase): """The Transformer encoder from "Attention is All You Need" :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` Args: num_layers (int): number of encoder layers d_model (int): size of the model heads (int): number of heads d_ff (int): size of the inner FF layer dropout (float): dropout parameters embeddings (onmt.modules.Embeddings): embeddings to use, should have positional encodings pos_ffn_activation_fn (ActivationFunction): activation function choice for PositionwiseFeedForward layer Returns: (torch.FloatTensor, torch.FloatTensor): * enc_out ``(batch_size, src_len, model_dim)`` * encoder final state: None in the case of Transformer * src_len ``(batch_size)`` """ def __init__( self, num_layers, d_model, heads, d_ff, dropout, attention_dropout, embeddings, max_relative_positions, relative_positions_buckets, pos_ffn_activation_fn=ActivationFunction.relu, add_qkvbias=False, num_kv=0, add_ffnbias=True, parallel_residual=False, layer_norm="standard", norm_eps=1e-6, use_ckpting=[], parallel_gpu=1, ): super(TransformerEncoder, self).__init__() self.embeddings = embeddings self.transformer = nn.ModuleList( [ TransformerEncoderLayer( d_model, heads, d_ff, dropout, attention_dropout, max_relative_positions=max_relative_positions, relative_positions_buckets=relative_positions_buckets, pos_ffn_activation_fn=pos_ffn_activation_fn, add_qkvbias=add_qkvbias, num_kv=num_kv, add_ffnbias=add_ffnbias, parallel_residual=parallel_residual, layer_norm=layer_norm, norm_eps=norm_eps, use_ckpting=use_ckpting, parallel_gpu=parallel_gpu, ) for i in range(num_layers) ] ) if layer_norm == "standard": self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps) elif layer_norm == "rms": self.layer_norm = RMSNorm(d_model, eps=norm_eps) else: raise ValueError(f"{layer_norm} layer norm type is not supported") @classmethod def from_opt(cls, opt, embeddings): """Alternate constructor.""" return cls( opt.enc_layers, opt.enc_hid_size, opt.heads, opt.transformer_ff, opt.dropout[0] if type(opt.dropout) is list else opt.dropout, opt.attention_dropout[0] if type(opt.attention_dropout) is list else opt.attention_dropout, embeddings, opt.max_relative_positions, opt.relative_positions_buckets, pos_ffn_activation_fn=opt.pos_ffn_activation_fn, add_qkvbias=opt.add_qkvbias, num_kv=opt.num_kv, add_ffnbias=opt.add_ffnbias, parallel_residual=opt.parallel_residual, layer_norm=opt.layer_norm, norm_eps=opt.norm_eps, use_ckpting=opt.use_ckpting, parallel_gpu=opt.world_size if opt.parallel_mode == "tensor_parallel" else 1, ) def forward(self, src, src_len=None): """See :func:`EncoderBase.forward()`""" enc_out = self.embeddings(src) mask = ~sequence_mask(src_len).unsqueeze(1) mask = mask.unsqueeze(1) mask = mask.expand(-1, -1, mask.size(3), -1) # mask is now (batch x 1 x slen x slen) # 1 to be expanded to number of heads in MHA # Run the forward pass of every layer of the tranformer. for layer in self.transformer: enc_out = layer(enc_out, mask) enc_out = self.layer_norm(enc_out) return enc_out, None, src_len def update_dropout(self, dropout, attention_dropout): self.embeddings.update_dropout(dropout) for layer in self.transformer: layer.update_dropout(dropout, attention_dropout)