ariG23498's picture
ariG23498 HF staff
chore: adding lolcats configs scrc and src
ae81e0f
"""
Shared attention helpers
"""
import torch
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from:
(batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def mask_attention(qk_dot: torch.Tensor, attn_mask: torch.tensor,
mask_value: float = -10000) -> torch.Tensor:
"""
Apply attention mask (e.g., for padding)
"""
if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l)
return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
else:
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)