HazardSpence's picture
Update logits.py
ef3cfb9 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Any
import time
class LogitsPredictor:
def __init__(self):
self.tokenizer = None
self.model = None
def setup(self, model_path="./"):
"""Load the model into memory to make running multiple predictions efficient"""
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForCausalLM.from_pretrained(model_path)
self.separator = list(filter(lambda x: x != self.tokenizer.bos_token_id, self.tokenizer.encode("\n")))[0]
def quantiles(self, sorted_probs, ranks):
quantiles = []
for i, probs in enumerate(sorted_probs):
q = torch.sum(probs[:ranks[i]]).item()
quantiles.append(q)
return quantiles
def ranks(self, sorted_indices, token_ids):
matches = sorted_indices == token_ids.unsqueeze(1)
return torch.argmax(matches.int(), dim=1)
# logits batch dimension already removed, logits.size() = (seq_len, vocab_size)
def score_tokens(self, logits, token_ids, token_indices, topk):
probs = torch.nn.functional.softmax(logits, dim=1)
surprisals = -torch.log2(probs)
positional_entropies = torch.sum(probs * surprisals, dim=1).unsqueeze(1)
positional_varentropies = torch.sum(probs * (positional_entropies - surprisals) ** 2, dim=1)
tokens = [self.tokenizer.decode([idx]) for idx in token_ids]
if topk != -1:
sorted_probs, sorted_indices = torch.sort(probs, 1, descending=True)
token_ranks = self.ranks(sorted_indices, token_ids)
token_quantiles = self.quantiles(sorted_probs, token_ranks)
topk_tensors = torch.topk(probs, topk, dim=1)
topk_tokens = [[self.tokenizer.decode(elem.item()) for elem in row] for row in topk_tensors.indices]
probs_of_actual_tokens = torch.gather(probs, 1, token_ids.unsqueeze(1))
logits_of_actual_tokens = torch.gather(logits, 1, token_ids.unsqueeze(1))
return [{"token": tokens[i],
"probability": probs_of_actual_tokens[i].item(),
"logit": logits_of_actual_tokens[i].item(),
"positional_entropy": positional_entropies[i].item(),
"positional_varentropy": positional_varentropies[i].item(),
**({"token_rank": token_ranks[i].item(),
"token_quantile": token_quantiles[i],
"topk": list(zip(topk_tokens[i], topk_tensors.values[i].tolist()))} if topk != -1 else {})
} for i in range(len(token_ids))]
def predict(self, trg_text: str, prefix_text: str, context_length: int,
stride: int, topk: int, perf_metadata: bool) -> List[Dict[str, Any]]:
prediction = []
pred_start = time.time()
if prefix_text:
prefix_tokens = self.tokenizer.encode(prefix_text.strip())
trg_tokens = self.tokenizer.encode(trg_text.strip())
if trg_tokens[0] == self.tokenizer.bos_token_id:
trg_tokens[0] = self.separator
else:
trg_tokens = [self.separator] + trg_tokens
input_tokens = torch.tensor(prefix_tokens + trg_tokens).unsqueeze(0)
prefix_len = len(prefix_tokens)
else:
# tokenizer.__call__() vs tokenizer.encode() is only relevant for alignment functions
# (which are mostly broken), and attention masks (which are not used here, will for batching though)
input_tokens = self.tokenizer(trg_text.strip(), return_tensors="pt")["input_ids"]
prefix_len = 0
tokenizing_done = time.time()
prev_end_index = prefix_len
l = len(input_tokens[0])
timing_data = []
for start_index in range(0, l, stride):
end_index = min(start_index + context_length, l)
info = f"StartIndex: {start_index}\nEndIndex: {end_index}"
print(info)
tokens = input_tokens[:, start_index:end_index]
tokens_len = end_index - start_index
model_start = time.time()
with torch.no_grad():
output = self.model(input_ids=tokens)
model_done = time.time()
logits_start_index = prev_end_index - start_index
logits_end_index = -1 if end_index == l else tokens_len
trg_logits = output.logits[0, logits_start_index:logits_end_index, :]
tokens_end_index = end_index if end_index == l else end_index + 1
real_tokens = input_tokens[0, prev_end_index+1:tokens_end_index]
real_token_indices = list(range(prefix_len, l))[prev_end_index+1-prefix_len:tokens_end_index-prefix_len]
scoring_start = time.time()
preds = self.score_tokens(trg_logits, real_tokens, real_token_indices, topk)
scoring_done = time.time()
prediction.extend(preds)
time_data = {
"model_time": model_done - model_start,
"scoring_time": scoring_done - scoring_start,
"tokens_len": tokens_len}
print(time_data)
timing_data.append(time_data)
prev_end_index = end_index
if end_index == l:
break
pred_done = time.time()
res = {"tokens": prediction, "perf_metadata": {"total_time": pred_done - pred_start, "strides": timing_data}} if perf_metadata else {"tokens": prediction}
return res