|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
import numpy as np |
|
from typing import Dict, Optional, Tuple |
|
import torch |
|
from torch import Tensor, nn |
|
import torch.nn.functional as F |
|
from torch.nn import LayerNorm, Parameter |
|
from beats.modules import ( |
|
GradMultiply, |
|
SamePad, |
|
get_activation_fn, |
|
GLU_Linear, |
|
quant_noise, |
|
) |
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
def __init__(self, args): |
|
super().__init__() |
|
|
|
self.dropout = args.dropout |
|
self.embedding_dim = args.encoder_embed_dim |
|
|
|
self.pos_conv = nn.Conv1d( |
|
self.embedding_dim, |
|
self.embedding_dim, |
|
kernel_size=args.conv_pos, |
|
padding=args.conv_pos // 2, |
|
groups=args.conv_pos_groups, |
|
) |
|
dropout = 0 |
|
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) |
|
nn.init.normal_(self.pos_conv.weight, mean=0, std=std) |
|
nn.init.constant_(self.pos_conv.bias, 0) |
|
|
|
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) |
|
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) |
|
|
|
if hasattr(args, "relative_position_embedding"): |
|
self.relative_position_embedding = args.relative_position_embedding |
|
self.num_buckets = args.num_buckets |
|
self.max_distance = args.max_distance |
|
else: |
|
self.relative_position_embedding = False |
|
self.num_buckets = 0 |
|
self.max_distance = 0 |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
TransformerSentenceEncoderLayer( |
|
embedding_dim=self.embedding_dim, |
|
ffn_embedding_dim=args.encoder_ffn_embed_dim, |
|
num_attention_heads=args.encoder_attention_heads, |
|
dropout=self.dropout, |
|
attention_dropout=args.attention_dropout, |
|
activation_dropout=args.activation_dropout, |
|
activation_fn=args.activation_fn, |
|
layer_norm_first=args.layer_norm_first, |
|
deep_norm=args.deep_norm, |
|
has_relative_attention_bias=self.relative_position_embedding, |
|
num_buckets=self.num_buckets, |
|
max_distance=self.max_distance, |
|
gru_rel_pos=args.gru_rel_pos, |
|
encoder_layers=args.encoder_layers, |
|
) |
|
for i in range(args.encoder_layers) |
|
] |
|
) |
|
if self.relative_position_embedding: |
|
for i in range(1, args.encoder_layers): |
|
del self.layers[i].self_attn.relative_attention_bias |
|
self.layers[i].self_attn.relative_attention_bias = self.layers[0].self_attn.relative_attention_bias |
|
|
|
self.layer_norm_first = args.layer_norm_first |
|
self.layer_norm = LayerNorm(self.embedding_dim) |
|
self.layerdrop = args.encoder_layerdrop |
|
|
|
self.apply(init_bert_params) |
|
|
|
if args.deep_norm: |
|
deep_norm_beta = math.pow(8 * args.encoder_layers, -1 / 4) |
|
for i in range(args.encoder_layers): |
|
nn.init.xavier_normal_(self.layers[i].self_attn.k_proj.weight, gain=1) |
|
nn.init.xavier_normal_(self.layers[i].self_attn.v_proj.weight, gain=deep_norm_beta) |
|
nn.init.xavier_normal_(self.layers[i].self_attn.q_proj.weight, gain=1) |
|
nn.init.xavier_normal_(self.layers[i].self_attn.out_proj.weight, gain=deep_norm_beta) |
|
nn.init.xavier_normal_(self.layers[i].fc1.weight, gain=deep_norm_beta) |
|
nn.init.xavier_normal_(self.layers[i].fc2.weight, gain=deep_norm_beta) |
|
|
|
self.layer_wise_gradient_decay_ratio = getattr(args, "layer_wise_gradient_decay_ratio", 1) |
|
|
|
def forward(self, x, padding_mask=None, layer=None): |
|
x, layer_results = self.extract_features(x, padding_mask, layer) |
|
|
|
if self.layer_norm_first and layer is None: |
|
x = self.layer_norm(x) |
|
|
|
return x, layer_results |
|
|
|
def extract_features(self, x, padding_mask=None, tgt_layer=None): |
|
|
|
if padding_mask is not None: |
|
x[padding_mask] = 0 |
|
|
|
x_conv = self.pos_conv(x.transpose(1, 2)) |
|
x_conv = x_conv.transpose(1, 2) |
|
x = x + x_conv |
|
|
|
if not self.layer_norm_first: |
|
x = self.layer_norm(x) |
|
|
|
x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
layer_results = [] |
|
z = None |
|
if tgt_layer is not None: |
|
layer_results.append((x, z)) |
|
r = None |
|
pos_bias = None |
|
for i, layer in enumerate(self.layers): |
|
if self.layer_wise_gradient_decay_ratio != 1.0: |
|
x = GradMultiply.apply(x, self.layer_wise_gradient_decay_ratio) |
|
dropout_probability = np.random.random() |
|
if not self.training or (dropout_probability > self.layerdrop): |
|
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, pos_bias=pos_bias) |
|
if tgt_layer is not None: |
|
layer_results.append((x, z)) |
|
if i == tgt_layer: |
|
r = x |
|
break |
|
|
|
if r is not None: |
|
x = r |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
return x, layer_results |
|
|
|
|
|
class TransformerSentenceEncoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: float = 768, |
|
ffn_embedding_dim: float = 3072, |
|
num_attention_heads: float = 8, |
|
dropout: float = 0.1, |
|
attention_dropout: float = 0.1, |
|
activation_dropout: float = 0.1, |
|
activation_fn: str = "relu", |
|
layer_norm_first: bool = False, |
|
deep_norm: bool = False, |
|
has_relative_attention_bias: bool = False, |
|
num_buckets: int = 0, |
|
max_distance: int = 0, |
|
rescale_init: bool = False, |
|
gru_rel_pos: bool = False, |
|
encoder_layers: int = 0, |
|
) -> None: |
|
|
|
super().__init__() |
|
self.embedding_dim = embedding_dim |
|
self.dropout = dropout |
|
self.activation_dropout = activation_dropout |
|
|
|
self.activation_name = activation_fn |
|
self.activation_fn = get_activation_fn(activation_fn) |
|
self.self_attn = MultiheadAttention( |
|
self.embedding_dim, |
|
num_attention_heads, |
|
dropout=attention_dropout, |
|
self_attention=True, |
|
has_relative_attention_bias=has_relative_attention_bias, |
|
num_buckets=num_buckets, |
|
max_distance=max_distance, |
|
rescale_init=rescale_init, |
|
gru_rel_pos=gru_rel_pos, |
|
) |
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(self.activation_dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
self.layer_norm_first = layer_norm_first |
|
|
|
self.self_attn_layer_norm = LayerNorm(self.embedding_dim) |
|
|
|
if self.activation_name == "glu": |
|
self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") |
|
else: |
|
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
|
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) |
|
|
|
self.final_layer_norm = LayerNorm(self.embedding_dim) |
|
|
|
self.deep_norm = deep_norm |
|
if self.deep_norm: |
|
self.deep_norm_alpha = math.pow(2 * encoder_layers, 1 / 4) |
|
else: |
|
self.deep_norm_alpha = 1 |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
self_attn_mask: torch.Tensor = None, |
|
self_attn_padding_mask: torch.Tensor = None, |
|
need_weights: bool = False, |
|
pos_bias=None |
|
): |
|
residual = x |
|
|
|
if self.layer_norm_first: |
|
x = self.self_attn_layer_norm(x) |
|
x, attn, pos_bias = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
need_weights=False, |
|
attn_mask=self_attn_mask, |
|
position_bias=pos_bias |
|
) |
|
x = self.dropout1(x) |
|
x = residual + x |
|
|
|
residual = x |
|
x = self.final_layer_norm(x) |
|
if self.activation_name == "glu": |
|
x = self.fc1(x) |
|
else: |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.dropout2(x) |
|
x = self.fc2(x) |
|
x = self.dropout3(x) |
|
x = residual + x |
|
else: |
|
x, attn, pos_bias = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=self_attn_padding_mask, |
|
need_weights=need_weights, |
|
attn_mask=self_attn_mask, |
|
position_bias=pos_bias |
|
) |
|
|
|
x = self.dropout1(x) |
|
x = residual * self.deep_norm_alpha + x |
|
|
|
x = self.self_attn_layer_norm(x) |
|
|
|
residual = x |
|
if self.activation_name == "glu": |
|
x = self.fc1(x) |
|
else: |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.dropout2(x) |
|
x = self.fc2(x) |
|
x = self.dropout3(x) |
|
x = residual * self.deep_norm_alpha + x |
|
x = self.final_layer_norm(x) |
|
|
|
return x, attn, pos_bias |
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
"""Multi-headed attention. |
|
|
|
See "Attention Is All You Need" for more details. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
kdim=None, |
|
vdim=None, |
|
dropout=0.0, |
|
bias=True, |
|
add_bias_kv=False, |
|
add_zero_attn=False, |
|
self_attention=False, |
|
encoder_decoder_attention=False, |
|
q_noise=0.0, |
|
qn_block_size=8, |
|
has_relative_attention_bias=False, |
|
num_buckets=32, |
|
max_distance=128, |
|
gru_rel_pos=False, |
|
rescale_init=False, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.kdim = kdim if kdim is not None else embed_dim |
|
self.vdim = vdim if vdim is not None else embed_dim |
|
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
|
self.num_heads = num_heads |
|
self.dropout_module = nn.Dropout(dropout) |
|
|
|
self.has_relative_attention_bias = has_relative_attention_bias |
|
self.num_buckets = num_buckets |
|
self.max_distance = max_distance |
|
if self.has_relative_attention_bias: |
|
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) |
|
|
|
self.head_dim = embed_dim // num_heads |
|
self.q_head_dim = self.head_dim |
|
self.k_head_dim = self.head_dim |
|
assert ( |
|
self.head_dim * num_heads == self.embed_dim |
|
), "embed_dim must be divisible by num_heads" |
|
self.scaling = self.head_dim ** -0.5 |
|
|
|
self.self_attention = self_attention |
|
self.encoder_decoder_attention = encoder_decoder_attention |
|
|
|
assert not self.self_attention or self.qkv_same_dim, ( |
|
"Self-attention requires query, key and " "value to be of the same size" |
|
) |
|
|
|
k_bias = True |
|
if rescale_init: |
|
k_bias = False |
|
|
|
k_embed_dim = embed_dim |
|
q_embed_dim = embed_dim |
|
|
|
self.k_proj = quant_noise( |
|
nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size |
|
) |
|
self.v_proj = quant_noise( |
|
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size |
|
) |
|
self.q_proj = quant_noise( |
|
nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size |
|
) |
|
|
|
self.out_proj = quant_noise( |
|
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size |
|
) |
|
|
|
if add_bias_kv: |
|
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) |
|
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) |
|
else: |
|
self.bias_k = self.bias_v = None |
|
|
|
self.add_zero_attn = add_zero_attn |
|
|
|
self.gru_rel_pos = gru_rel_pos |
|
if self.gru_rel_pos: |
|
self.grep_linear = nn.Linear(self.q_head_dim, 8) |
|
self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
if self.qkv_same_dim: |
|
|
|
|
|
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
|
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
|
else: |
|
nn.init.xavier_uniform_(self.k_proj.weight) |
|
nn.init.xavier_uniform_(self.v_proj.weight) |
|
nn.init.xavier_uniform_(self.q_proj.weight) |
|
|
|
nn.init.xavier_uniform_(self.out_proj.weight) |
|
if self.out_proj.bias is not None: |
|
nn.init.constant_(self.out_proj.bias, 0.0) |
|
if self.bias_k is not None: |
|
nn.init.xavier_normal_(self.bias_k) |
|
if self.bias_v is not None: |
|
nn.init.xavier_normal_(self.bias_v) |
|
if self.has_relative_attention_bias: |
|
nn.init.xavier_normal_(self.relative_attention_bias.weight) |
|
|
|
def _relative_positions_bucket(self, relative_positions, bidirectional=True): |
|
num_buckets = self.num_buckets |
|
max_distance = self.max_distance |
|
relative_buckets = 0 |
|
|
|
if bidirectional: |
|
num_buckets = num_buckets // 2 |
|
relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets |
|
relative_positions = torch.abs(relative_positions) |
|
else: |
|
relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) |
|
|
|
max_exact = num_buckets // 2 |
|
is_small = relative_positions < max_exact |
|
|
|
relative_postion_if_large = max_exact + ( |
|
torch.log(relative_positions.float() / max_exact) |
|
/ math.log(max_distance / max_exact) |
|
* (num_buckets - max_exact) |
|
).to(torch.long) |
|
relative_postion_if_large = torch.min( |
|
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) |
|
) |
|
|
|
relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) |
|
return relative_buckets |
|
|
|
def compute_bias(self, query_length, key_length): |
|
context_position = torch.arange(query_length, dtype=torch.long)[:, None] |
|
memory_position = torch.arange(key_length, dtype=torch.long)[None, :] |
|
relative_position = memory_position - context_position |
|
relative_position_bucket = self._relative_positions_bucket( |
|
relative_position, |
|
bidirectional=True |
|
) |
|
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) |
|
values = self.relative_attention_bias(relative_position_bucket) |
|
values = values.permute([2, 0, 1]) |
|
return values |
|
|
|
def forward( |
|
self, |
|
query, |
|
key: Optional[Tensor], |
|
value: Optional[Tensor], |
|
key_padding_mask: Optional[Tensor] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
need_weights: bool = True, |
|
static_kv: bool = False, |
|
attn_mask: Optional[Tensor] = None, |
|
before_softmax: bool = False, |
|
need_head_weights: bool = False, |
|
position_bias: Optional[Tensor] = None |
|
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: |
|
"""Input shape: Time x Batch x Channel |
|
|
|
Args: |
|
key_padding_mask (ByteTensor, optional): mask to exclude |
|
keys that are pads, of shape `(batch, src_len)`, where |
|
padding elements are indicated by 1s. |
|
need_weights (bool, optional): return the attention weights, |
|
averaged over heads (default: False). |
|
attn_mask (ByteTensor, optional): typically used to |
|
implement causal attention, where the mask prevents the |
|
attention from looking forward in time (default: None). |
|
before_softmax (bool, optional): return the raw attention |
|
weights and values before the attention softmax. |
|
need_head_weights (bool, optional): return the attention |
|
weights for each head. Implies *need_weights*. Default: |
|
return the average attention weights over all heads. |
|
""" |
|
if need_head_weights: |
|
need_weights = True |
|
|
|
is_tpu = query.device.type == "xla" |
|
|
|
tgt_len, bsz, embed_dim = query.size() |
|
src_len = tgt_len |
|
assert embed_dim == self.embed_dim |
|
assert list(query.size()) == [tgt_len, bsz, embed_dim] |
|
if key is not None: |
|
src_len, key_bsz, _ = key.size() |
|
if not torch.jit.is_scripting(): |
|
assert key_bsz == bsz |
|
assert value is not None |
|
assert src_len, bsz == value.shape[:2] |
|
|
|
if self.has_relative_attention_bias and position_bias is None: |
|
position_bias = self.compute_bias(tgt_len, src_len) |
|
position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if incremental_state is not None: |
|
saved_state = self._get_input_buffer(incremental_state) |
|
if saved_state is not None and "prev_key" in saved_state: |
|
|
|
|
|
if static_kv: |
|
assert self.encoder_decoder_attention and not self.self_attention |
|
key = value = None |
|
else: |
|
saved_state = None |
|
|
|
if self.self_attention: |
|
q = self.q_proj(query) |
|
k = self.k_proj(query) |
|
v = self.v_proj(query) |
|
elif self.encoder_decoder_attention: |
|
|
|
q = self.q_proj(query) |
|
if key is None: |
|
assert value is None |
|
k = v = None |
|
else: |
|
k = self.k_proj(key) |
|
v = self.v_proj(key) |
|
|
|
else: |
|
assert key is not None and value is not None |
|
q = self.q_proj(query) |
|
k = self.k_proj(key) |
|
v = self.v_proj(value) |
|
q *= self.scaling |
|
alpha = 32 |
|
q *= 1 / alpha |
|
|
|
if self.bias_k is not None: |
|
assert self.bias_v is not None |
|
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
|
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
|
if attn_mask is not None: |
|
attn_mask = torch.cat( |
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
) |
|
if key_padding_mask is not None: |
|
key_padding_mask = torch.cat( |
|
[ |
|
key_padding_mask, |
|
key_padding_mask.new_zeros(key_padding_mask.size(0), 1), |
|
], |
|
dim=1, |
|
) |
|
|
|
q = ( |
|
q.contiguous() |
|
.view(tgt_len, bsz * self.num_heads, self.q_head_dim) |
|
.transpose(0, 1) |
|
) |
|
if k is not None: |
|
k = ( |
|
k.contiguous() |
|
.view(-1, bsz * self.num_heads, self.k_head_dim) |
|
.transpose(0, 1) |
|
) |
|
if v is not None: |
|
v = ( |
|
v.contiguous() |
|
.view(-1, bsz * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
|
|
if saved_state is not None: |
|
|
|
if "prev_key" in saved_state: |
|
_prev_key = saved_state["prev_key"] |
|
assert _prev_key is not None |
|
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) |
|
if static_kv: |
|
k = prev_key |
|
else: |
|
assert k is not None |
|
k = torch.cat([prev_key, k], dim=1) |
|
src_len = k.size(1) |
|
if "prev_value" in saved_state: |
|
_prev_value = saved_state["prev_value"] |
|
assert _prev_value is not None |
|
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) |
|
if static_kv: |
|
v = prev_value |
|
else: |
|
assert v is not None |
|
v = torch.cat([prev_value, v], dim=1) |
|
prev_key_padding_mask: Optional[Tensor] = None |
|
if "prev_key_padding_mask" in saved_state: |
|
prev_key_padding_mask = saved_state["prev_key_padding_mask"] |
|
assert k is not None and v is not None |
|
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( |
|
key_padding_mask=key_padding_mask, |
|
prev_key_padding_mask=prev_key_padding_mask, |
|
batch_size=bsz, |
|
src_len=k.size(1), |
|
static_kv=static_kv, |
|
) |
|
|
|
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) |
|
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) |
|
saved_state["prev_key_padding_mask"] = key_padding_mask |
|
|
|
assert incremental_state is not None |
|
incremental_state = self._set_input_buffer(incremental_state, saved_state) |
|
assert k is not None |
|
assert k.size(1) == src_len |
|
|
|
|
|
|
|
if key_padding_mask is not None and key_padding_mask.dim() == 0: |
|
key_padding_mask = None |
|
|
|
if key_padding_mask is not None: |
|
assert key_padding_mask.size(0) == bsz |
|
assert key_padding_mask.size(1) == src_len |
|
|
|
if self.add_zero_attn: |
|
assert v is not None |
|
src_len += 1 |
|
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) |
|
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) |
|
if attn_mask is not None: |
|
attn_mask = torch.cat( |
|
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 |
|
) |
|
if key_padding_mask is not None: |
|
key_padding_mask = torch.cat( |
|
[ |
|
key_padding_mask, |
|
torch.zeros(key_padding_mask.size(0), 1).type_as( |
|
key_padding_mask |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
attn_weights = torch.bmm(q, k.transpose(1, 2)) |
|
attn_weights = (attn_weights - attn_weights.max(dim=-1, keepdim=True)[0]) * alpha |
|
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) |
|
|
|
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
|
if attn_mask is not None: |
|
attn_mask = attn_mask.unsqueeze(0) |
|
attn_weights += attn_mask |
|
|
|
if key_padding_mask is not None: |
|
|
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
|
if not is_tpu: |
|
attn_weights = attn_weights.masked_fill( |
|
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), |
|
float("-inf"), |
|
) |
|
else: |
|
attn_weights = attn_weights.transpose(0, 2) |
|
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) |
|
attn_weights = attn_weights.transpose(0, 2) |
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
|
if before_softmax: |
|
return attn_weights, v, position_bias |
|
|
|
if position_bias is not None: |
|
attn_mask_rel_pos = position_bias |
|
if self.gru_rel_pos == 1: |
|
query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) * alpha / self.scaling |
|
_B, _H, _L, __ = query_layer.size() |
|
gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( |
|
_B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) |
|
gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 |
|
attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, tgt_len, 1) * position_bias |
|
|
|
attn_mask_rel_pos = attn_mask_rel_pos.view(attn_weights.size()) |
|
|
|
attn_weights = attn_weights + attn_mask_rel_pos |
|
|
|
attn_weights_float = F.softmax( |
|
attn_weights, dim=-1 |
|
) |
|
attn_weights = attn_weights_float.type_as(attn_weights) |
|
attn_probs = self.dropout_module(attn_weights) |
|
|
|
assert v is not None |
|
attn = torch.bmm(attn_probs, v) |
|
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
|
attn = self.out_proj(attn) |
|
attn_weights: Optional[Tensor] = None |
|
if need_weights: |
|
attn_weights = attn_weights_float.view( |
|
bsz, self.num_heads, tgt_len, src_len |
|
).transpose(1, 0) |
|
if not need_head_weights: |
|
|
|
attn_weights = attn_weights.mean(dim=0) |
|
|
|
return attn, attn_weights, position_bias |
|
|
|
@staticmethod |
|
def _append_prev_key_padding_mask( |
|
key_padding_mask: Optional[Tensor], |
|
prev_key_padding_mask: Optional[Tensor], |
|
batch_size: int, |
|
src_len: int, |
|
static_kv: bool, |
|
) -> Optional[Tensor]: |
|
|
|
if prev_key_padding_mask is not None and static_kv: |
|
new_key_padding_mask = prev_key_padding_mask |
|
elif prev_key_padding_mask is not None and key_padding_mask is not None: |
|
new_key_padding_mask = torch.cat( |
|
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 |
|
) |
|
|
|
|
|
|
|
elif prev_key_padding_mask is not None: |
|
if src_len > prev_key_padding_mask.size(1): |
|
filler = torch.zeros( |
|
(batch_size, src_len - prev_key_padding_mask.size(1)), |
|
device=prev_key_padding_mask.device, |
|
) |
|
new_key_padding_mask = torch.cat( |
|
[prev_key_padding_mask.float(), filler.float()], dim=1 |
|
) |
|
else: |
|
new_key_padding_mask = prev_key_padding_mask.float() |
|
elif key_padding_mask is not None: |
|
if src_len > key_padding_mask.size(1): |
|
filler = torch.zeros( |
|
(batch_size, src_len - key_padding_mask.size(1)), |
|
device=key_padding_mask.device, |
|
) |
|
new_key_padding_mask = torch.cat( |
|
[filler.float(), key_padding_mask.float()], dim=1 |
|
) |
|
else: |
|
new_key_padding_mask = key_padding_mask.float() |
|
else: |
|
new_key_padding_mask = prev_key_padding_mask |
|
return new_key_padding_mask |
|
|
|
def _get_input_buffer( |
|
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] |
|
) -> Dict[str, Optional[Tensor]]: |
|
result = self.get_incremental_state(incremental_state, "attn_state") |
|
if result is not None: |
|
return result |
|
else: |
|
empty_result: Dict[str, Optional[Tensor]] = {} |
|
return empty_result |
|
|
|
def _set_input_buffer( |
|
self, |
|
incremental_state: Dict[str, Dict[str, Optional[Tensor]]], |
|
buffer: Dict[str, Optional[Tensor]], |
|
): |
|
return self.set_incremental_state(incremental_state, "attn_state", buffer) |
|
|
|
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): |
|
return attn_weights |
|
|
|
|
|
def init_bert_params(module): |
|
""" |
|
Initialize the weights specific to the BERT Model. |
|
This overrides the default initializations depending on the specified arguments. |
|
1. If normal_init_linear_weights is set then weights of linear |
|
layer will be initialized using the normal distribution and |
|
bais will be set to the specified value. |
|
2. If normal_init_embed_weights is set then weights of embedding |
|
layer will be initialized using the normal distribution. |
|
3. If normal_init_proj_weights is set then weights of |
|
in_project_weight for MultiHeadAttention initialized using |
|
the normal distribution (to be validated). |
|
""" |
|
|
|
def normal_(data): |
|
|
|
|
|
data.copy_( |
|
data.cpu().normal_(mean=0.0, std=0.02).to(data.device) |
|
) |
|
|
|
if isinstance(module, nn.Linear): |
|
normal_(module.weight.data) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
if isinstance(module, nn.Embedding): |
|
normal_(module.weight.data) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
if isinstance(module, MultiheadAttention): |
|
normal_(module.q_proj.weight.data) |
|
normal_(module.k_proj.weight.data) |
|
normal_(module.v_proj.weight.data) |