File size: 687 Bytes
24628d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn.functional as F


def get_last_attn(attn_map):
    for i, layer in enumerate(attn_map):
        attn_map[i] = layer[:, :, -1, :].unsqueeze(2)

    return attn_map

def sample_token(logits, top_k=None, top_p=None, temperature=1.0):
    # Optionally apply temperature
    logits = logits / temperature

    # Apply top-k sampling
    if top_k is not None:
        top_k = min(top_k, logits.size(-1))  # Ensure top_k <= vocab size
        values, indices = torch.topk(logits, top_k)
        probs = F.softmax(values, dim=-1)
        next_token_id = indices[torch.multinomial(probs, 1)]

        return next_token_id

    return logits.argmax(dim=-1).squeeze()