|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from typing import List, Dict, Any |
|
import time |
|
|
|
class LogitsPredictor: |
|
def __init__(self): |
|
self.tokenizer = None |
|
self.model = None |
|
|
|
def setup(self, model_path="./"): |
|
"""Load the model into memory to make running multiple predictions efficient""" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_path) |
|
self.separator = list(filter(lambda x: x != self.tokenizer.bos_token_id, self.tokenizer.encode("\n")))[0] |
|
|
|
def quantiles(self, sorted_probs, ranks): |
|
quantiles = [] |
|
for i, probs in enumerate(sorted_probs): |
|
q = torch.sum(probs[:ranks[i]]).item() |
|
quantiles.append(q) |
|
return quantiles |
|
|
|
def ranks(self, sorted_indices, token_ids): |
|
matches = sorted_indices == token_ids.unsqueeze(1) |
|
return torch.argmax(matches.int(), dim=1) |
|
|
|
|
|
def score_tokens(self, logits, token_ids, token_indices, topk): |
|
probs = torch.nn.functional.softmax(logits, dim=1) |
|
surprisals = -torch.log2(probs) |
|
positional_entropies = torch.sum(probs * surprisals, dim=1).unsqueeze(1) |
|
positional_varentropies = torch.sum(probs * (positional_entropies - surprisals) ** 2, dim=1) |
|
tokens = [self.tokenizer.decode([idx]) for idx in token_ids] |
|
|
|
if topk != -1: |
|
sorted_probs, sorted_indices = torch.sort(probs, 1, descending=True) |
|
token_ranks = self.ranks(sorted_indices, token_ids) |
|
token_quantiles = self.quantiles(sorted_probs, token_ranks) |
|
topk_tensors = torch.topk(probs, topk, dim=1) |
|
topk_tokens = [[self.tokenizer.decode(elem.item()) for elem in row] for row in topk_tensors.indices] |
|
|
|
probs_of_actual_tokens = torch.gather(probs, 1, token_ids.unsqueeze(1)) |
|
logits_of_actual_tokens = torch.gather(logits, 1, token_ids.unsqueeze(1)) |
|
|
|
return [{"token": tokens[i], |
|
"probability": probs_of_actual_tokens[i].item(), |
|
"logit": logits_of_actual_tokens[i].item(), |
|
"positional_entropy": positional_entropies[i].item(), |
|
"positional_varentropy": positional_varentropies[i].item(), |
|
**({"token_rank": token_ranks[i].item(), |
|
"token_quantile": token_quantiles[i], |
|
"topk": list(zip(topk_tokens[i], topk_tensors.values[i].tolist()))} if topk != -1 else {}) |
|
} for i in range(len(token_ids))] |
|
|
|
def predict(self, trg_text: str, prefix_text: str, context_length: int, |
|
stride: int, topk: int, perf_metadata: bool) -> List[Dict[str, Any]]: |
|
prediction = [] |
|
pred_start = time.time() |
|
if prefix_text: |
|
prefix_tokens = self.tokenizer.encode(prefix_text.strip()) |
|
trg_tokens = self.tokenizer.encode(trg_text.strip()) |
|
if trg_tokens[0] == self.tokenizer.bos_token_id: |
|
trg_tokens[0] = self.separator |
|
else: |
|
trg_tokens = [self.separator] + trg_tokens |
|
input_tokens = torch.tensor(prefix_tokens + trg_tokens).unsqueeze(0) |
|
prefix_len = len(prefix_tokens) |
|
else: |
|
|
|
|
|
input_tokens = self.tokenizer(trg_text.strip(), return_tensors="pt")["input_ids"] |
|
prefix_len = 0 |
|
|
|
tokenizing_done = time.time() |
|
prev_end_index = prefix_len |
|
l = len(input_tokens[0]) |
|
timing_data = [] |
|
for start_index in range(0, l, stride): |
|
end_index = min(start_index + context_length, l) |
|
info = f"StartIndex: {start_index}\nEndIndex: {end_index}" |
|
print(info) |
|
tokens = input_tokens[:, start_index:end_index] |
|
tokens_len = end_index - start_index |
|
|
|
model_start = time.time() |
|
with torch.no_grad(): |
|
output = self.model(input_ids=tokens) |
|
model_done = time.time() |
|
logits_start_index = prev_end_index - start_index |
|
logits_end_index = -1 if end_index == l else tokens_len |
|
trg_logits = output.logits[0, logits_start_index:logits_end_index, :] |
|
|
|
tokens_end_index = end_index if end_index == l else end_index + 1 |
|
real_tokens = input_tokens[0, prev_end_index+1:tokens_end_index] |
|
real_token_indices = list(range(prefix_len, l))[prev_end_index+1-prefix_len:tokens_end_index-prefix_len] |
|
|
|
scoring_start = time.time() |
|
preds = self.score_tokens(trg_logits, real_tokens, real_token_indices, topk) |
|
scoring_done = time.time() |
|
prediction.extend(preds) |
|
time_data = { |
|
"model_time": model_done - model_start, |
|
"scoring_time": scoring_done - scoring_start, |
|
"tokens_len": tokens_len} |
|
print(time_data) |
|
timing_data.append(time_data) |
|
prev_end_index = end_index |
|
if end_index == l: |
|
break |
|
pred_done = time.time() |
|
res = {"tokens": prediction, "perf_metadata": {"total_time": pred_done - pred_start, "strides": timing_data}} if perf_metadata else {"tokens": prediction} |
|
return res |