Prompt-Compression-Toolbox / longlingua_compressor.py
JerryLiJinyi's picture
Upload 4 files
c6a14bf verified
raw
history blame
45.3 kB
from llmlingua import PromptCompressor
import bisect
from collections import defaultdict
from typing import List
import numpy as np
import torch
import nltk
import tiktoken
import re
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from abs_compressor import AbstractCompressor
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
class LongLLMLinguaCompressor(AbstractCompressor):
def __init__(
self,
model_name: str = "meta-llama/Llama-2-7b-chat-hf",
device_map: str = "cuda",
use_auth_token: bool = False,
open_api_config: dict = {},
):
self.load_model(model_name, device_map, use_auth_token)
self.retrieval_model = None
self.retrieval_model_name = None
self.open_api_config = open_api_config
self.cache_bos_num = 10
def load_model(
self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
):
config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.padding_side = "left"
tokenizer.pad_token_id = (
config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
)
self.device = (
device_map if any(key in device_map for key in ["cuda", "cpu"]) else "cuda"
)
if "cuda" in device_map or "cpu" in device_map:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto" if device_map == "cuda" else torch.float32,
config=config,
ignore_mismatched_sizes=True,
trust_remote_code=True,
token="Your Token here"
).to(device_map)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device_map,
torch_dtype="auto",
pad_token_id=tokenizer.pad_token_id,
offload_folder="/tmp/offload",
offload_state_dict=True,
cache_dir="/tmp/cache",
use_auth_token=use_auth_token,
trust_remote_code=True,
token="Your Token here"
)
self.tokenizer = tokenizer
self.model = model
self.context_idxs = []
self.max_position_embeddings = config.max_position_embeddings
def get_ppl(
self,
text: str,
granularity: str = "sentence",
input_ids=None,
attention_mask=None,
past_key_values=None,
return_kv=False,
end=None,
condition_mode: str = "none",
condition_pos_id: int = 0,
):
if input_ids is None:
tokenized_text = self.tokenizer(text, return_tensors="pt")
input_ids = tokenized_text["input_ids"].to(self.device)
attention_mask = tokenized_text["attention_mask"].to(self.device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
else:
past_length = 0
if end is None:
end = input_ids.shape[1]
end = min(end, past_length + self.max_position_embeddings)
with torch.no_grad():
response = self.model(
input_ids[:, past_length:end],
attention_mask=attention_mask[:, :end],
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = response.past_key_values
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
shift_logits = response.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., past_length + 1 : end].contiguous()
# Flatten the tokens
active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
active_labels = shift_labels.view(-1)[active]
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(active_logits, active_labels)
if condition_mode == "before":
loss = loss[:condition_pos_id]
elif condition_mode == "after":
loss = loss[condition_pos_id:]
res = loss.mean() if granularity == "sentence" else loss
return (res, past_key_values) if return_kv else res
def __call__(self, *args, **kwargs):
return self.compress(*args, **kwargs)
def compress(
self,
context: List[str],
instruction: str = "",
question: str = " ",
ratio: float = 0.5,
target_token: float = -1,
iterative_size: int = 200,
force_context_ids: List[int] = None,
force_context_number: int = None,
use_sentence_level_filter: bool = False,
use_context_level_filter: bool = True,
use_token_level_filter: bool = True,
keep_split: bool = False,
keep_first_sentence: int = 0,
keep_last_sentence: int = 0,
keep_sentence_number: int = 0,
high_priority_bonus: int = 100,
context_budget: str = "+100",
token_budget_ratio: float = 1.4,
condition_in_question: str = "none",
reorder_context: str = "original",
dynamic_context_compression_ratio: float = 0.0,
condition_compare: bool = False,
add_instruction: bool = False,
rank_method: str = "longllmlingua",
concate_question: bool = True,
):
if isinstance(context, str):
context = [context]
assert not (
rank_method == "longllmlingua" and not question
), "In the LongLLMLingua, it is necessary to set a question."
if condition_compare and "_condition" not in condition_in_question:
condition_in_question += "_condition"
if rank_method == "longllmlingua":
if condition_in_question == "none":
condition_in_question = "after"
elif rank_method == "llmlingua":
condition_in_question = (
"none"
if "_condition" not in condition_in_question
else "none_condition"
)
origin_tokens = len(
encoding.encode("\n\n".join([instruction] + context + [question]).strip())
)
context_tokens_length = [self.get_token_length(c) for c in context]
instruction_tokens_length, question_tokens_length = self.get_token_length(
instruction
), self.get_token_length(question)
if target_token == -1:
target_token = (
(
instruction_tokens_length
+ question_tokens_length
+ sum(context_tokens_length)
)
* (1 - ratio)
- instruction_tokens_length
- (question_tokens_length if concate_question else 0)
)
condition_flag = "_condition" in condition_in_question
condition_in_question = condition_in_question.replace("_condition", "")
if len(context) > 1 and use_context_level_filter:
context, dynamic_ratio = self.control_context_budget(
context,
context_tokens_length,
target_token,
force_context_ids,
force_context_number,
question,
condition_in_question,
reorder_context=reorder_context,
dynamic_context_compression_ratio=dynamic_context_compression_ratio,
rank_method=rank_method,
context_budget=context_budget,
)
else:
dynamic_ratio = [0.0] * len(context)
if use_sentence_level_filter:
context = self.control_sentence_budget(
context,
target_token,
keep_first_sentence=keep_first_sentence,
keep_last_sentence=keep_last_sentence,
keep_sentence_number=keep_sentence_number,
high_priority_bonus=high_priority_bonus,
token_budget_ratio=token_budget_ratio,
question=question,
condition_in_question=condition_in_question,
rank_method=rank_method,
)
if condition_flag:
if add_instruction:
context = [question + "\n\n" + instruction] + context
start = self.get_token_length(question + "\n\n" + instruction) + 2
else:
context = [question] + context
start = self.get_token_length(question) + 2
else:
start = 0
if use_token_level_filter:
context = self.iterative_compress_prompt(
context,
target_token,
iterative_size=iterative_size,
keep_split=keep_split,
start=start,
dynamic_ratio=dynamic_ratio,
condition_compare=condition_compare,
)
compressed_prompt = (
self.tokenizer.batch_decode(context[0])[0]
.replace("<s> ", "")
.replace("<s>", "")
)
else:
compressed_prompt = "\n\n".join(context)
if instruction:
compressed_prompt = instruction + "\n\n" + compressed_prompt
if question and concate_question:
compressed_prompt = compressed_prompt + "\n\n" + question
compressed_tokens = len(encoding.encode(compressed_prompt))
saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
return {
"compressed_prompt": compressed_prompt,
"origin_tokens": origin_tokens,
"compressed_tokens": compressed_tokens,
# "ratio": f"{origin_tokens/compressed_tokens:.1f}x",
"ratio": compressed_tokens / origin_tokens,
# "saving": f", Saving ${saving:.1f} in GPT-4.",
}
def get_token_length(self, text: str, add_special_tokens: bool = True):
return len(
self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
)
def get_condition_ppl(
self,
text: str,
question: str,
condition_in_question: str = "none",
granularity: str = "sentence",
):
if condition_in_question == "none":
return self.get_ppl(text, granularity=granularity)
elif condition_in_question == "before":
return self.get_ppl(
question + text,
granularity=granularity,
condition_mode="after",
condition_pos_id=self.get_token_length(question) - 1,
)
elif condition_in_question == "after":
return self.get_ppl(
text + question,
granularity=granularity,
condition_mode="after",
condition_pos_id=self.get_token_length(text) - 1,
)
def get_dynamic_compression_ratio(
self,
context: list,
target_token: float,
iterative_size: int,
dynamic_ratio: list,
start: int,
):
def get_ratio(base: float, delta: float):
return max(min(1, base + delta), 0)
context_length = [self.get_token_length(ii, False) + 2 for ii in context]
if start:
context_length = context_length[1:]
tau = target_token / (sum(context_length) + 1)
res, idx, last, last_target = [], 0, 1, []
while idx < len(context_length):
if last + context_length[idx] >= iterative_size:
last_target.append(
(iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
)
res.append(last_target)
last = last + context_length[idx] - iterative_size
if last > iterative_size:
k = last // iterative_size
res.extend(
[[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
)
last -= k * iterative_size
last_target = (
[(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
)
else:
last += context_length[idx]
last_target.append(
(context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
)
idx += 1
if last_target:
res.append(last_target)
return res
def control_context_budget(
self,
context: List[str],
context_tokens_length: List[int],
target_token: float,
force_context_ids: List[int] = None,
force_context_number: int = None,
question: str = "",
condition_in_question: str = "none",
reorder_context: str = "original",
dynamic_context_compression_ratio: float = 0.0,
rank_method: str = "longllmlingua",
context_budget: str = "+100",
):
if force_context_ids is not None:
return [context[ii] for ii in force_context_ids]
demostrations_sort = self.get_rank_results(
context,
question,
rank_method,
condition_in_question,
context_tokens_length,
)
if target_token < 0:
target_token = 100
target_token = eval("target_token" + context_budget)
res = []
used = force_context_ids if force_context_ids is not None else []
self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
for idx, _ in demostrations_sort:
if idx >= len(context_tokens_length):
continue
target_token -= context_tokens_length[idx]
if idx not in used:
used.append(idx)
if target_token < 0 or (
force_context_number is not None and len(res) >= force_context_number
):
break
original_used = used
if reorder_context == "original":
used = sorted(used)
elif reorder_context == "two_stage":
l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
_ for idx, _ in enumerate(used) if idx % 2 == 1
]
used = l + r[::-1]
if dynamic_context_compression_ratio > 0:
N = len(used)
if condition_in_question:
rank = [
i
for i, _ in self.get_rank_results(
context,
question,
"longllmlingua",
"after",
context_tokens_length,
)
]
used = sorted(used, key=lambda x: rank.index(x))
dynamic_ratio = [
i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
for i in range(-(N - 1), N, 2)
][::-1]
dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
dynamic_ratio = [dynamic_ratio_map[i] for i in used]
else:
dynamic_ratio = [0.0] * len(used)
res = [context[idx] for idx in used if idx < len(context)]
return res, dynamic_ratio
def control_sentence_budget(
self,
context: List[str],
target_token: float,
keep_first_sentence: int = 0,
keep_last_sentence: int = 0,
keep_sentence_number: int = 0,
high_priority_bonus: int = 100,
token_budget_ratio: float = 1.4,
question: str = "",
condition_in_question: str = "none",
rank_method: str = "longllmlingua",
):
def keep_sentence(dem_idx: int, sent_keep: int):
idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
for idx in idxs:
sentence_ppl[idx] += high_priority_bonus
sentences = [nltk.sent_tokenize(c) for c in context]
dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
for idx_d, s in enumerate(sentences):
for _ in s:
dem_g[idx_d].add(idx)
s2de[idx] = idx_d
idx += 1
context_sentences = [s for ii in sentences for s in ii]
sentence_tokens_length = [
self.get_token_length(sentence) for sentence in context_sentences
]
N = len(context_sentences)
flags = list(range(len(context_sentences)))
if len(sentence_tokens_length) == 1:
return context
if rank_method == "longllmlingua":
sentence_ppl = [
self.get_condition_ppl(sentence, question, condition_in_question)
.cpu()
.numpy()
.item()
for sentence in context_sentences
]
if keep_first_sentence:
sentence_ppl[:keep_first_sentence] = [
ii + high_priority_bonus
for ii in sentence_ppl[:keep_first_sentence]
]
if keep_last_sentence:
sentence_ppl[-keep_last_sentence:] = [
ii + high_priority_bonus
for ii in sentence_ppl[-keep_last_sentence:]
]
if keep_sentence_number:
for dem_idx in range(len(sentences)):
keep_sentence(dem_idx, keep_sentence_number)
sort_direct = -1 if condition_in_question == "none" else 1
sent_sort = sorted(
enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
)
else:
sent_sort = self.get_rank_results(
context_sentences,
question,
rank_method,
condition_in_question,
[0] * len(context_sentences),
)
sentence_flags = [False] * N
if target_token < 0:
target_token = 100
target_token *= token_budget_ratio
res = []
for idx, _ in sent_sort:
idx = flags[idx]
target_token -= sentence_tokens_length[idx]
sentence_flags[idx] = True
if target_token < 0:
break
idx = 0
res = []
for s in sentences:
tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
res.append("\n".join(tmp))
idx += len(s)
return res
def get_compressed_input(
self,
loss,
input_ids,
attention_mask,
end=200,
iterative_size=200,
threshold=0.5,
keep_flag=None,
split_token_id: int = 13,
start: int = 0,
self_loss=None,
self_input_ids=None,
self_attention_mask=None,
):
if self_loss is not None:
need_idx = torch.concat(
[
loss[:start] > 0,
self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
loss[:1] > 0,
]
)
else:
need_idx = torch.concat([loss > threshold, loss[:1] > 0])
need_idx[end:] = 1
need_idx[: end - iterative_size] = 1
loss = loss[need_idx[:-1]]
if self_loss is not None:
if need_idx.shape[0] < self_loss.shape[0] + start + 1:
need_idx = torch.cat(
[
need_idx,
torch.ones(
self_loss.shape[0] - need_idx.shape[0] + start + 1,
dtype=torch.bool,
).to(need_idx.device),
]
)
self_loss = self_loss[need_idx[start:-1]]
if need_idx.shape[0] < input_ids.shape[1]:
need_idx = torch.cat(
[
need_idx,
torch.ones(
input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
).to(need_idx.device),
]
)
elif need_idx.shape[0] > input_ids.shape[1]:
need_idx = need_idx[: input_ids.shape[1]]
if keep_flag is not None:
need_idx[keep_flag == 1] = 1
last = -1
if keep_flag is not None:
for ii in range(end - iterative_size, end):
if need_idx[ii] != 1:
continue
now = input_ids[0][ii].detach().cpu().item()
if (
now == split_token_id
and last == split_token_id
and keep_flag[ii].detach().cpu().item() == 0
):
need_idx[ii] = 0
else:
last = now
compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
compressed_attention_mask = attention_mask[attention_mask == 1][
need_idx
].unsqueeze(0)
if self_loss is not None:
self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
need_idx[start:]
].unsqueeze(0)
self_compressed_attention_mask = self_attention_mask[
self_attention_mask == 1
][need_idx[start:]].unsqueeze(0)
else:
self_compressed_input_ids, self_compressed_attention_mask = None, None
if keep_flag is not None:
if len(keep_flag) > len(need_idx):
keep_flag = torch.cat(
[
keep_flag[:start],
keep_flag[start : len(need_idx) + start][need_idx],
keep_flag[start + len(need_idx) :],
]
)
else:
keep_flag = keep_flag[need_idx]
end -= (need_idx[:end] == 0).sum()
return (
compressed_input_ids,
compressed_attention_mask,
keep_flag,
end,
loss,
self_loss,
self_compressed_input_ids,
self_compressed_attention_mask,
)
def get_estimate_threshold_base_distribution(
self, ppl, ratio: float, condition_flag: bool = False
):
ppl = ppl[ppl != 10000]
target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
return (
ppl.sort(descending=not condition_flag)
.values[target_token]
.detach()
.cpu()
.item()
)
def iterative_compress_prompt(
self,
context: List[str],
target_token: float,
iterative_size: int = 200,
keep_split: bool = False,
split_token_id: int = 13,
start: int = 0,
dynamic_ratio: list = None,
condition_compare: bool = False,
):
iterative_ratios = self.get_dynamic_compression_ratio(
context, target_token, iterative_size, dynamic_ratio, start
)
context = "\n\n".join(context)
tokenized_text = self.tokenizer(context, return_tensors="pt")
input_ids = tokenized_text["input_ids"].to(self.device)
attention_mask = tokenized_text["attention_mask"].to(self.device)
N = (attention_mask == 1).sum()
compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
if condition_compare:
self_input_ids, self_attention_mask = (
input_ids[:, start:],
attention_mask[:, start:],
)
self_compressed_input_ids, self_compressed_attention_mask = (
self_input_ids,
self_attention_mask,
)
end = min(iterative_size + start, compressed_input_ids.shape[1])
threshold, keep_flag = None, None
if keep_split:
input_ids_numpy = input_ids.cpu().detach().numpy()[0]
N = len(input_ids_numpy)
keep_flag = [
int(
(
ii > 0
and input_ids_numpy[ii] == split_token_id
and input_ids_numpy[ii - 1] == split_token_id
)
or (
ii < N - 1
and input_ids_numpy[ii] == split_token_id
and input_ids_numpy[ii + 1] == split_token_id
)
)
for ii in range(N)
]
keep_flag = torch.tensor(keep_flag).to(self.device)
past_key_values, past_loss, ready_end = None, None, 0
self_past_key_values, self_past_loss, self_ready_end = None, None, 0
pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
idx = 0
while end <= compressed_input_ids.shape[1]:
if end > self.max_position_embeddings and past_key_values is not None:
# KV-Cache Compression
e, s = end - self.max_position_embeddings, self.cache_bos_num
if pop_compressed_input_ids is None:
pop_compressed_input_ids = compressed_input_ids[:, :e]
else:
pop_compressed_input_ids = torch.cat(
[pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
)
compressed_input_ids = compressed_input_ids[:, e:]
compressed_attention_mask = compressed_attention_mask[:, e:]
past_key_values = [
[
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
]
for k, v in past_key_values
]
end, ready_end = end - e, ready_end - e
if condition_compare:
self_ready_end -= e
if pop_self_compressed_input_ids is None:
pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
else:
pop_self_compressed_input_ids = torch.cat(
[
pop_self_compressed_input_ids,
self_compressed_input_ids[:, :e],
],
dim=-1,
)
self_compressed_input_ids = self_compressed_input_ids[:, e:]
self_compressed_attention_mask = self_compressed_attention_mask[
:, e:
]
self_past_key_values = [
[
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
]
for k, v in self_past_key_values
]
loss, past_key_values = self.get_ppl(
"",
"token",
compressed_input_ids,
compressed_attention_mask,
past_key_values=past_key_values,
return_kv=True,
end=end if idx else None,
)
if past_loss is not None:
if end - 1 > len(past_loss):
past_loss = torch.cat(
[past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
)
past_loss[ready_end : end - 1] = loss
loss = past_loss
else:
past_loss = loss
if idx:
past_key_values = [
[k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
for k, v in past_key_values
]
else:
past_key_values = None
if condition_compare:
self_loss, self_past_key_values = self.get_ppl(
"",
"token",
self_compressed_input_ids,
self_compressed_attention_mask,
past_key_values=self_past_key_values,
return_kv=True,
end=end - start if idx else None,
)
if self_past_loss is not None:
if end - start - 1 > len(self_past_loss):
self_past_loss = torch.cat(
[
self_past_loss,
torch.zeros_like(self_loss)[
: end - 1 - start - len(self_past_loss)
],
]
)
self_past_loss[self_ready_end : end - start - 1] = self_loss
self_loss = self_past_loss
else:
self_past_loss = self_loss
if idx:
self_past_key_values = [
[
k[:, :, : end - iterative_size - start],
v[:, :, : end - iterative_size - start],
]
for k, v in self_past_key_values
]
else:
self_past_key_values = None
self_ready_end = (
end - start - iterative_size if not (start and idx == 0) else 0
)
ready_end = end - iterative_size if not (start and idx == 0) else 0
for delta_end, ratio in iterative_ratios[idx]:
loss = past_loss
if condition_compare:
self_loss = self_past_loss
threshold = self.get_estimate_threshold_base_distribution(
self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
)
else:
threshold = self.get_estimate_threshold_base_distribution(
loss, ratio, False
)
(
compressed_input_ids,
compressed_attention_mask,
keep_flag,
end,
past_loss,
self_past_loss,
self_compressed_input_ids,
self_compressed_attention_mask,
) = self.get_compressed_input(
loss,
compressed_input_ids,
compressed_attention_mask,
end - iterative_size + delta_end,
iterative_size=delta_end,
threshold=threshold,
keep_flag=keep_flag,
split_token_id=split_token_id,
start=start,
self_loss=self_loss if condition_compare else None,
self_input_ids=self_compressed_input_ids
if condition_compare
else None,
self_attention_mask=self_compressed_attention_mask
if condition_compare
else None,
)
end += iterative_size
idx += 1
if pop_compressed_input_ids is not None:
compressed_input_ids = torch.cat(
[pop_compressed_input_ids, compressed_input_ids], dim=-1
)
return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
def recover(
self,
original_prompt: str,
compressed_prompt: str,
response: str,
):
def match_from_compressed(response_word):
response_input_ids = self.tokenizer(
response_word, add_special_tokens=False
)["input_ids"]
response_set, response_c = set(response_input_ids), defaultdict(list)
for idx in range(M):
if original_input_ids[idx] in response_set:
response_c[original_input_ids[idx]].append(idx)
res, res_min, res_c = None, float("inf"), 1
n = len(response_input_ids)
for l in response_c[response_input_ids[0]]:
x, y, c = 0, l, 1
for x in range(1, n):
idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
if (
idx >= len(response_c[response_input_ids[x]])
or response_c[response_input_ids[x]][idx] - y > 10
):
continue
c += 1
y = response_c[response_input_ids[x]][idx]
if c > res_c:
res_c = c
res_min = y - l + 1
res = (l, y + 1)
elif c == res_c and y - l + 1 < res_min:
res_min = y - l + 1
res = (l, y + 1)
if res is None:
return response_word
# while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
# l -= 1
# while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
# l -= 1
return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
response_words = response.split(" ")
original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
"input_ids"
]
N, M = len(response_words), len(original_input_ids)
recovered_response_words = []
l = 0
while l < N:
if response_words[l] not in compressed_prompt:
recovered_response_words.append(response_words[l])
l += 1
continue
r = l
while (
r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
):
r += 1
match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
recovered_response_words.append(match_words)
l = r + 1
return " ".join(recovered_response_words)
def get_rank_results(
self,
context: list,
question: str,
rank_method: str,
condition_in_question: str,
context_tokens_length: list,
):
def get_distance_bm25(corpus, query):
from rank_bm25 import BM25Okapi
tokenized_corpus = [doc.split(" ") for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = query.split(" ")
doc_scores = bm25.get_scores(tokenized_query)
idx = [(ii, 0) for ii in (-doc_scores).argsort()]
return idx
def get_distance_gzip(corpus, query):
def get_score(x, y):
cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
cxy = len(gzip.compress(f"{x} {y}".encode()))
return (cxy - min(cx, cy)) / max(cx, cy)
import gzip
doc_scores = [get_score(doc, query) for doc in corpus]
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_sentbert(corpus, query):
from sentence_transformers import SentenceTransformer, util
if self.retrieval_model is None or self.retrieval_model_name != rank_method:
self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
self.retrieval_model_name = rank_method
doc_embeds = self.retrieval_model.encode(corpus)
query = self.retrieval_model.encode(query)
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_openai(corpus, query):
import openai
from sentence_transformers import util
openai.api_key = self.open_api_config.get("api_key", "")
openai.api_base = self.open_api_config.get(
"api_base", "https://api.openai.com/v1"
)
openai.api_type = self.open_api_config.get("api_type", "open_ai")
openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
engine = self.open_api_config.get("engine", "text-embedding-ada-002")
def get_embed(text):
return openai.Embedding.create(
input=[text.replace("\n", " ")], engine=engine
)["LongBench"][0]["embedding"]
doc_embeds = [get_embed(i) for i in corpus]
query = get_embed(query)
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_sentbert_bge(corpus, query):
from sentence_transformers import SentenceTransformer, util
if self.retrieval_model is None or self.retrieval_model_name != rank_method:
self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
self.retrieval_model_name = rank_method
doc_embeds = self.retrieval_model.encode(
[i for i in corpus], normalize_embeddings=True
)
query = self.retrieval_model.encode(query, normalize_embeddings=True)
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_bge_ranker(corpus, query):
from transformers import AutoModelForSequenceClassification, AutoTokenizer
pairs = [[i, query] for i in corpus]
if self.retrieval_model is None or self.retrieval_model_name != rank_method:
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large")
model = (
AutoModelForSequenceClassification.from_pretrained(
"BAAI/bge-reranker-large"
)
.eval()
.to(self.device)
)
self.retrieval_model = [tokenizer, model]
self.retrieval_model_name = rank_method
with torch.no_grad():
inputs = self.retrieval_model[0](
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
).to(self.device)
scores = (
self.retrieval_model[1](**inputs, return_dict=True)
.logits.view(
-1,
)
.float()
)
idx = [(ii, 0) for ii in np.argsort(-scores.cpu())]
return idx
def get_distance_bge_llmembedder(corpus, query):
from transformers import AutoModel, AutoTokenizer
if self.retrieval_model is None or self.retrieval_model_name != rank_method:
tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
model = (
AutoModel.from_pretrained("BAAI/llm-embedder")
.eval()
.to(self.device)
)
self.retrieval_model = [tokenizer, model]
self.retrieval_model_name = rank_method
instruction_qa_query = (
"Represent this query for retrieving relevant documents: "
)
instruction_qa_key = "Represent this document for retrieval: "
queries = [instruction_qa_query + query for _ in corpus]
keys = [instruction_qa_key + key for key in corpus]
with torch.no_grad():
query_inputs = self.retrieval_model[0](
queries,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
).to(self.device)
key_inputs = self.retrieval_model[0](
keys,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
).to(self.device)
query_outputs = self.retrieval_model[1](**query_inputs)
key_outputs = self.retrieval_model[1](**key_inputs)
# CLS pooling
query_embeddings = query_outputs.last_hidden_state[:, 0]
key_embeddings = key_outputs.last_hidden_state[:, 0]
# Normalize
query_embeddings = torch.nn.functional.normalize(
query_embeddings, p=2, dim=1
)
key_embeddings = torch.nn.functional.normalize(
key_embeddings, p=2, dim=1
)
similarity = query_embeddings @ key_embeddings.T
idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())]
return idx
def get_distance_jinza(corpus, query):
from numpy.linalg import norm
from transformers import AutoModel
def cos_sim(a, b):
return (a @ b.T) / (norm(a) * norm(b))
if self.retrieval_model is None or self.retrieval_model_name != rank_method:
model = (
AutoModel.from_pretrained(
"jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
)
.eval()
.to(self.device)
)
self.retrieval_model = model
self.retrieval_model_name = rank_method
doc_embeds = self.retrieval_model.encode(corpus)
query = self.retrieval_model.encode(query)
doc_scores = cos_sim(doc_embeds, query)
idx = [(ii, 0) for ii in np.argsort(-doc_scores)]
return idx
def get_distance_voyageai(corpus, query):
import voyageai
from sentence_transformers import util
voyageai.api_key = self.open_api_config.get("voyageai_api_key", "")
def get_embed(text):
return voyageai.get_embedding(text, model="voyage-01")
doc_embeds = [get_embed(i) for i in corpus]
query = get_embed(query)
doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
idx = [(ii, 0) for ii in np.argsort(doc_scores)]
return idx
def get_distance_cohere(corpus, query):
import cohere
api_key = self.open_api_config.get("cohere_api_key", "")
co = cohere.Client(api_key)
results = co.rerank(
model="rerank-english-v2.0", query=query, documents=corpus, top_n=20
)
c_map = {jj: ii for ii, jj in enumerate(corpus)}
doc_rank = [c_map[ii.document["text"]] for ii in results]
idx = [(ii, 0) for ii in doc_rank]
return idx
def get_distance_longllmlingua(corpus, query):
context_ppl = [
self.get_condition_ppl(
d,
query
+ " We can get the answer to this question in the given documents.",
condition_in_question,
)
- dl * 2 / 250 * 0
for d, dl in zip(corpus, context_tokens_length)
]
sort_direct = -1 if condition_in_question == "none" else 1
ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
return ys
method = None
if rank_method == "bm25":
method = get_distance_bm25
elif rank_method == "gzip":
method = get_distance_gzip
elif rank_method == "sentbert":
method = get_distance_sentbert
elif rank_method == "openai":
method = get_distance_openai
elif rank_method in ["longllmlingua", "llmlingua"]:
method = get_distance_longllmlingua
elif rank_method == "bge":
method = get_distance_sentbert_bge
elif rank_method == "bge_reranker":
method = get_distance_bge_ranker
elif rank_method == "bge_llmembedder":
method = get_distance_bge_llmembedder
elif rank_method == "jinza":
method = get_distance_jinza
elif rank_method == "voyageai":
method = get_distance_voyageai
elif rank_method == "cohere":
method = get_distance_cohere
return method(context, question)