import math import torch import torch.nn as nn class RelativePositionEmbedding(nn.Module): """ Relative Position Embedding https://arxiv.org/abs/1910.10683 https://github.com/bojone/bert4keras/blob/db236eac110a67a587df7660f6a1337d5b2ef07e/bert4keras/layers.py#L663 https://github.com/huggingface/transformers/blob/master/src/transformers/models/t5/modeling_t5.py#L344 """ def __init__(self, heads_num, bidirectional = True, num_buckets = 32, max_distance = 128): super(RelativePositionEmbedding, self).__init__() self.num_buckets = num_buckets self.bidirectional = bidirectional self.max_distance = max_distance self.relative_attention_bias = nn.Embedding(self.num_buckets, heads_num) def forward(self, encoder_hidden, decoder_hidden): """ Compute binned relative position bias Args: encoder_hidden: [batch_size x seq_length x emb_size] decoder_hidden: [batch_size x seq_length x emb_size] Returns: position_bias: [1 x heads_num x seq_length x seq_length] """ query_length = encoder_hidden.size()[1] key_length = decoder_hidden.size()[1] context_position = torch.arange(query_length, dtype=torch.long)[:, None] memory_position = torch.arange(key_length, dtype=torch.long)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self.relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=self.bidirectional, num_buckets=self.num_buckets, max_distance=self.max_distance ) relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values def relative_position_bucket(self, relative_position, bidirectional, num_buckets, max_distance): """ Adapted from Mesh Tensorflow: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ relative_buckets = 0 if bidirectional: num_buckets //= 2 relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) # now relative_position is in the range [0, inf) # half of the buckets are for exact increments in positions max_exact = num_buckets // 2 is_small = relative_position < max_exact # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance relative_postion_if_large = max_exact + ( torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) ).to(torch.long) relative_postion_if_large = torch.min( relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) ) relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large) return relative_buckets