Spaces:
Runtime error
Runtime error
from llmlingua import PromptCompressor | |
import bisect | |
from collections import defaultdict | |
from typing import List | |
import numpy as np | |
import torch | |
import nltk | |
import tiktoken | |
import re | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
from abs_compressor import AbstractCompressor | |
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") | |
class LongLLMLinguaCompressor(AbstractCompressor): | |
def __init__( | |
self, | |
model_name: str = "meta-llama/Llama-2-7b-chat-hf", | |
device_map: str = "cuda", | |
use_auth_token: bool = False, | |
open_api_config: dict = {}, | |
): | |
self.load_model(model_name, device_map, use_auth_token) | |
self.retrieval_model = None | |
self.retrieval_model_name = None | |
self.open_api_config = open_api_config | |
self.cache_bos_num = 10 | |
def load_model( | |
self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False | |
): | |
config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") | |
tokenizer.padding_side = "left" | |
tokenizer.pad_token_id = ( | |
config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id | |
) | |
self.device = ( | |
device_map if any(key in device_map for key in ["cuda", "cpu"]) else "cuda" | |
) | |
if "cuda" in device_map or "cpu" in device_map: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto" if device_map == "cuda" else torch.float32, | |
config=config, | |
ignore_mismatched_sizes=True, | |
trust_remote_code=True, | |
token="Your Token here" | |
).to(device_map) | |
else: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map=device_map, | |
torch_dtype="auto", | |
pad_token_id=tokenizer.pad_token_id, | |
offload_folder="/tmp/offload", | |
offload_state_dict=True, | |
cache_dir="/tmp/cache", | |
use_auth_token=use_auth_token, | |
trust_remote_code=True, | |
token="Your Token here" | |
) | |
self.tokenizer = tokenizer | |
self.model = model | |
self.context_idxs = [] | |
self.max_position_embeddings = config.max_position_embeddings | |
def get_ppl( | |
self, | |
text: str, | |
granularity: str = "sentence", | |
input_ids=None, | |
attention_mask=None, | |
past_key_values=None, | |
return_kv=False, | |
end=None, | |
condition_mode: str = "none", | |
condition_pos_id: int = 0, | |
): | |
if input_ids is None: | |
tokenized_text = self.tokenizer(text, return_tensors="pt") | |
input_ids = tokenized_text["input_ids"].to(self.device) | |
attention_mask = tokenized_text["attention_mask"].to(self.device) | |
if past_key_values is not None: | |
past_length = past_key_values[0][0].shape[2] | |
else: | |
past_length = 0 | |
if end is None: | |
end = input_ids.shape[1] | |
end = min(end, past_length + self.max_position_embeddings) | |
with torch.no_grad(): | |
response = self.model( | |
input_ids[:, past_length:end], | |
attention_mask=attention_mask[:, :end], | |
past_key_values=past_key_values, | |
use_cache=True, | |
) | |
past_key_values = response.past_key_values | |
loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | |
shift_logits = response.logits[..., :-1, :].contiguous() | |
shift_labels = input_ids[..., past_length + 1 : end].contiguous() | |
# Flatten the tokens | |
active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1) | |
active_logits = shift_logits.view(-1, shift_logits.size(-1))[active] | |
active_labels = shift_labels.view(-1)[active] | |
loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | |
loss = loss_fct(active_logits, active_labels) | |
if condition_mode == "before": | |
loss = loss[:condition_pos_id] | |
elif condition_mode == "after": | |
loss = loss[condition_pos_id:] | |
res = loss.mean() if granularity == "sentence" else loss | |
return (res, past_key_values) if return_kv else res | |
def __call__(self, *args, **kwargs): | |
return self.compress(*args, **kwargs) | |
def compress( | |
self, | |
context: List[str], | |
instruction: str = "", | |
question: str = " ", | |
ratio: float = 0.5, | |
target_token: float = -1, | |
iterative_size: int = 200, | |
force_context_ids: List[int] = None, | |
force_context_number: int = None, | |
use_sentence_level_filter: bool = False, | |
use_context_level_filter: bool = True, | |
use_token_level_filter: bool = True, | |
keep_split: bool = False, | |
keep_first_sentence: int = 0, | |
keep_last_sentence: int = 0, | |
keep_sentence_number: int = 0, | |
high_priority_bonus: int = 100, | |
context_budget: str = "+100", | |
token_budget_ratio: float = 1.4, | |
condition_in_question: str = "none", | |
reorder_context: str = "original", | |
dynamic_context_compression_ratio: float = 0.0, | |
condition_compare: bool = False, | |
add_instruction: bool = False, | |
rank_method: str = "longllmlingua", | |
concate_question: bool = True, | |
): | |
if isinstance(context, str): | |
context = [context] | |
assert not ( | |
rank_method == "longllmlingua" and not question | |
), "In the LongLLMLingua, it is necessary to set a question." | |
if condition_compare and "_condition" not in condition_in_question: | |
condition_in_question += "_condition" | |
if rank_method == "longllmlingua": | |
if condition_in_question == "none": | |
condition_in_question = "after" | |
elif rank_method == "llmlingua": | |
condition_in_question = ( | |
"none" | |
if "_condition" not in condition_in_question | |
else "none_condition" | |
) | |
origin_tokens = len( | |
encoding.encode("\n\n".join([instruction] + context + [question]).strip()) | |
) | |
context_tokens_length = [self.get_token_length(c) for c in context] | |
instruction_tokens_length, question_tokens_length = self.get_token_length( | |
instruction | |
), self.get_token_length(question) | |
if target_token == -1: | |
target_token = ( | |
( | |
instruction_tokens_length | |
+ question_tokens_length | |
+ sum(context_tokens_length) | |
) | |
* (1 - ratio) | |
- instruction_tokens_length | |
- (question_tokens_length if concate_question else 0) | |
) | |
condition_flag = "_condition" in condition_in_question | |
condition_in_question = condition_in_question.replace("_condition", "") | |
if len(context) > 1 and use_context_level_filter: | |
context, dynamic_ratio = self.control_context_budget( | |
context, | |
context_tokens_length, | |
target_token, | |
force_context_ids, | |
force_context_number, | |
question, | |
condition_in_question, | |
reorder_context=reorder_context, | |
dynamic_context_compression_ratio=dynamic_context_compression_ratio, | |
rank_method=rank_method, | |
context_budget=context_budget, | |
) | |
else: | |
dynamic_ratio = [0.0] * len(context) | |
if use_sentence_level_filter: | |
context = self.control_sentence_budget( | |
context, | |
target_token, | |
keep_first_sentence=keep_first_sentence, | |
keep_last_sentence=keep_last_sentence, | |
keep_sentence_number=keep_sentence_number, | |
high_priority_bonus=high_priority_bonus, | |
token_budget_ratio=token_budget_ratio, | |
question=question, | |
condition_in_question=condition_in_question, | |
rank_method=rank_method, | |
) | |
if condition_flag: | |
if add_instruction: | |
context = [question + "\n\n" + instruction] + context | |
start = self.get_token_length(question + "\n\n" + instruction) + 2 | |
else: | |
context = [question] + context | |
start = self.get_token_length(question) + 2 | |
else: | |
start = 0 | |
if use_token_level_filter: | |
context = self.iterative_compress_prompt( | |
context, | |
target_token, | |
iterative_size=iterative_size, | |
keep_split=keep_split, | |
start=start, | |
dynamic_ratio=dynamic_ratio, | |
condition_compare=condition_compare, | |
) | |
compressed_prompt = ( | |
self.tokenizer.batch_decode(context[0])[0] | |
.replace("<s> ", "") | |
.replace("<s>", "") | |
) | |
else: | |
compressed_prompt = "\n\n".join(context) | |
if instruction: | |
compressed_prompt = instruction + "\n\n" + compressed_prompt | |
if question and concate_question: | |
compressed_prompt = compressed_prompt + "\n\n" + question | |
compressed_tokens = len(encoding.encode(compressed_prompt)) | |
saving = (origin_tokens - compressed_tokens) * 0.06 / 1000 | |
return { | |
"compressed_prompt": compressed_prompt, | |
"origin_tokens": origin_tokens, | |
"compressed_tokens": compressed_tokens, | |
# "ratio": f"{origin_tokens/compressed_tokens:.1f}x", | |
"ratio": compressed_tokens / origin_tokens, | |
# "saving": f", Saving ${saving:.1f} in GPT-4.", | |
} | |
def get_token_length(self, text: str, add_special_tokens: bool = True): | |
return len( | |
self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids | |
) | |
def get_condition_ppl( | |
self, | |
text: str, | |
question: str, | |
condition_in_question: str = "none", | |
granularity: str = "sentence", | |
): | |
if condition_in_question == "none": | |
return self.get_ppl(text, granularity=granularity) | |
elif condition_in_question == "before": | |
return self.get_ppl( | |
question + text, | |
granularity=granularity, | |
condition_mode="after", | |
condition_pos_id=self.get_token_length(question) - 1, | |
) | |
elif condition_in_question == "after": | |
return self.get_ppl( | |
text + question, | |
granularity=granularity, | |
condition_mode="after", | |
condition_pos_id=self.get_token_length(text) - 1, | |
) | |
def get_dynamic_compression_ratio( | |
self, | |
context: list, | |
target_token: float, | |
iterative_size: int, | |
dynamic_ratio: list, | |
start: int, | |
): | |
def get_ratio(base: float, delta: float): | |
return max(min(1, base + delta), 0) | |
context_length = [self.get_token_length(ii, False) + 2 for ii in context] | |
if start: | |
context_length = context_length[1:] | |
tau = target_token / (sum(context_length) + 1) | |
res, idx, last, last_target = [], 0, 1, [] | |
while idx < len(context_length): | |
if last + context_length[idx] >= iterative_size: | |
last_target.append( | |
(iterative_size - last, get_ratio(tau, dynamic_ratio[idx])) | |
) | |
res.append(last_target) | |
last = last + context_length[idx] - iterative_size | |
if last > iterative_size: | |
k = last // iterative_size | |
res.extend( | |
[[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k | |
) | |
last -= k * iterative_size | |
last_target = ( | |
[(last, get_ratio(tau, dynamic_ratio[idx]))] if last else [] | |
) | |
else: | |
last += context_length[idx] | |
last_target.append( | |
(context_length[idx], get_ratio(tau, dynamic_ratio[idx])) | |
) | |
idx += 1 | |
if last_target: | |
res.append(last_target) | |
return res | |
def control_context_budget( | |
self, | |
context: List[str], | |
context_tokens_length: List[int], | |
target_token: float, | |
force_context_ids: List[int] = None, | |
force_context_number: int = None, | |
question: str = "", | |
condition_in_question: str = "none", | |
reorder_context: str = "original", | |
dynamic_context_compression_ratio: float = 0.0, | |
rank_method: str = "longllmlingua", | |
context_budget: str = "+100", | |
): | |
if force_context_ids is not None: | |
return [context[ii] for ii in force_context_ids] | |
demostrations_sort = self.get_rank_results( | |
context, | |
question, | |
rank_method, | |
condition_in_question, | |
context_tokens_length, | |
) | |
if target_token < 0: | |
target_token = 100 | |
target_token = eval("target_token" + context_budget) | |
res = [] | |
used = force_context_ids if force_context_ids is not None else [] | |
self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)]) | |
for idx, _ in demostrations_sort: | |
if idx >= len(context_tokens_length): | |
continue | |
target_token -= context_tokens_length[idx] | |
if idx not in used: | |
used.append(idx) | |
if target_token < 0 or ( | |
force_context_number is not None and len(res) >= force_context_number | |
): | |
break | |
original_used = used | |
if reorder_context == "original": | |
used = sorted(used) | |
elif reorder_context == "two_stage": | |
l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [ | |
_ for idx, _ in enumerate(used) if idx % 2 == 1 | |
] | |
used = l + r[::-1] | |
if dynamic_context_compression_ratio > 0: | |
N = len(used) | |
if condition_in_question: | |
rank = [ | |
i | |
for i, _ in self.get_rank_results( | |
context, | |
question, | |
"longllmlingua", | |
"after", | |
context_tokens_length, | |
) | |
] | |
used = sorted(used, key=lambda x: rank.index(x)) | |
dynamic_ratio = [ | |
i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0 | |
for i in range(-(N - 1), N, 2) | |
][::-1] | |
dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)} | |
dynamic_ratio = [dynamic_ratio_map[i] for i in used] | |
else: | |
dynamic_ratio = [0.0] * len(used) | |
res = [context[idx] for idx in used if idx < len(context)] | |
return res, dynamic_ratio | |
def control_sentence_budget( | |
self, | |
context: List[str], | |
target_token: float, | |
keep_first_sentence: int = 0, | |
keep_last_sentence: int = 0, | |
keep_sentence_number: int = 0, | |
high_priority_bonus: int = 100, | |
token_budget_ratio: float = 1.4, | |
question: str = "", | |
condition_in_question: str = "none", | |
rank_method: str = "longllmlingua", | |
): | |
def keep_sentence(dem_idx: int, sent_keep: int): | |
idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep] | |
for idx in idxs: | |
sentence_ppl[idx] += high_priority_bonus | |
sentences = [nltk.sent_tokenize(c) for c in context] | |
dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0 | |
for idx_d, s in enumerate(sentences): | |
for _ in s: | |
dem_g[idx_d].add(idx) | |
s2de[idx] = idx_d | |
idx += 1 | |
context_sentences = [s for ii in sentences for s in ii] | |
sentence_tokens_length = [ | |
self.get_token_length(sentence) for sentence in context_sentences | |
] | |
N = len(context_sentences) | |
flags = list(range(len(context_sentences))) | |
if len(sentence_tokens_length) == 1: | |
return context | |
if rank_method == "longllmlingua": | |
sentence_ppl = [ | |
self.get_condition_ppl(sentence, question, condition_in_question) | |
.cpu() | |
.numpy() | |
.item() | |
for sentence in context_sentences | |
] | |
if keep_first_sentence: | |
sentence_ppl[:keep_first_sentence] = [ | |
ii + high_priority_bonus | |
for ii in sentence_ppl[:keep_first_sentence] | |
] | |
if keep_last_sentence: | |
sentence_ppl[-keep_last_sentence:] = [ | |
ii + high_priority_bonus | |
for ii in sentence_ppl[-keep_last_sentence:] | |
] | |
if keep_sentence_number: | |
for dem_idx in range(len(sentences)): | |
keep_sentence(dem_idx, keep_sentence_number) | |
sort_direct = -1 if condition_in_question == "none" else 1 | |
sent_sort = sorted( | |
enumerate(sentence_ppl), key=lambda x: sort_direct * x[1] | |
) | |
else: | |
sent_sort = self.get_rank_results( | |
context_sentences, | |
question, | |
rank_method, | |
condition_in_question, | |
[0] * len(context_sentences), | |
) | |
sentence_flags = [False] * N | |
if target_token < 0: | |
target_token = 100 | |
target_token *= token_budget_ratio | |
res = [] | |
for idx, _ in sent_sort: | |
idx = flags[idx] | |
target_token -= sentence_tokens_length[idx] | |
sentence_flags[idx] = True | |
if target_token < 0: | |
break | |
idx = 0 | |
res = [] | |
for s in sentences: | |
tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]] | |
res.append("\n".join(tmp)) | |
idx += len(s) | |
return res | |
def get_compressed_input( | |
self, | |
loss, | |
input_ids, | |
attention_mask, | |
end=200, | |
iterative_size=200, | |
threshold=0.5, | |
keep_flag=None, | |
split_token_id: int = 13, | |
start: int = 0, | |
self_loss=None, | |
self_input_ids=None, | |
self_attention_mask=None, | |
): | |
if self_loss is not None: | |
need_idx = torch.concat( | |
[ | |
loss[:start] > 0, | |
self_loss[: loss[start:].shape[0]] - loss[start:] > threshold, | |
loss[:1] > 0, | |
] | |
) | |
else: | |
need_idx = torch.concat([loss > threshold, loss[:1] > 0]) | |
need_idx[end:] = 1 | |
need_idx[: end - iterative_size] = 1 | |
loss = loss[need_idx[:-1]] | |
if self_loss is not None: | |
if need_idx.shape[0] < self_loss.shape[0] + start + 1: | |
need_idx = torch.cat( | |
[ | |
need_idx, | |
torch.ones( | |
self_loss.shape[0] - need_idx.shape[0] + start + 1, | |
dtype=torch.bool, | |
).to(need_idx.device), | |
] | |
) | |
self_loss = self_loss[need_idx[start:-1]] | |
if need_idx.shape[0] < input_ids.shape[1]: | |
need_idx = torch.cat( | |
[ | |
need_idx, | |
torch.ones( | |
input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool | |
).to(need_idx.device), | |
] | |
) | |
elif need_idx.shape[0] > input_ids.shape[1]: | |
need_idx = need_idx[: input_ids.shape[1]] | |
if keep_flag is not None: | |
need_idx[keep_flag == 1] = 1 | |
last = -1 | |
if keep_flag is not None: | |
for ii in range(end - iterative_size, end): | |
if need_idx[ii] != 1: | |
continue | |
now = input_ids[0][ii].detach().cpu().item() | |
if ( | |
now == split_token_id | |
and last == split_token_id | |
and keep_flag[ii].detach().cpu().item() == 0 | |
): | |
need_idx[ii] = 0 | |
else: | |
last = now | |
compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0) | |
compressed_attention_mask = attention_mask[attention_mask == 1][ | |
need_idx | |
].unsqueeze(0) | |
if self_loss is not None: | |
self_compressed_input_ids = self_input_ids[self_attention_mask == 1][ | |
need_idx[start:] | |
].unsqueeze(0) | |
self_compressed_attention_mask = self_attention_mask[ | |
self_attention_mask == 1 | |
][need_idx[start:]].unsqueeze(0) | |
else: | |
self_compressed_input_ids, self_compressed_attention_mask = None, None | |
if keep_flag is not None: | |
if len(keep_flag) > len(need_idx): | |
keep_flag = torch.cat( | |
[ | |
keep_flag[:start], | |
keep_flag[start : len(need_idx) + start][need_idx], | |
keep_flag[start + len(need_idx) :], | |
] | |
) | |
else: | |
keep_flag = keep_flag[need_idx] | |
end -= (need_idx[:end] == 0).sum() | |
return ( | |
compressed_input_ids, | |
compressed_attention_mask, | |
keep_flag, | |
end, | |
loss, | |
self_loss, | |
self_compressed_input_ids, | |
self_compressed_attention_mask, | |
) | |
def get_estimate_threshold_base_distribution( | |
self, ppl, ratio: float, condition_flag: bool = False | |
): | |
ppl = ppl[ppl != 10000] | |
target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1)) | |
return ( | |
ppl.sort(descending=not condition_flag) | |
.values[target_token] | |
.detach() | |
.cpu() | |
.item() | |
) | |
def iterative_compress_prompt( | |
self, | |
context: List[str], | |
target_token: float, | |
iterative_size: int = 200, | |
keep_split: bool = False, | |
split_token_id: int = 13, | |
start: int = 0, | |
dynamic_ratio: list = None, | |
condition_compare: bool = False, | |
): | |
iterative_ratios = self.get_dynamic_compression_ratio( | |
context, target_token, iterative_size, dynamic_ratio, start | |
) | |
context = "\n\n".join(context) | |
tokenized_text = self.tokenizer(context, return_tensors="pt") | |
input_ids = tokenized_text["input_ids"].to(self.device) | |
attention_mask = tokenized_text["attention_mask"].to(self.device) | |
N = (attention_mask == 1).sum() | |
compressed_input_ids, compressed_attention_mask = input_ids, attention_mask | |
if condition_compare: | |
self_input_ids, self_attention_mask = ( | |
input_ids[:, start:], | |
attention_mask[:, start:], | |
) | |
self_compressed_input_ids, self_compressed_attention_mask = ( | |
self_input_ids, | |
self_attention_mask, | |
) | |
end = min(iterative_size + start, compressed_input_ids.shape[1]) | |
threshold, keep_flag = None, None | |
if keep_split: | |
input_ids_numpy = input_ids.cpu().detach().numpy()[0] | |
N = len(input_ids_numpy) | |
keep_flag = [ | |
int( | |
( | |
ii > 0 | |
and input_ids_numpy[ii] == split_token_id | |
and input_ids_numpy[ii - 1] == split_token_id | |
) | |
or ( | |
ii < N - 1 | |
and input_ids_numpy[ii] == split_token_id | |
and input_ids_numpy[ii + 1] == split_token_id | |
) | |
) | |
for ii in range(N) | |
] | |
keep_flag = torch.tensor(keep_flag).to(self.device) | |
past_key_values, past_loss, ready_end = None, None, 0 | |
self_past_key_values, self_past_loss, self_ready_end = None, None, 0 | |
pop_compressed_input_ids, pop_self_compressed_input_ids = None, None | |
idx = 0 | |
while end <= compressed_input_ids.shape[1]: | |
if end > self.max_position_embeddings and past_key_values is not None: | |
# KV-Cache Compression | |
e, s = end - self.max_position_embeddings, self.cache_bos_num | |
if pop_compressed_input_ids is None: | |
pop_compressed_input_ids = compressed_input_ids[:, :e] | |
else: | |
pop_compressed_input_ids = torch.cat( | |
[pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1 | |
) | |
compressed_input_ids = compressed_input_ids[:, e:] | |
compressed_attention_mask = compressed_attention_mask[:, e:] | |
past_key_values = [ | |
[ | |
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), | |
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), | |
] | |
for k, v in past_key_values | |
] | |
end, ready_end = end - e, ready_end - e | |
if condition_compare: | |
self_ready_end -= e | |
if pop_self_compressed_input_ids is None: | |
pop_self_compressed_input_ids = self_compressed_input_ids[:, :e] | |
else: | |
pop_self_compressed_input_ids = torch.cat( | |
[ | |
pop_self_compressed_input_ids, | |
self_compressed_input_ids[:, :e], | |
], | |
dim=-1, | |
) | |
self_compressed_input_ids = self_compressed_input_ids[:, e:] | |
self_compressed_attention_mask = self_compressed_attention_mask[ | |
:, e: | |
] | |
self_past_key_values = [ | |
[ | |
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), | |
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), | |
] | |
for k, v in self_past_key_values | |
] | |
loss, past_key_values = self.get_ppl( | |
"", | |
"token", | |
compressed_input_ids, | |
compressed_attention_mask, | |
past_key_values=past_key_values, | |
return_kv=True, | |
end=end if idx else None, | |
) | |
if past_loss is not None: | |
if end - 1 > len(past_loss): | |
past_loss = torch.cat( | |
[past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]] | |
) | |
past_loss[ready_end : end - 1] = loss | |
loss = past_loss | |
else: | |
past_loss = loss | |
if idx: | |
past_key_values = [ | |
[k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]] | |
for k, v in past_key_values | |
] | |
else: | |
past_key_values = None | |
if condition_compare: | |
self_loss, self_past_key_values = self.get_ppl( | |
"", | |
"token", | |
self_compressed_input_ids, | |
self_compressed_attention_mask, | |
past_key_values=self_past_key_values, | |
return_kv=True, | |
end=end - start if idx else None, | |
) | |
if self_past_loss is not None: | |
if end - start - 1 > len(self_past_loss): | |
self_past_loss = torch.cat( | |
[ | |
self_past_loss, | |
torch.zeros_like(self_loss)[ | |
: end - 1 - start - len(self_past_loss) | |
], | |
] | |
) | |
self_past_loss[self_ready_end : end - start - 1] = self_loss | |
self_loss = self_past_loss | |
else: | |
self_past_loss = self_loss | |
if idx: | |
self_past_key_values = [ | |
[ | |
k[:, :, : end - iterative_size - start], | |
v[:, :, : end - iterative_size - start], | |
] | |
for k, v in self_past_key_values | |
] | |
else: | |
self_past_key_values = None | |
self_ready_end = ( | |
end - start - iterative_size if not (start and idx == 0) else 0 | |
) | |
ready_end = end - iterative_size if not (start and idx == 0) else 0 | |
for delta_end, ratio in iterative_ratios[idx]: | |
loss = past_loss | |
if condition_compare: | |
self_loss = self_past_loss | |
threshold = self.get_estimate_threshold_base_distribution( | |
self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False | |
) | |
else: | |
threshold = self.get_estimate_threshold_base_distribution( | |
loss, ratio, False | |
) | |
( | |
compressed_input_ids, | |
compressed_attention_mask, | |
keep_flag, | |
end, | |
past_loss, | |
self_past_loss, | |
self_compressed_input_ids, | |
self_compressed_attention_mask, | |
) = self.get_compressed_input( | |
loss, | |
compressed_input_ids, | |
compressed_attention_mask, | |
end - iterative_size + delta_end, | |
iterative_size=delta_end, | |
threshold=threshold, | |
keep_flag=keep_flag, | |
split_token_id=split_token_id, | |
start=start, | |
self_loss=self_loss if condition_compare else None, | |
self_input_ids=self_compressed_input_ids | |
if condition_compare | |
else None, | |
self_attention_mask=self_compressed_attention_mask | |
if condition_compare | |
else None, | |
) | |
end += iterative_size | |
idx += 1 | |
if pop_compressed_input_ids is not None: | |
compressed_input_ids = torch.cat( | |
[pop_compressed_input_ids, compressed_input_ids], dim=-1 | |
) | |
return compressed_input_ids[:, start:], compressed_attention_mask[:, start:] | |
def recover( | |
self, | |
original_prompt: str, | |
compressed_prompt: str, | |
response: str, | |
): | |
def match_from_compressed(response_word): | |
response_input_ids = self.tokenizer( | |
response_word, add_special_tokens=False | |
)["input_ids"] | |
response_set, response_c = set(response_input_ids), defaultdict(list) | |
for idx in range(M): | |
if original_input_ids[idx] in response_set: | |
response_c[original_input_ids[idx]].append(idx) | |
res, res_min, res_c = None, float("inf"), 1 | |
n = len(response_input_ids) | |
for l in response_c[response_input_ids[0]]: | |
x, y, c = 0, l, 1 | |
for x in range(1, n): | |
idx = bisect.bisect_right(response_c[response_input_ids[x]], y) | |
if ( | |
idx >= len(response_c[response_input_ids[x]]) | |
or response_c[response_input_ids[x]][idx] - y > 10 | |
): | |
continue | |
c += 1 | |
y = response_c[response_input_ids[x]][idx] | |
if c > res_c: | |
res_c = c | |
res_min = y - l + 1 | |
res = (l, y + 1) | |
elif c == res_c and y - l + 1 < res_min: | |
res_min = y - l + 1 | |
res = (l, y + 1) | |
if res is None: | |
return response_word | |
# while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"): | |
# l -= 1 | |
# while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"): | |
# l -= 1 | |
return self.tokenizer.decode(original_input_ids[res[0] : res[1]]) | |
response_words = response.split(" ") | |
original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[ | |
"input_ids" | |
] | |
N, M = len(response_words), len(original_input_ids) | |
recovered_response_words = [] | |
l = 0 | |
while l < N: | |
if response_words[l] not in compressed_prompt: | |
recovered_response_words.append(response_words[l]) | |
l += 1 | |
continue | |
r = l | |
while ( | |
r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt | |
): | |
r += 1 | |
match_words = match_from_compressed(" ".join(response_words[l : r + 1])) | |
recovered_response_words.append(match_words) | |
l = r + 1 | |
return " ".join(recovered_response_words) | |
def get_rank_results( | |
self, | |
context: list, | |
question: str, | |
rank_method: str, | |
condition_in_question: str, | |
context_tokens_length: list, | |
): | |
def get_distance_bm25(corpus, query): | |
from rank_bm25 import BM25Okapi | |
tokenized_corpus = [doc.split(" ") for doc in corpus] | |
bm25 = BM25Okapi(tokenized_corpus) | |
tokenized_query = query.split(" ") | |
doc_scores = bm25.get_scores(tokenized_query) | |
idx = [(ii, 0) for ii in (-doc_scores).argsort()] | |
return idx | |
def get_distance_gzip(corpus, query): | |
def get_score(x, y): | |
cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode())) | |
cxy = len(gzip.compress(f"{x} {y}".encode())) | |
return (cxy - min(cx, cy)) / max(cx, cy) | |
import gzip | |
doc_scores = [get_score(doc, query) for doc in corpus] | |
idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
return idx | |
def get_distance_sentbert(corpus, query): | |
from sentence_transformers import SentenceTransformer, util | |
if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") | |
self.retrieval_model_name = rank_method | |
doc_embeds = self.retrieval_model.encode(corpus) | |
query = self.retrieval_model.encode(query) | |
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
return idx | |
def get_distance_openai(corpus, query): | |
import openai | |
from sentence_transformers import util | |
openai.api_key = self.open_api_config.get("api_key", "") | |
openai.api_base = self.open_api_config.get( | |
"api_base", "https://api.openai.com/v1" | |
) | |
openai.api_type = self.open_api_config.get("api_type", "open_ai") | |
openai.api_version = self.open_api_config.get("api_version", "2023-05-15") | |
engine = self.open_api_config.get("engine", "text-embedding-ada-002") | |
def get_embed(text): | |
return openai.Embedding.create( | |
input=[text.replace("\n", " ")], engine=engine | |
)["LongBench"][0]["embedding"] | |
doc_embeds = [get_embed(i) for i in corpus] | |
query = get_embed(query) | |
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
return idx | |
def get_distance_sentbert_bge(corpus, query): | |
from sentence_transformers import SentenceTransformer, util | |
if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5") | |
self.retrieval_model_name = rank_method | |
doc_embeds = self.retrieval_model.encode( | |
[i for i in corpus], normalize_embeddings=True | |
) | |
query = self.retrieval_model.encode(query, normalize_embeddings=True) | |
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
return idx | |
def get_distance_bge_ranker(corpus, query): | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
pairs = [[i, query] for i in corpus] | |
if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") | |
model = ( | |
AutoModelForSequenceClassification.from_pretrained( | |
"BAAI/bge-reranker-large" | |
) | |
.eval() | |
.to(self.device) | |
) | |
self.retrieval_model = [tokenizer, model] | |
self.retrieval_model_name = rank_method | |
with torch.no_grad(): | |
inputs = self.retrieval_model[0]( | |
pairs, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512, | |
).to(self.device) | |
scores = ( | |
self.retrieval_model[1](**inputs, return_dict=True) | |
.logits.view( | |
-1, | |
) | |
.float() | |
) | |
idx = [(ii, 0) for ii in np.argsort(-scores.cpu())] | |
return idx | |
def get_distance_bge_llmembedder(corpus, query): | |
from transformers import AutoModel, AutoTokenizer | |
if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder") | |
model = ( | |
AutoModel.from_pretrained("BAAI/llm-embedder") | |
.eval() | |
.to(self.device) | |
) | |
self.retrieval_model = [tokenizer, model] | |
self.retrieval_model_name = rank_method | |
instruction_qa_query = ( | |
"Represent this query for retrieving relevant documents: " | |
) | |
instruction_qa_key = "Represent this document for retrieval: " | |
queries = [instruction_qa_query + query for _ in corpus] | |
keys = [instruction_qa_key + key for key in corpus] | |
with torch.no_grad(): | |
query_inputs = self.retrieval_model[0]( | |
queries, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512, | |
).to(self.device) | |
key_inputs = self.retrieval_model[0]( | |
keys, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
max_length=512, | |
).to(self.device) | |
query_outputs = self.retrieval_model[1](**query_inputs) | |
key_outputs = self.retrieval_model[1](**key_inputs) | |
# CLS pooling | |
query_embeddings = query_outputs.last_hidden_state[:, 0] | |
key_embeddings = key_outputs.last_hidden_state[:, 0] | |
# Normalize | |
query_embeddings = torch.nn.functional.normalize( | |
query_embeddings, p=2, dim=1 | |
) | |
key_embeddings = torch.nn.functional.normalize( | |
key_embeddings, p=2, dim=1 | |
) | |
similarity = query_embeddings @ key_embeddings.T | |
idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())] | |
return idx | |
def get_distance_jinza(corpus, query): | |
from numpy.linalg import norm | |
from transformers import AutoModel | |
def cos_sim(a, b): | |
return (a @ b.T) / (norm(a) * norm(b)) | |
if self.retrieval_model is None or self.retrieval_model_name != rank_method: | |
model = ( | |
AutoModel.from_pretrained( | |
"jinaai/jina-embeddings-v2-base-en", trust_remote_code=True | |
) | |
.eval() | |
.to(self.device) | |
) | |
self.retrieval_model = model | |
self.retrieval_model_name = rank_method | |
doc_embeds = self.retrieval_model.encode(corpus) | |
query = self.retrieval_model.encode(query) | |
doc_scores = cos_sim(doc_embeds, query) | |
idx = [(ii, 0) for ii in np.argsort(-doc_scores)] | |
return idx | |
def get_distance_voyageai(corpus, query): | |
import voyageai | |
from sentence_transformers import util | |
voyageai.api_key = self.open_api_config.get("voyageai_api_key", "") | |
def get_embed(text): | |
return voyageai.get_embedding(text, model="voyage-01") | |
doc_embeds = [get_embed(i) for i in corpus] | |
query = get_embed(query) | |
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) | |
idx = [(ii, 0) for ii in np.argsort(doc_scores)] | |
return idx | |
def get_distance_cohere(corpus, query): | |
import cohere | |
api_key = self.open_api_config.get("cohere_api_key", "") | |
co = cohere.Client(api_key) | |
results = co.rerank( | |
model="rerank-english-v2.0", query=query, documents=corpus, top_n=20 | |
) | |
c_map = {jj: ii for ii, jj in enumerate(corpus)} | |
doc_rank = [c_map[ii.document["text"]] for ii in results] | |
idx = [(ii, 0) for ii in doc_rank] | |
return idx | |
def get_distance_longllmlingua(corpus, query): | |
context_ppl = [ | |
self.get_condition_ppl( | |
d, | |
query | |
+ " We can get the answer to this question in the given documents.", | |
condition_in_question, | |
) | |
- dl * 2 / 250 * 0 | |
for d, dl in zip(corpus, context_tokens_length) | |
] | |
sort_direct = -1 if condition_in_question == "none" else 1 | |
ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1]) | |
return ys | |
method = None | |
if rank_method == "bm25": | |
method = get_distance_bm25 | |
elif rank_method == "gzip": | |
method = get_distance_gzip | |
elif rank_method == "sentbert": | |
method = get_distance_sentbert | |
elif rank_method == "openai": | |
method = get_distance_openai | |
elif rank_method in ["longllmlingua", "llmlingua"]: | |
method = get_distance_longllmlingua | |
elif rank_method == "bge": | |
method = get_distance_sentbert_bge | |
elif rank_method == "bge_reranker": | |
method = get_distance_bge_ranker | |
elif rank_method == "bge_llmembedder": | |
method = get_distance_bge_llmembedder | |
elif rank_method == "jinza": | |
method = get_distance_jinza | |
elif rank_method == "voyageai": | |
method = get_distance_voyageai | |
elif rank_method == "cohere": | |
method = get_distance_cohere | |
return method(context, question) | |