from typing import Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class ScaledDotProductAttention(nn.Module): """ Scaled Dot-Product Attention (section 3.2.1) Args: - dim (int): dimension of d_k or d_head - dropout_p (float): probability of dropout Input: - query (batch, num_heads, seq_len, d_head) - key (batch, num_heads, seq_len, d_head) - value (batch, num_heads, seq_len, d_head) - mask () Output: - context (batch, num_head, seq_len, d_head): Context matrix. - attn (batch, num_head, seq_len, seq_len): Attention matrix for visualization. """ def __init__(self, dim: int, dropout_p: float) -> None: super(ScaledDotProductAttention, self).__init__() self.sqrt_dim = np.sqrt(dim) self.dropout = nn.Dropout(p = dropout_p) def forward( self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: # (batch, num_heads, seq_len, d_head) @ (batch, num_heads, d_head, seq_len) # ==> score: (batch, num_heads, seq_len, seq_len) score = torch.matmul(query, key.transpose(-2, -1)) / self.sqrt_dim if mask is not None: score.masked_fill_(mask == 0, -1e4) attn = F.softmax(score, -1) # (batch, num_head, seq_len, seq_len) attn = self.dropout(attn) context = torch.matmul(attn, value) # (batch, num_head, seq_len, d_head) return context, attn