import torch from transformers import AutoTokenizer, AutoModelForCausalLM from typing import List, Dict, Any import time class LogitsPredictor: def __init__(self): self.tokenizer = None self.model = None def setup(self, model_path="./"): """Load the model into memory to make running multiple predictions efficient""" self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModelForCausalLM.from_pretrained(model_path) self.separator = list(filter(lambda x: x != self.tokenizer.bos_token_id, self.tokenizer.encode("\n")))[0] def quantiles(self, sorted_probs, ranks): quantiles = [] for i, probs in enumerate(sorted_probs): q = torch.sum(probs[:ranks[i]]).item() quantiles.append(q) return quantiles def ranks(self, sorted_indices, token_ids): matches = sorted_indices == token_ids.unsqueeze(1) return torch.argmax(matches.int(), dim=1) # logits batch dimension already removed, logits.size() = (seq_len, vocab_size) def score_tokens(self, logits, token_ids, token_indices, topk): probs = torch.nn.functional.softmax(logits, dim=1) surprisals = -torch.log2(probs) positional_entropies = torch.sum(probs * surprisals, dim=1).unsqueeze(1) positional_varentropies = torch.sum(probs * (positional_entropies - surprisals) ** 2, dim=1) tokens = [self.tokenizer.decode([idx]) for idx in token_ids] if topk != -1: sorted_probs, sorted_indices = torch.sort(probs, 1, descending=True) token_ranks = self.ranks(sorted_indices, token_ids) token_quantiles = self.quantiles(sorted_probs, token_ranks) topk_tensors = torch.topk(probs, topk, dim=1) topk_tokens = [[self.tokenizer.decode(elem.item()) for elem in row] for row in topk_tensors.indices] probs_of_actual_tokens = torch.gather(probs, 1, token_ids.unsqueeze(1)) logits_of_actual_tokens = torch.gather(logits, 1, token_ids.unsqueeze(1)) return [{"token": tokens[i], "probability": probs_of_actual_tokens[i].item(), "logit": logits_of_actual_tokens[i].item(), "positional_entropy": positional_entropies[i].item(), "positional_varentropy": positional_varentropies[i].item(), **({"token_rank": token_ranks[i].item(), "token_quantile": token_quantiles[i], "topk": list(zip(topk_tokens[i], topk_tensors.values[i].tolist()))} if topk != -1 else {}) } for i in range(len(token_ids))] def predict(self, trg_text: str, prefix_text: str, context_length: int, stride: int, topk: int, perf_metadata: bool) -> List[Dict[str, Any]]: prediction = [] pred_start = time.time() if prefix_text: prefix_tokens = self.tokenizer.encode(prefix_text.strip()) trg_tokens = self.tokenizer.encode(trg_text.strip()) if trg_tokens[0] == self.tokenizer.bos_token_id: trg_tokens[0] = self.separator else: trg_tokens = [self.separator] + trg_tokens input_tokens = torch.tensor(prefix_tokens + trg_tokens).unsqueeze(0) prefix_len = len(prefix_tokens) else: # tokenizer.__call__() vs tokenizer.encode() is only relevant for alignment functions # (which are mostly broken), and attention masks (which are not used here, will for batching though) input_tokens = self.tokenizer(trg_text.strip(), return_tensors="pt")["input_ids"] prefix_len = 0 tokenizing_done = time.time() prev_end_index = prefix_len l = len(input_tokens[0]) timing_data = [] for start_index in range(0, l, stride): end_index = min(start_index + context_length, l) info = f"StartIndex: {start_index}\nEndIndex: {end_index}" print(info) tokens = input_tokens[:, start_index:end_index] tokens_len = end_index - start_index model_start = time.time() with torch.no_grad(): output = self.model(input_ids=tokens) model_done = time.time() logits_start_index = prev_end_index - start_index logits_end_index = -1 if end_index == l else tokens_len trg_logits = output.logits[0, logits_start_index:logits_end_index, :] tokens_end_index = end_index if end_index == l else end_index + 1 real_tokens = input_tokens[0, prev_end_index+1:tokens_end_index] real_token_indices = list(range(prefix_len, l))[prev_end_index+1-prefix_len:tokens_end_index-prefix_len] scoring_start = time.time() preds = self.score_tokens(trg_logits, real_tokens, real_token_indices, topk) scoring_done = time.time() prediction.extend(preds) time_data = { "model_time": model_done - model_start, "scoring_time": scoring_done - scoring_start, "tokens_len": tokens_len} print(time_data) timing_data.append(time_data) prev_end_index = end_index if end_index == l: break pred_done = time.time() res = {"tokens": prediction, "perf_metadata": {"total_time": pred_done - pred_start, "strides": timing_data}} if perf_metadata else {"tokens": prediction} return res