File size: 1,202 Bytes
ae81e0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""
Shared attention helpers
"""
import torch


# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). 
    The hidden states go from: 
       (batch, num_key_value_heads, seqlen, head_dim) to 
       (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def mask_attention(qk_dot: torch.Tensor, attn_mask: torch.tensor, 
                   mask_value: float = -10000) -> torch.Tensor:
    """
    Apply attention mask (e.g., for padding)
    """ 
    if len(attn_mask.shape) == 4:  # attn_mask either (b, h, l, d) or (b, l)
        return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
    else:
        return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)