juanma1907's picture
Upload folder using huggingface_hub
462dacf
from model import ExLlama, ExLlamaCache, ExLlamaConfig
from tokenizer import ExLlamaTokenizer
from generator import ExLlamaGenerator
import json
import math
import os
import sys
import torch
import torch.nn.functional as F
'''
Passing in model, cache, tokenizer is a total hack because we don't want to have to reinitialize (or move all the globals into a shared state model)
'''
class Perplexity:
def __init__(self, method="default", model = None, cache = None, tokenizer = None):
# This needs to be loaded by calling .load()
self.dataset_chunks = []
self.model = model
self.cache = cache
self.tokenizer = tokenizer
self._begin()
def _begin(self):
if self.cache is None:
self.cache = ExLlamaCache(self.model)
else:
self.cache.current_seq_len = 0
def _next_logits(self, input_ids, apply_lora, last_id_only = True):
# n_logits = []
# a = 0
# while a < input_ids.shape[-1]:
# b = min(input_ids.shape[-1], a + 2048)
# n_logits.append(self.model.forward(input_ids[:, a:b], self.cache, last_id_only, lora = apply_lora))
# a = b
#
# return torch.cat(n_logits, dim = 1)
return self.model.forward(input_ids, self.cache, last_id_only, lora = apply_lora)
def _tokenize(self, text):
return self.tokenizer.encode(text)
# Load raw dataset from a text file and tokenize into chunks. Each chunk can optionally truncated to allow for
# evaluating the same data at different sequence lengths
def load(self, dataset_path, chunk_size, chunk_truncate = None, overlap = 0, minlength = 0, json_key = "text"):
file_extension = os.path.splitext(dataset_path)[1]
# JSON format: Returned chunks may be of variable length, with each chunk representing one list item
if file_extension == '.jsonl' or file_extension == '.json':
with open(dataset_path) as f:
for line in f:
example = json.loads(line)[json_key]
if len(example) > minlength:
chunk = self._tokenize(example)
chunk = chunk[:, :chunk_size]
if chunk_truncate is not None: chunk = chunk[:, :chunk_truncate]
self.dataset_chunks.append(chunk)
# Raw Text: Returned chunks are fixed length windows of the entire tokenized dataset
else:
with open(dataset_path, encoding="utf-8") as f:
text = f.read()
tokens = self._tokenize(text)
# overlap shouldn't be bigger than the context, also need at least one token for predicting last...
if overlap >= chunk_size:
overlap = chunk_size-2
# We can't use torch.chunks since it want's to split things into equal sized chunks. Instead, let's do our own chunking
start = 0
while start < tokens.size(1):
chunk = tokens[:, start:start + chunk_size]
start += chunk_size - overlap
if chunk_truncate is not None: chunk = chunk[:, :chunk_truncate]
self.dataset_chunks.append(chunk)
def test(self, chunk_limit = sys.maxsize, lora = None, tag = "", ppl_token = False):
if not self.dataset_chunks:
sys.exit(" xx ERROR: Empty dataset!")
print(f" -- Testing {min(len(self.dataset_chunks), chunk_limit)} chunks", end="")
sys.stdout.flush()
logprob_sum = 0.0
logprob_count = 0
chunk_count = 0
for chunk in self.dataset_chunks:
self._begin()
input_ids = chunk[:, :-1]
target_ids = chunk[:, 1:]
if ppl_token:
logits_s = []
for i in range(input_ids.shape[-1]):
logits_t = self._next_logits(input_ids[:, i : i + 1], lora, last_id_only = False)
logits_s.append(logits_t)
logits = torch.cat(logits_s, dim = 1)
else:
logits = self._next_logits(input_ids, lora, last_id_only = False)
log_probs = F.log_softmax(logits, dim=-1)
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1)
logprob_sum += token_log_probs.sum().item()
logprob_count += target_ids.numel()
if chunk_count % 10 == 0:
print(".", end = "")
sys.stdout.flush()
chunk_count += 1
if chunk_limit and chunk_count >= chunk_limit:
break
mean_log_prob = logprob_sum / logprob_count
perplexity = math.exp(-mean_log_prob)
print("")
print(f" ** Perplexity{tag}: {perplexity:.4f}")
def add_args(parser):
parser.add_argument("-ppl", "--perplexity", nargs = '?', const = 'default', metavar = "METHOD", help = "Perplexity benchmark. Optionally specify method: gptq-for-llama, llama.cpp (not yet implemented)")
parser.add_argument("-ppl_ds", "--perplexity_dataset", metavar = "DATAPATH", type = str, help = "Load dataset for perplexity (JSONL if .jsonl, otherwise parses it as raw text)")
parser.add_argument("-ppl_cn", "--perplexity_chunk_num", nargs = "?", type = int, help = "Number of chunks for perplexity benchmark", default = 100)
parser.add_argument("-ppl_cs", "--perplexity_chunk_size", type = int, help = "Size of chunks for perplexity benchmark", default = 2048)
parser.add_argument("-ppl_ct", "--perplexity_chunk_truncate", type = int, help = "Truncated size of chunks for perplexity benchmark", default = 2048)
parser.add_argument("-ppl_co", "--perplexity_chunk_overlap", type = int, help = "Chunk overlap", default = 0)
parser.add_argument("-ppl_cm", "--perplexity_chunk_min", type = int, help = "Minimum chunk length", default = 50)
parser.add_argument("-ppl_key", "--perplexity_json_key", type = str, help = "Key to extract from JSON dataset, default: 'text'", default = "text")
parser.add_argument("-ppl_t", "--perplexity_token", action = "store_true", help = "Run perplexity test on individual tokens, for debug purposes (slow)")
def post_parse(args):
if not args.perplexity: return
# GPTQ-for-LLaMa equivalent
if args.perplexity == "gptq-for-llama":
args.perplexity_dataset = "datasets/wikitext2.txt"
args.perplexity_chunk_num = 128
args.perplexity_chunk_size = 2048
args.perplexity_chunk_truncate = 2048
args.perplexity_chunk_overlap = 0
args.perplexity_chunk_min = 0
# Default dataset for legacy method
if args.perplexity_dataset is None: args.perplexity_dataset = "datasets/wikitext2_val_sample.jsonl"
print(f" -- Perplexity:")
print(f" -- - Dataset: {args.perplexity_dataset}")
print(f" -- - Chunks: {args.perplexity_chunk_num}")
print(f" -- - Chunk size: {args.perplexity_chunk_size}" + (f" -> {args.perplexity_chunk_truncate}" if args.perplexity_chunk_truncate is not None else ""))
print(f" -- - Chunk overlap: {args.perplexity_chunk_overlap}")
print(f" -- - Min. chunk size: {args.perplexity_chunk_min}")
print(f" -- - Key: {args.perplexity_json_key}")
if args.perplexity_token: print("f -- - Per-token mode")