Spaces:
Running
on
Zero
Running
on
Zero
File size: 687 Bytes
24628d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn.functional as F
def get_last_attn(attn_map):
for i, layer in enumerate(attn_map):
attn_map[i] = layer[:, :, -1, :].unsqueeze(2)
return attn_map
def sample_token(logits, top_k=None, top_p=None, temperature=1.0):
# Optionally apply temperature
logits = logits / temperature
# Apply top-k sampling
if top_k is not None:
top_k = min(top_k, logits.size(-1)) # Ensure top_k <= vocab size
values, indices = torch.topk(logits, top_k)
probs = F.softmax(values, dim=-1)
next_token_id = indices[torch.multinomial(probs, 1)]
return next_token_id
return logits.argmax(dim=-1).squeeze()
|