homemade_lo_vi / modules /multi_head_attention.py
moiduy04's picture
Upload 18 files
bc1ada8
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from modules.wrapper import Linear
from modules.dot_product_attention import ScaledDotProductAttention
class MultiHeadAttention(nn.Module):
"""
Multi-Head Attention (section 3.2.2)
Args:
- d_model (int): dimension of model
- num_heads (int): number of heads
- dropout_p (float): probability of dropout
Inputs:
- query (batch, seq_len, d_model):
- key (batch, seq_len, d_model):
- value (batch, seq_len, d_model):
- mask ():
Output: (Tensor, Tensor):
- context ()
- attn (): Attention matrix for visualization.
"""
def __init__(
self,
d_model: int,
num_heads: int,
dropout_p: int,
) -> None:
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model % num_heads should be 0"
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.W_query = Linear(d_model, d_model)
self.W_key = Linear(d_model, d_model)
self.W_value = Linear(d_model, d_model)
# self.W_output = Linear(d_model, d_model)
self.scaled_dot_attn = ScaledDotProductAttention(d_model, dropout_p)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
batch_size = query.shape[0]
# original: (batch, seq_len, d_model)
# --forward--> (batch, seq_len, d_model)
# --view--> (batch, seq_len, num_heads, d_head)
# --transpose--> (batch, num_heads, seq_len, d_head)
query = self.W_query(query).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
key = self.W_key(key).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
value = self.W_value(value).view(batch_size, -1, self.num_heads, self.d_head).transpose(1,2)
context, attn = self.scaled_dot_attn(query, key, value, mask)
# (batch, num_heads, seq_len, d_head)
# --transpose--> (batch, seq_len, num_heads, d_head)
# --view--> (batch, seq_len, d_model)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# context = self.W_output(context)
return context, attn