Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,420 Bytes
7ea3fe5 d62afec 7ea3fe5 d62afec 7ea3fe5 d62afec 7ea3fe5 d62afec a74197d d62afec a74197d d62afec a74197d d62afec a74197d d62afec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import torch
from .model import Model
from .utils import sample_token, get_last_attn
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F
device = 'cuda' if torch.cuda.is_available() else 'cpu'
token = os.getenv("HF_TOKEN")
class AttentionModel(Model):
def __init__(self, config):
super().__init__(config)
self.name = config["model_info"]["name"]
self.max_output_tokens = int(config["params"]["max_output_tokens"])
model_id = config["model_info"]["model_id"]
self.tokenizer = AutoTokenizer.from_pretrained(model_id,
use_auth_token=token)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=device,
attn_implementation="eager",
use_auth_token=token
).eval()
if config["params"]["important_heads"] == "all":
attn_size = self.get_map_dim()
self.important_heads = [[i, j] for i in range(
attn_size[0]) for j in range(attn_size[1])]
else:
self.important_heads = config["params"]["important_heads"]
self.top_k = 50
self.top_p = None
def get_map_dim(self):
_, _, attention_maps, _, _, _ = self.inference("print hi", "")
attention_map = attention_maps[0]
return len(attention_map), attention_map[0].shape[1]
# def query(self, msg, return_type="normal", max_output_tokens=None):
# text_split = msg.split('\nText: ')
# instruction, data = text_split[0], text_split[1]
# response, output_tokens, attention_maps, tokens, input_range, generated_probs = self.inference(
# instruction, data, max_output_tokens=max_output_tokens)
# if return_type == "attention":
# return response, output_tokens, attention_maps, tokens, input_range, generated_probs
# else:
# return response
def inference(self, instruction, data, max_output_tokens=None):
messages = [
{"role": "system", "content": instruction},
{"role": "user", "content": "\nText: " + data}
]
# Use tokenization with minimal overhead
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
instruction_len = len(self.tokenizer.encode(instruction))
data_len = len(self.tokenizer.encode(data))
model_inputs = self.tokenizer(
[text], return_tensors="pt").to(self.model.device)
input_tokens = self.tokenizer.convert_ids_to_tokens(
model_inputs['input_ids'][0])
if "qwen" in self.name:
data_range = ((3, 3+instruction_len), (-5-data_len, -5))
elif "phi3" in self.name:
data_range = ((1, 1+instruction_len), (-2-data_len, -2))
elif "llama3-8b" in self.name:
data_range = ((5, 5+instruction_len), (-5-data_len, -5))
elif "mistral-7b" in self.name:
data_range = ((3, 3+instruction_len), (-1-data_len, -1))
elif "granite3-8b" in self.name:
data_range = ((3, 3+instruction_len), (-5-data_len, -5))
else:
raise NotImplementedError
generated_tokens = []
generated_probs = []
input_ids = model_inputs.input_ids
attention_mask = model_inputs.attention_mask
attention_maps = []
if max_output_tokens != None:
n_tokens = max_output_tokens
else:
n_tokens = self.max_output_tokens
with torch.no_grad():
for i in range(n_tokens):
output = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True
)
logits = output.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
# next_token_id = logits.argmax(dim=-1).squeeze()
next_token_id = sample_token(
logits[0], top_k=self.top_k, top_p=self.top_p, temperature=1.0)[0]
generated_probs.append(probs[0, next_token_id.item()].item())
generated_tokens.append(next_token_id.item())
if next_token_id.item() == self.tokenizer.eos_token_id:
break
input_ids = torch.cat(
(input_ids, next_token_id.unsqueeze(0).unsqueeze(0)), dim=-1)
attention_mask = torch.cat(
(attention_mask, torch.tensor([[1]], device=input_ids.device)), dim=-1)
attention_map = [attention.detach().cpu().half()
for attention in output['attentions']]
attention_map = [torch.nan_to_num(
attention, nan=0.0) for attention in attention_map]
attention_map = get_last_attn(attention_map)
attention_maps.append(attention_map)
output_tokens = [self.tokenizer.decode(
token, skip_special_tokens=True) for token in generated_tokens]
generated_text = self.tokenizer.decode(
generated_tokens, skip_special_tokens=True)
return generated_text, output_tokens, attention_maps, input_tokens, data_range, generated_probs
|