File size: 8,269 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
"""
Implementation of "Attention is All You Need"
"""
import torch.nn as nn
from onmt.encoders.encoder import EncoderBase
from onmt.modules import MultiHeadedAttention
from onmt.modules.position_ffn import PositionwiseFeedForward
from onmt.modules.position_ffn import ActivationFunction
from onmt.utils.misc import sequence_mask
from onmt.modules.rmsnorm import RMSNorm
class TransformerEncoderLayer(nn.Module):
"""
A single layer of the transformer encoder.
Args:
d_model (int): the dimension of keys/values/queries in
MultiHeadedAttention, also the input size of
the first-layer of the PositionwiseFeedForward.
heads (int): the number of head for MultiHeadedAttention.
d_ff (int): the second-layer of the PositionwiseFeedForward.
dropout (float): dropout probability(0-1.0).
pos_ffn_activation_fn (ActivationFunction):
activation function choice for PositionwiseFeedForward layer
"""
def __init__(
self,
d_model,
heads,
d_ff,
dropout,
attention_dropout,
max_relative_positions=0,
relative_positions_buckets=0,
pos_ffn_activation_fn=ActivationFunction.relu,
add_qkvbias=False,
num_kv=0,
add_ffnbias=True,
parallel_residual=False,
layer_norm="standard",
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiHeadedAttention(
heads,
d_model,
dropout=attention_dropout,
is_decoder=False,
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
attn_type="self",
add_qkvbias=add_qkvbias,
num_kv=num_kv,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
self.feed_forward = PositionwiseFeedForward(
d_model,
d_ff,
dropout,
pos_ffn_activation_fn,
add_ffnbias,
parallel_residual,
layer_norm,
norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
self.parallel_residual = parallel_residual
if layer_norm == "standard":
self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps)
elif layer_norm == "rms":
self.layer_norm = RMSNorm(d_model, eps=norm_eps)
else:
raise ValueError(f"{layer_norm} layer norm type is not supported")
self.dropout = nn.Dropout(dropout)
def forward(self, layer_in, mask):
"""
Args:
layer_in (FloatTensor): ``(batch_size, src_len, model_dim)``
mask (LongTensor): ``(batch_size, 1, src_len)``
Returns:
(FloatTensor):
* layer_out ``(batch_size, src_len, model_dim)``
"""
norm_layer_in = self.layer_norm(layer_in)
context, _ = self.self_attn(
norm_layer_in, norm_layer_in, norm_layer_in, mask=mask
)
if self.parallel_residual:
# feed_forward applies residual, so we remove and apply residual with un-normed
layer_out = (
self.feed_forward(norm_layer_in)
- norm_layer_in
+ layer_in
+ self.dropout(context)
)
else:
layer_out = self.dropout(context) + layer_in
layer_out = self.feed_forward(layer_out)
return layer_out
def update_dropout(self, dropout, attention_dropout):
self.self_attn.update_dropout(attention_dropout)
self.feed_forward.update_dropout(dropout)
self.dropout.p = dropout
class TransformerEncoder(EncoderBase):
"""The Transformer encoder from "Attention is All You Need"
:cite:`DBLP:journals/corr/VaswaniSPUJGKP17`
Args:
num_layers (int): number of encoder layers
d_model (int): size of the model
heads (int): number of heads
d_ff (int): size of the inner FF layer
dropout (float): dropout parameters
embeddings (onmt.modules.Embeddings):
embeddings to use, should have positional encodings
pos_ffn_activation_fn (ActivationFunction):
activation function choice for PositionwiseFeedForward layer
Returns:
(torch.FloatTensor, torch.FloatTensor):
* enc_out ``(batch_size, src_len, model_dim)``
* encoder final state: None in the case of Transformer
* src_len ``(batch_size)``
"""
def __init__(
self,
num_layers,
d_model,
heads,
d_ff,
dropout,
attention_dropout,
embeddings,
max_relative_positions,
relative_positions_buckets,
pos_ffn_activation_fn=ActivationFunction.relu,
add_qkvbias=False,
num_kv=0,
add_ffnbias=True,
parallel_residual=False,
layer_norm="standard",
norm_eps=1e-6,
use_ckpting=[],
parallel_gpu=1,
):
super(TransformerEncoder, self).__init__()
self.embeddings = embeddings
self.transformer = nn.ModuleList(
[
TransformerEncoderLayer(
d_model,
heads,
d_ff,
dropout,
attention_dropout,
max_relative_positions=max_relative_positions,
relative_positions_buckets=relative_positions_buckets,
pos_ffn_activation_fn=pos_ffn_activation_fn,
add_qkvbias=add_qkvbias,
num_kv=num_kv,
add_ffnbias=add_ffnbias,
parallel_residual=parallel_residual,
layer_norm=layer_norm,
norm_eps=norm_eps,
use_ckpting=use_ckpting,
parallel_gpu=parallel_gpu,
)
for i in range(num_layers)
]
)
if layer_norm == "standard":
self.layer_norm = nn.LayerNorm(d_model, eps=norm_eps)
elif layer_norm == "rms":
self.layer_norm = RMSNorm(d_model, eps=norm_eps)
else:
raise ValueError(f"{layer_norm} layer norm type is not supported")
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.enc_layers,
opt.enc_hid_size,
opt.heads,
opt.transformer_ff,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
opt.attention_dropout[0]
if type(opt.attention_dropout) is list
else opt.attention_dropout,
embeddings,
opt.max_relative_positions,
opt.relative_positions_buckets,
pos_ffn_activation_fn=opt.pos_ffn_activation_fn,
add_qkvbias=opt.add_qkvbias,
num_kv=opt.num_kv,
add_ffnbias=opt.add_ffnbias,
parallel_residual=opt.parallel_residual,
layer_norm=opt.layer_norm,
norm_eps=opt.norm_eps,
use_ckpting=opt.use_ckpting,
parallel_gpu=opt.world_size
if opt.parallel_mode == "tensor_parallel"
else 1,
)
def forward(self, src, src_len=None):
"""See :func:`EncoderBase.forward()`"""
enc_out = self.embeddings(src)
mask = ~sequence_mask(src_len).unsqueeze(1)
mask = mask.unsqueeze(1)
mask = mask.expand(-1, -1, mask.size(3), -1)
# mask is now (batch x 1 x slen x slen)
# 1 to be expanded to number of heads in MHA
# Run the forward pass of every layer of the tranformer.
for layer in self.transformer:
enc_out = layer(enc_out, mask)
enc_out = self.layer_norm(enc_out)
return enc_out, None, src_len
def update_dropout(self, dropout, attention_dropout):
self.embeddings.update_dropout(dropout)
for layer in self.transformer:
layer.update_dropout(dropout, attention_dropout)
|