|
from model import ExLlama, ExLlamaCache, ExLlamaConfig |
|
from tokenizer import ExLlamaTokenizer |
|
from generator import ExLlamaGenerator |
|
|
|
import json |
|
import math |
|
import os |
|
import sys |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
''' |
|
Passing in model, cache, tokenizer is a total hack because we don't want to have to reinitialize (or move all the globals into a shared state model) |
|
''' |
|
|
|
class Perplexity: |
|
def __init__(self, method="default", model = None, cache = None, tokenizer = None): |
|
|
|
self.dataset_chunks = [] |
|
|
|
self.model = model |
|
self.cache = cache |
|
self.tokenizer = tokenizer |
|
|
|
self._begin() |
|
|
|
|
|
def _begin(self): |
|
if self.cache is None: |
|
self.cache = ExLlamaCache(self.model) |
|
else: |
|
self.cache.current_seq_len = 0 |
|
|
|
|
|
def _next_logits(self, input_ids, apply_lora, last_id_only = True): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.model.forward(input_ids, self.cache, last_id_only, lora = apply_lora) |
|
|
|
|
|
def _tokenize(self, text): |
|
return self.tokenizer.encode(text) |
|
|
|
|
|
|
|
|
|
|
|
def load(self, dataset_path, chunk_size, chunk_truncate = None, overlap = 0, minlength = 0, json_key = "text"): |
|
|
|
file_extension = os.path.splitext(dataset_path)[1] |
|
|
|
|
|
|
|
if file_extension == '.jsonl' or file_extension == '.json': |
|
with open(dataset_path) as f: |
|
for line in f: |
|
example = json.loads(line)[json_key] |
|
if len(example) > minlength: |
|
chunk = self._tokenize(example) |
|
chunk = chunk[:, :chunk_size] |
|
if chunk_truncate is not None: chunk = chunk[:, :chunk_truncate] |
|
self.dataset_chunks.append(chunk) |
|
|
|
|
|
|
|
else: |
|
with open(dataset_path, encoding="utf-8") as f: |
|
text = f.read() |
|
|
|
tokens = self._tokenize(text) |
|
|
|
|
|
if overlap >= chunk_size: |
|
overlap = chunk_size-2 |
|
|
|
|
|
start = 0 |
|
while start < tokens.size(1): |
|
chunk = tokens[:, start:start + chunk_size] |
|
start += chunk_size - overlap |
|
if chunk_truncate is not None: chunk = chunk[:, :chunk_truncate] |
|
self.dataset_chunks.append(chunk) |
|
|
|
|
|
def test(self, chunk_limit = sys.maxsize, lora = None, tag = "", ppl_token = False): |
|
if not self.dataset_chunks: |
|
sys.exit(" xx ERROR: Empty dataset!") |
|
|
|
print(f" -- Testing {min(len(self.dataset_chunks), chunk_limit)} chunks", end="") |
|
sys.stdout.flush() |
|
|
|
logprob_sum = 0.0 |
|
logprob_count = 0 |
|
|
|
chunk_count = 0 |
|
|
|
for chunk in self.dataset_chunks: |
|
|
|
self._begin() |
|
|
|
input_ids = chunk[:, :-1] |
|
target_ids = chunk[:, 1:] |
|
|
|
if ppl_token: |
|
logits_s = [] |
|
for i in range(input_ids.shape[-1]): |
|
logits_t = self._next_logits(input_ids[:, i : i + 1], lora, last_id_only = False) |
|
logits_s.append(logits_t) |
|
logits = torch.cat(logits_s, dim = 1) |
|
else: |
|
logits = self._next_logits(input_ids, lora, last_id_only = False) |
|
|
|
log_probs = F.log_softmax(logits, dim=-1) |
|
token_log_probs = log_probs.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) |
|
|
|
logprob_sum += token_log_probs.sum().item() |
|
logprob_count += target_ids.numel() |
|
|
|
if chunk_count % 10 == 0: |
|
print(".", end = "") |
|
sys.stdout.flush() |
|
|
|
chunk_count += 1 |
|
if chunk_limit and chunk_count >= chunk_limit: |
|
break |
|
|
|
mean_log_prob = logprob_sum / logprob_count |
|
perplexity = math.exp(-mean_log_prob) |
|
|
|
print("") |
|
print(f" ** Perplexity{tag}: {perplexity:.4f}") |
|
|
|
|
|
def add_args(parser): |
|
|
|
parser.add_argument("-ppl", "--perplexity", nargs = '?', const = 'default', metavar = "METHOD", help = "Perplexity benchmark. Optionally specify method: gptq-for-llama, llama.cpp (not yet implemented)") |
|
parser.add_argument("-ppl_ds", "--perplexity_dataset", metavar = "DATAPATH", type = str, help = "Load dataset for perplexity (JSONL if .jsonl, otherwise parses it as raw text)") |
|
parser.add_argument("-ppl_cn", "--perplexity_chunk_num", nargs = "?", type = int, help = "Number of chunks for perplexity benchmark", default = 100) |
|
parser.add_argument("-ppl_cs", "--perplexity_chunk_size", type = int, help = "Size of chunks for perplexity benchmark", default = 2048) |
|
parser.add_argument("-ppl_ct", "--perplexity_chunk_truncate", type = int, help = "Truncated size of chunks for perplexity benchmark", default = 2048) |
|
parser.add_argument("-ppl_co", "--perplexity_chunk_overlap", type = int, help = "Chunk overlap", default = 0) |
|
parser.add_argument("-ppl_cm", "--perplexity_chunk_min", type = int, help = "Minimum chunk length", default = 50) |
|
parser.add_argument("-ppl_key", "--perplexity_json_key", type = str, help = "Key to extract from JSON dataset, default: 'text'", default = "text") |
|
parser.add_argument("-ppl_t", "--perplexity_token", action = "store_true", help = "Run perplexity test on individual tokens, for debug purposes (slow)") |
|
|
|
|
|
def post_parse(args): |
|
|
|
if not args.perplexity: return |
|
|
|
|
|
|
|
if args.perplexity == "gptq-for-llama": |
|
args.perplexity_dataset = "datasets/wikitext2.txt" |
|
args.perplexity_chunk_num = 128 |
|
args.perplexity_chunk_size = 2048 |
|
args.perplexity_chunk_truncate = 2048 |
|
args.perplexity_chunk_overlap = 0 |
|
args.perplexity_chunk_min = 0 |
|
|
|
|
|
|
|
if args.perplexity_dataset is None: args.perplexity_dataset = "datasets/wikitext2_val_sample.jsonl" |
|
|
|
print(f" -- Perplexity:") |
|
print(f" -- - Dataset: {args.perplexity_dataset}") |
|
print(f" -- - Chunks: {args.perplexity_chunk_num}") |
|
print(f" -- - Chunk size: {args.perplexity_chunk_size}" + (f" -> {args.perplexity_chunk_truncate}" if args.perplexity_chunk_truncate is not None else "")) |
|
print(f" -- - Chunk overlap: {args.perplexity_chunk_overlap}") |
|
print(f" -- - Min. chunk size: {args.perplexity_chunk_min}") |
|
print(f" -- - Key: {args.perplexity_json_key}") |
|
if args.perplexity_token: print("f -- - Per-token mode") |
|
|
|
|