Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn.functional as F | |
def get_last_attn(attn_map): | |
for i, layer in enumerate(attn_map): | |
attn_map[i] = layer[:, :, -1, :].unsqueeze(2) | |
return attn_map | |
def sample_token(logits, top_k=None, top_p=None, temperature=1.0): | |
# Optionally apply temperature | |
logits = logits / temperature | |
# Apply top-k sampling | |
if top_k is not None: | |
top_k = min(top_k, logits.size(-1)) # Ensure top_k <= vocab size | |
values, indices = torch.topk(logits, top_k) | |
probs = F.softmax(values, dim=-1) | |
next_token_id = indices[torch.multinomial(probs, 1)] | |
return next_token_id | |
return logits.argmax(dim=-1).squeeze() | |