import math import torch import torch.nn as nn from tencentpretrain.utils.rope import apply_rotary_emb class MultiHeadedAttention(nn.Module): """ Each head is a self-attention operation. self-attention refers to https://arxiv.org/pdf/1706.03762.pdf """ def __init__(self, hidden_size, heads_num, attention_head_size, dropout, has_bias=True, with_scale=True): super(MultiHeadedAttention, self).__init__() self.heads_num = heads_num self.per_head_size = attention_head_size self.with_scale = with_scale self.inner_hidden_size = heads_num * attention_head_size self.linear_layers = nn.ModuleList( [nn.Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)] ) self.dropout = nn.Dropout(dropout) self.final_linear = nn.Linear(self.inner_hidden_size, hidden_size, bias=has_bias) def forward(self, key, value, query, mask, position_bias=None, has_residual_attention=False, prev_attn=None, freqs_cis=None): """ Args: key: [batch_size x seq_length x hidden_size] value: [batch_size x seq_length x hidden_size] query: [batch_size x seq_length x hidden_size] mask: [batch_size x 1 x seq_length x seq_length] position_bias: [1 x heads_num x seq_length x seq_length] Returns: output: [batch_size x seq_length x hidden_size] """ batch_size, seq_length, _ = query.size() heads_num = self.heads_num per_head_size = self.per_head_size def shape(x): return x. \ contiguous(). \ view(batch_size, seq_length, heads_num, per_head_size). \ transpose(1, 2) def unshape(x): return x. \ transpose(1, 2). \ contiguous(). \ view(batch_size, seq_length, self.inner_hidden_size) query, key, value = [l(x). \ view(batch_size, -1, heads_num, per_head_size). \ transpose(1, 2) \ for l, x in zip(self.linear_layers, (query, key, value)) ] if freqs_cis is not None: query, key = apply_rotary_emb(query.transpose(1,2), key.transpose(1,2), freqs_cis=freqs_cis) scores = torch.matmul(query, key.transpose(-2, -1)) if position_bias is not None: scores = scores + position_bias if self.with_scale: scores = scores / math.sqrt(float(per_head_size)) scores = scores + mask.type_as(scores) prev_attn_out = None if has_residual_attention: if prev_attn is not None: scores += prev_attn prev_attn_out = scores probs = nn.Softmax(dim=-1)(scores) probs = self.dropout(probs) output = unshape(torch.matmul(probs, value)) output = self.final_linear(output) return output, prev_attn_out