|
""" |
|
Shared attention helpers |
|
""" |
|
import torch |
|
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
|
""" |
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). |
|
The hidden states go from: |
|
(batch, num_key_value_heads, seqlen, head_dim) to |
|
(batch, num_attention_heads, seqlen, head_dim) |
|
""" |
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
|
if n_rep == 1: |
|
return hidden_states |
|
hidden_states = hidden_states[:, :, None, :, :].expand( |
|
batch, num_key_value_heads, n_rep, slen, head_dim) |
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
|
def mask_attention(qk_dot: torch.Tensor, attn_mask: torch.tensor, |
|
mask_value: float = -10000) -> torch.Tensor: |
|
""" |
|
Apply attention mask (e.g., for padding) |
|
""" |
|
if len(attn_mask.shape) == 4: |
|
return qk_dot.masked_fill(~attn_mask.bool(), mask_value) |
|
else: |
|
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value) |