johnsonhung906
add code for qwen
24628d9
raw
history blame
1.25 kB
import torch
import numpy as np
def process_attn(attention, rng, attn_func):
heatmap = np.zeros((len(attention), attention[0].shape[1]))
for i, attn_layer in enumerate(attention):
attn_layer = attn_layer.to(torch.float32).numpy()
if "sum" in attn_func:
last_token_attn_to_inst = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1)
attn = last_token_attn_to_inst
elif "max" in attn_func:
last_token_attn_to_inst = np.max(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1)
attn = last_token_attn_to_inst
else: raise NotImplementedError
last_token_attn_to_inst_sum = np.sum(attn_layer[0, :, -1, rng[0][0]:rng[0][1]], axis=1)
last_token_attn_to_data_sum = np.sum(attn_layer[0, :, -1, rng[1][0]:rng[1][1]], axis=1)
if "normalize" in attn_func:
epsilon = 1e-8
heatmap[i, :] = attn / (last_token_attn_to_inst_sum + last_token_attn_to_data_sum + epsilon)
else:
heatmap[i, :] = attn
heatmap = np.nan_to_num(heatmap, nan=0.0)
return heatmap
def calc_attn_score(heatmap, heads):
score = np.mean([heatmap[l, h] for l, h in heads], axis=0)
return score