from typing import Optional, Tuple import torch import torch.nn as nn from torch import Tensor from modules.wrapper import Linear from modules.dot_product_attention import ScaledDotProductAttention class MultiHeadAttention(nn.Module): """ Multi-Head Attention (section 3.2.2) Args: - d_model (int): dimension of model - num_heads (int): number of heads - dropout_p (float): probability of dropout Inputs: - query (batch, seq_len, d_model): - key (batch, seq_len, d_model): - value (batch, seq_len, d_model): - mask (): Output: (Tensor, Tensor): - context () - attn (): Attention matrix for visualization. """ def __init__( self, d_model: int, num_heads: int, dropout_p: int, ) -> None: super(MultiHeadAttention, self).__init__() assert d_model % num_heads == 0, "d_model % num_heads should be 0" self.d_model = d_model self.num_heads = num_heads self.d_head = d_model // num_heads self.W_query = Linear(d_model, d_model) self.W_key = Linear(d_model, d_model) self.W_value = Linear(d_model, d_model) # self.W_output = Linear(d_model, d_model) self.scaled_dot_attn = ScaledDotProductAttention(d_model, dropout_p) def forward( self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: batch_size = query.shape[0] # original: (batch, seq_len, d_model) # --forward--> (batch, seq_len, d_model) # --view--> (batch, seq_len, num_heads, d_head) # --transpose--> (batch, num_heads, seq_len, d_head) query = self.W_query(query).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2) key = self.W_key(key).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2) value = self.W_value(value).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2) context, attn = self.scaled_dot_attn(query, key, value, mask) # (batch, num_heads, seq_len, d_head) # --transpose--> (batch, seq_len, num_heads, d_head) # --view--> (batch, seq_len, d_model) context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # context = self.W_output(context) return context, attn