Spaces:
Runtime error
Runtime error
import math | |
import torch | |
import torch.nn as nn | |
class RelativePositionEmbedding(nn.Module): | |
""" Relative Position Embedding | |
https://arxiv.org/abs/1910.10683 | |
https://github.com/bojone/bert4keras/blob/db236eac110a67a587df7660f6a1337d5b2ef07e/bert4keras/layers.py#L663 | |
https://github.com/huggingface/transformers/blob/master/src/transformers/models/t5/modeling_t5.py#L344 | |
""" | |
def __init__(self, heads_num, bidirectional = True, num_buckets = 32, max_distance = 128): | |
super(RelativePositionEmbedding, self).__init__() | |
self.num_buckets = num_buckets | |
self.bidirectional = bidirectional | |
self.max_distance = max_distance | |
self.relative_attention_bias = nn.Embedding(self.num_buckets, heads_num) | |
def forward(self, encoder_hidden, decoder_hidden): | |
""" | |
Compute binned relative position bias | |
Args: | |
encoder_hidden: [batch_size x seq_length x emb_size] | |
decoder_hidden: [batch_size x seq_length x emb_size] | |
Returns: | |
position_bias: [1 x heads_num x seq_length x seq_length] | |
""" | |
query_length = encoder_hidden.size()[1] | |
key_length = decoder_hidden.size()[1] | |
context_position = torch.arange(query_length, dtype=torch.long)[:, None] | |
memory_position = torch.arange(key_length, dtype=torch.long)[None, :] | |
relative_position = memory_position - context_position # shape (query_length, key_length) | |
relative_position_bucket = self.relative_position_bucket( | |
relative_position, # shape (query_length, key_length) | |
bidirectional=self.bidirectional, | |
num_buckets=self.num_buckets, | |
max_distance=self.max_distance | |
) | |
relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) | |
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) | |
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) | |
return values | |
def relative_position_bucket(self, relative_position, bidirectional, num_buckets, max_distance): | |
""" | |
Adapted from Mesh Tensorflow: | |
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 | |
Translate relative position to a bucket number for relative attention. The relative position is defined as | |
memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to | |
position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for | |
small absolute relative_position and larger buckets for larger absolute relative_positions. All relative | |
positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. | |
This should allow for more graceful generalization to longer sequences than the model has been trained on | |
Args: | |
relative_position: an int32 Tensor | |
bidirectional: a boolean - whether the attention is bidirectional | |
num_buckets: an integer | |
max_distance: an integer | |
Returns: | |
a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) | |
""" | |
relative_buckets = 0 | |
if bidirectional: | |
num_buckets //= 2 | |
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets | |
relative_position = torch.abs(relative_position) | |
else: | |
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) | |
# now relative_position is in the range [0, inf) | |
# half of the buckets are for exact increments in positions | |
max_exact = num_buckets // 2 | |
is_small = relative_position < max_exact | |
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance | |
relative_postion_if_large = max_exact + ( | |
torch.log(relative_position.float() / max_exact) | |
/ math.log(max_distance / max_exact) | |
* (num_buckets - max_exact) | |
).to(torch.long) | |
relative_postion_if_large = torch.min( | |
relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) | |
) | |
relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) | |
return relative_buckets | |