homemade_lo_vi / modules /dot_product_attention.py
moiduy04's picture
Upload 18 files
bc1ada8
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class ScaledDotProductAttention(nn.Module):
"""
Scaled Dot-Product Attention (section 3.2.1)
Args:
- dim (int): dimension of d_k or d_head
- dropout_p (float): probability of dropout
Input:
- query (batch, num_heads, seq_len, d_head)
- key (batch, num_heads, seq_len, d_head)
- value (batch, num_heads, seq_len, d_head)
- mask ()
Output:
- context (batch, num_head, seq_len, d_head): Context matrix.
- attn (batch, num_head, seq_len, seq_len): Attention matrix for visualization.
"""
def __init__(self, dim: int, dropout_p: float) -> None:
super(ScaledDotProductAttention, self).__init__()
self.sqrt_dim = np.sqrt(dim)
self.dropout = nn.Dropout(p = dropout_p)
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
# (batch, num_heads, seq_len, d_head) @ (batch, num_heads, d_head, seq_len)
# ==> score: (batch, num_heads, seq_len, seq_len)
score = torch.matmul(query, key.transpose(-2, -1)) / self.sqrt_dim
if mask is not None:
score.masked_fill_(mask == 0, -1e4)
attn = F.softmax(score, -1)
# (batch, num_head, seq_len, seq_len)
attn = self.dropout(attn)
context = torch.matmul(attn, value)
# (batch, num_head, seq_len, d_head)
return context, attn