File size: 1,202 Bytes
ae81e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
"""
Shared attention helpers
"""
import torch
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from:
(batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def mask_attention(qk_dot: torch.Tensor, attn_mask: torch.tensor,
mask_value: float = -10000) -> torch.Tensor:
"""
Apply attention mask (e.g., for padding)
"""
if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l)
return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
else:
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value) |