|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
from modules.wrapper import Linear |
|
from modules.dot_product_attention import ScaledDotProductAttention |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
""" |
|
Multi-Head Attention (section 3.2.2) |
|
|
|
Args: |
|
- d_model (int): dimension of model |
|
- num_heads (int): number of heads |
|
- dropout_p (float): probability of dropout |
|
|
|
Inputs: |
|
- query (batch, seq_len, d_model): |
|
- key (batch, seq_len, d_model): |
|
- value (batch, seq_len, d_model): |
|
- mask (): |
|
|
|
Output: (Tensor, Tensor): |
|
- context () |
|
- attn (): Attention matrix for visualization. |
|
""" |
|
def __init__( |
|
self, |
|
d_model: int, |
|
num_heads: int, |
|
dropout_p: int, |
|
) -> None: |
|
super(MultiHeadAttention, self).__init__() |
|
|
|
assert d_model % num_heads == 0, "d_model % num_heads should be 0" |
|
|
|
self.d_model = d_model |
|
self.num_heads = num_heads |
|
self.d_head = d_model // num_heads |
|
|
|
self.W_query = Linear(d_model, d_model) |
|
self.W_key = Linear(d_model, d_model) |
|
self.W_value = Linear(d_model, d_model) |
|
|
|
|
|
self.scaled_dot_attn = ScaledDotProductAttention(d_model, dropout_p) |
|
|
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
mask: Optional[Tensor] = None, |
|
) -> Tuple[Tensor, Tensor]: |
|
batch_size = query.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
query = self.W_query(query).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2) |
|
key = self.W_key(key).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2) |
|
value = self.W_value(value).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2) |
|
|
|
context, attn = self.scaled_dot_attn(query, key, value, mask) |
|
|
|
|
|
|
|
|
|
|
|
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) |
|
|
|
|
|
return context, attn |
|
|