johnsonhung906
add code for qwen
24628d9
raw
history blame
687 Bytes
import torch
import torch.nn.functional as F
def get_last_attn(attn_map):
for i, layer in enumerate(attn_map):
attn_map[i] = layer[:, :, -1, :].unsqueeze(2)
return attn_map
def sample_token(logits, top_k=None, top_p=None, temperature=1.0):
# Optionally apply temperature
logits = logits / temperature
# Apply top-k sampling
if top_k is not None:
top_k = min(top_k, logits.size(-1)) # Ensure top_k <= vocab size
values, indices = torch.topk(logits, top_k)
probs = F.softmax(values, dim=-1)
next_token_id = indices[torch.multinomial(probs, 1)]
return next_token_id
return logits.argmax(dim=-1).squeeze()