VISOR-GPT / train /tencentpretrain /layers /relative_position_embedding.py
szukevin's picture
upload
7900c16
raw
history blame
4.63 kB
import math
import torch
import torch.nn as nn
class RelativePositionEmbedding(nn.Module):
""" Relative Position Embedding
https://arxiv.org/abs/1910.10683
https://github.com/bojone/bert4keras/blob/db236eac110a67a587df7660f6a1337d5b2ef07e/bert4keras/layers.py#L663
https://github.com/huggingface/transformers/blob/master/src/transformers/models/t5/modeling_t5.py#L344
"""
def __init__(self, heads_num, bidirectional = True, num_buckets = 32, max_distance = 128):
super(RelativePositionEmbedding, self).__init__()
self.num_buckets = num_buckets
self.bidirectional = bidirectional
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(self.num_buckets, heads_num)
def forward(self, encoder_hidden, decoder_hidden):
"""
Compute binned relative position bias
Args:
encoder_hidden: [batch_size x seq_length x emb_size]
decoder_hidden: [batch_size x seq_length x emb_size]
Returns:
position_bias: [1 x heads_num x seq_length x seq_length]
"""
query_length = encoder_hidden.size()[1]
key_length = decoder_hidden.size()[1]
context_position = torch.arange(query_length, dtype=torch.long)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self.relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance
)
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
def relative_position_bucket(self, relative_position, bidirectional, num_buckets, max_distance):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention. The relative position is defined as
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
This should allow for more graceful generalization to longer sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
"""
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_postion_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_postion_if_large = torch.min(
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
return relative_buckets