diff --git "a/llmlingua/prompt_compressor.py" "b/llmlingua/prompt_compressor.py"
new file mode 100644--- /dev/null
+++ "b/llmlingua/prompt_compressor.py"
@@ -0,0 +1,2412 @@
+# Copyright (c) 2023 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+
+import bisect
+import re
+from collections import defaultdict
+from typing import List
+
+import numpy as np
+import torch
+
+import nltk
+import tiktoken
+from transformers import (
+    AutoConfig,
+    AutoModelForCausalLM,
+    AutoModelForTokenClassification,
+    AutoTokenizer,
+)
+import torch.nn.functional as F
+import string
+import copy
+from torch.utils.data import DataLoader
+
+from .utils import TokenClfDataset, seed_everything, is_begin_of_new_word, replace_added_token, get_pure_token
+
+
+class PromptCompressor:
+    """
+    PromptCompressor is designed for compressing prompts based on a given language model.
+
+    This class initializes with the language model and its configuration, preparing it for prompt compression tasks.
+    The PromptCompressor class is versatile and can be adapted for various models and specific requirements in prompt processing.
+    Users can specify different model names and configurations as needed for their particular use case.The architecture is
+    based on the paper "LLMLingua: Compressing Prompts for Accelerated Inference of Large Language Models". Jiang, Huiqiang, Qianhui Wu,
+    Chin-Yew Lin, Yuqing Yang, and Lili Qiu. "Llmlingua: Compressing prompts for accelerated inference of large language models."
+    arXiv preprint arXiv:2310.05736 (2023).
+
+    Args:
+        model_name (str, optional): The name of the language model to be loaded. Default is "NousResearch/Llama-2-7b-hf".
+        device_map (str, optional): The device to load the model onto, e.g., "cuda" for GPU. Default is "cuda".
+        model_config (dict, optional): A dictionary containing the configuration parameters for the model. Default is an empty dictionary.
+        open_api_config (dict, optional): A dictionary containing configuration for openai APIs that may be used in conjunction with the model. Default is an empty dictionary.
+        use_llmlingua2 (bool, optional): Whether to use llmlingua-2 compressor based on the paper 
+            "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression".
+            Zhuoshi Pan, Qianhui Wu, Huiqiang Jiang, Menglin Xia, Xufang Luo, Jue Zhang, Qingwei Lin, Victor Ruhle, Yuqing Yang, Chin-Yew Lin, H. Vicky Zhao, Lili Qiu, Dongmei Zhang. 
+            "LLMLingua-2: Context-Aware Data Distillation for Efficient and Faithful Task-Agnostic Prompt Compression". arXiv preprint arXiv:,
+            Default is True.
+        llmlingua2_config (dict, optional): A dictionary containing the configuration parameters for llmlingua-2. Default is 
+            {
+                "max_batch_size": 50, 
+                "max_force_token": 100, # max number of the tokens which will be forcely preserved
+            }
+    Example:
+        >>> compress_method = PromptCompressor(model_name="xxx/llmlingua-2-xlm-roberta-large-meetingbank", use_llmlingua2=True, )
+        >>> context = ["This is the first context sentence.", "Here is another context sentence."]
+        >>> result = compress_method.compress_prompt(context, use_context_level_filter=True, target_token=5)
+        >>> print(result["compressed_prompt"])
+        # This will print the compressed version of the context.
+
+    Note:
+        The `PromptCompressor` class requires the Hugging Face Transformers library and an appropriate environment to load and run the models.
+    """
+
+    def __init__(
+        self,
+        model_name: str = "NousResearch/Llama-2-7b-hf",
+        device_map: str = "cuda",
+        model_config: dict = {},
+        open_api_config: dict = {},
+        use_llmlingua2: bool = True,
+        llmlingua2_config: dict = {},
+    ):
+        self.model_name = model_name
+        self.use_llmlingua2 = use_llmlingua2
+        self.retrieval_model = None
+        self.retrieval_model_name = None
+        self.open_api_config = open_api_config
+        self.cache_bos_num = 10
+        self.prefix_bos_num = 100
+        self.oai_tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
+
+        self.load_model(model_name, device_map, model_config)
+        if use_llmlingua2:
+            self.init_llmlingua2(**llmlingua2_config)
+
+    def init_llmlingua2(
+        self,
+        max_batch_size: int = 50,
+        max_force_token: int = 100,
+    ):
+
+        seed_everything(42)
+        self.max_batch_size = max_batch_size
+        self.max_seq_len = 512
+        self.max_force_token = max_force_token
+        self.special_tokens = set(self.tokenizer.special_tokens_map.values())
+
+        self.added_tokens = [f"[NEW{i}]" for i in range(max_force_token)]
+        self.tokenizer.add_special_tokens(
+            {"additional_special_tokens": self.added_tokens}
+        )
+        self.model.resize_token_embeddings(len(self.tokenizer))
+
+    def load_model(
+        self, model_name: str, device_map: str = "cuda", model_config: dict = {}
+    ):
+        trust_remote_code = model_config.get("trust_remote_code", True)
+        if "trust_remote_code" not in model_config:
+            model_config["trust_remote_code"] = trust_remote_code
+        config = AutoConfig.from_pretrained(model_name, **model_config)
+        tokenizer = AutoTokenizer.from_pretrained(model_name, **model_config)
+        if model_config.get("pad_to_left", True):
+            tokenizer.padding_side = "left"
+            tokenizer.pad_token_id = (
+                config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
+            )
+        MODEL_CLASS = (
+            AutoModelForTokenClassification
+            if any("ForTokenClassification" in ar for ar in config.architectures)
+            else AutoModelForCausalLM
+        )
+        self.device = (
+            device_map
+            if any(key in device_map for key in ["cuda", "cpu", "mps"])
+            else "cuda"
+        )
+        if "cuda" in device_map or "cpu" in device_map:
+            model = MODEL_CLASS.from_pretrained(
+                model_name,
+                torch_dtype=model_config.get(
+                    "torch_dtype", "auto" if device_map == "cuda" else torch.float32
+                ),
+                device_map=device_map,
+                config=config,
+                ignore_mismatched_sizes=True,
+                **model_config,
+            )
+        else:
+            model = MODEL_CLASS.from_pretrained(
+                model_name,
+                device_map=device_map,
+                torch_dtype=model_config.get("torch_dtype", "auto"),
+                pad_token_id=tokenizer.pad_token_id,
+                **model_config,
+            )
+        self.tokenizer = tokenizer
+        self.model = model
+        self.context_idxs = []
+        self.max_position_embeddings = config.max_position_embeddings
+
+    def get_ppl(
+        self,
+        text: str,
+        granularity: str = "sentence",
+        input_ids=None,
+        attention_mask=None,
+        past_key_values=None,
+        return_kv=False,
+        end=None,
+        condition_mode: str = "none",
+        condition_pos_id: int = 0,
+    ):
+        if input_ids is None:
+            tokenized_text = self.tokenizer(text, return_tensors="pt")
+            input_ids = tokenized_text["input_ids"].to(self.device)
+            attention_mask = tokenized_text["attention_mask"].to(self.device)
+        if past_key_values is not None:
+            past_length = past_key_values[0][0].shape[2]
+        else:
+            past_length = 0
+        if end is None:
+            end = input_ids.shape[1]
+        end = min(end, past_length + self.max_position_embeddings)
+        with torch.no_grad():
+            response = self.model(
+                input_ids[:, past_length:end],
+                attention_mask=attention_mask[:, :end],
+                past_key_values=past_key_values,
+                use_cache=True,
+            )
+            past_key_values = response.past_key_values
+
+        shift_logits = response.logits[..., :-1, :].contiguous()
+        shift_labels = input_ids[..., past_length + 1 : end].contiguous()
+        # Flatten the tokens
+        active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
+        active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
+        active_labels = shift_labels.view(-1)[active]
+        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
+        loss = loss_fct(active_logits, active_labels)
+        if condition_mode == "before":
+            loss = loss[:condition_pos_id]
+        elif condition_mode == "after":
+            loss = loss[condition_pos_id:]
+        res = loss.mean() if granularity == "sentence" else loss
+        return (res, past_key_values) if return_kv else res
+
+    def __call__(self, *args, **kwargs):
+        return self.compress_prompt(*args, **kwargs)
+
+    def structured_compress_prompt(
+        self,
+        context: List[str],
+        instruction: str = "",
+        question: str = "",
+        rate: float = 0.5,
+        target_token: float = -1,
+        iterative_size: int = 200,
+        force_context_ids: List[int] = None,
+        force_context_number: int = None,
+        use_sentence_level_filter: bool = False,
+        use_context_level_filter: bool = True,
+        use_token_level_filter: bool = True,
+        keep_split: bool = False,
+        keep_first_sentence: int = 0,
+        keep_last_sentence: int = 0,
+        keep_sentence_number: int = 0,
+        high_priority_bonus: int = 100,
+        context_budget: str = "+100",
+        token_budget_ratio: float = 1.4,
+        condition_in_question: str = "none",
+        reorder_context: str = "original",
+        dynamic_context_compression_ratio: float = 0.0,
+        condition_compare: bool = False,
+        add_instruction: bool = False,
+        rank_method: str = "llmlingua",
+        concate_question: bool = True,
+    ):
+        """
+        Compresses the given prompt context based on a specified structure.
+
+        Each element of context should be segmented using one or more non-nested '<llmlingua></llmlingua>' tags.
+        Each '<llmlingua>' tag can include optional parameters 'rate' and 'compress' (e.g., '<llmlingua, rate=0.3, compress=True>'),
+        indicating the compression rate for that segment. Default values are 'rate=rate' and 'compress=True'.
+        When 'compress' is set to False, it overrides the 'rate' parameter, resulting in no compression for that segment.
+
+        Args:
+            context (List[str]): List of context strings divided by '<llmlingua></llmlingua>' tags with optional compression settings.
+            instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
+            question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
+            rate (float, optional): The compression rate is defined the same as in paper "Language Modeling Is Compression".
+                Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne, Elliot Catt, Tim Genewein, Christopher Mattern,
+                Jordi Grau-Moya et al. "Language modeling is compression." arXiv preprint arXiv:2309.10668 (2023):
+                .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
+                Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
+                fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
+                to 1.0, representing the target compression rate. ``rate``, is applicable only within the context-level filter
+                and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global rate.
+                However, for segments where no specific rate is defined, the global rate serves as the default value. The final
+                compression rate of the entire text is a composite result of multiple compression rates applied across different sections.
+            target_token (float, optional): The global maximum number of tokens to be achieved. Default is -1, indicating no
+                specific target. The actual number of tokens after compression should generally be less than the specified target_token,
+                but there can be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
+                the sole criterion, overriding the ``rate``. ``target_token``, is applicable only within the context-level
+                filter and the sentence-level filter. In the token-level filter, the rate for each segment overrides the global target token.
+                However, for segments where no specific rate is defined, the global rate calculated from global target token serves
+                as the default value. The final target token of the entire text is a composite result of multiple compression rates
+                applied across different sections.
+            iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
+            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+            force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
+            use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
+            use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
+            use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
+            keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
+            keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
+            keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
+            keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
+            high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
+            context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
+            token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
+            condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
+            reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
+            dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
+            condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
+            add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
+            rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
+            concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
+
+        Returns:
+            dict: A dictionary containing:
+                - "compressed_prompt" (str): The resulting compressed prompt.
+                - "origin_tokens" (int): The original number of tokens in the input.
+                - "compressed_tokens" (int): The number of tokens in the compressed output.
+                - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
+                - "rate" (str): The compression rate achieved, in a human-readable format.
+                - "saving" (str): Estimated savings in GPT-4 token usage.
+        """
+        if not context:
+            context = [" "]
+        if isinstance(context, str):
+            context = [context]
+        context = [
+            self.tokenizer.decode(self.tokenizer(c, add_special_tokens=False).input_ids)
+            for c in context
+        ]
+        context_tokens_length = [self.get_token_length(c) for c in context]
+        instruction_tokens_length, question_tokens_length = self.get_token_length(
+            instruction
+        ), self.get_token_length(question)
+        if target_token == -1:
+            target_token = (
+                (
+                    instruction_tokens_length
+                    + question_tokens_length
+                    + sum(context_tokens_length)
+                )
+                * rate
+                - instruction_tokens_length
+                - (question_tokens_length if concate_question else 0)
+            )
+        else:
+            rate = target_token / sum(context_tokens_length)
+        (
+            context,
+            context_segs,
+            context_segs_rate,
+            context_segs_compress,
+        ) = self.segment_structured_context(context, rate)
+        return self.compress_prompt(
+            context,
+            instruction,
+            question,
+            rate,
+            target_token,
+            iterative_size,
+            force_context_ids,
+            force_context_number,
+            use_sentence_level_filter,
+            use_context_level_filter,
+            use_token_level_filter,
+            keep_split,
+            keep_first_sentence,
+            keep_last_sentence,
+            keep_sentence_number,
+            high_priority_bonus,
+            context_budget,
+            token_budget_ratio,
+            condition_in_question,
+            reorder_context,
+            dynamic_context_compression_ratio,
+            condition_compare,
+            add_instruction,
+            rank_method,
+            concate_question,
+            context_segs=context_segs,
+            context_segs_rate=context_segs_rate,
+            context_segs_compress=context_segs_compress,
+        )
+
+    def compress_prompt(
+        self,
+        context: List[str],
+        instruction: str = "",
+        question: str = "",
+        rate: float = 0.5,
+        target_token: float = -1,
+        iterative_size: int = 200,
+        force_context_ids: List[int] = None,
+        force_context_number: int = None,
+        use_sentence_level_filter: bool = False,
+        use_context_level_filter: bool = True,
+        use_token_level_filter: bool = True,
+        keep_split: bool = False,
+        keep_first_sentence: int = 0,
+        keep_last_sentence: int = 0,
+        keep_sentence_number: int = 0,
+        high_priority_bonus: int = 100,
+        context_budget: str = "+100",
+        token_budget_ratio: float = 1.4,
+        condition_in_question: str = "none",
+        reorder_context: str = "original",
+        dynamic_context_compression_ratio: float = 0.0,
+        condition_compare: bool = False,
+        add_instruction: bool = False,
+        rank_method: str = "llmlingua",
+        concate_question: bool = True,
+        context_segs: List[str] = None,
+        context_segs_rate: List[float] = None,
+        context_segs_compress: List[bool] = None,
+        target_context: int = -1,
+        context_level_rate: float = 1.0,
+        context_level_target_token: int = -1,
+        return_word_label: bool = False,
+        word_sep: str = "\t\t|\t\t",
+        label_sep: str = " ",
+        token_to_word: str = "mean",
+        force_tokens: List[str] = [],
+        force_reserve_digit: bool = False,
+        drop_consecutive: bool = False,
+        chunk_end_tokens: List[str] = [".", "\n"],
+    ):
+        """
+        Compresses the given context.
+
+        Args:
+            context (List[str]): List of context strings that form the basis of the prompt.
+            instruction (str, optional): Additional instruction text to be included in the prompt. Default is an empty string.
+            question (str, optional): A specific question that the prompt is addressing. Default is an empty string.
+            rate (float, optional): The maximum compression rate target to be achieved. The compression rate is defined
+                the same as in paper "Language Modeling Is Compression". Delétang, Grégoire, Anian Ruoss, Paul-Ambroise Duquenne,
+                Elliot Catt, Tim Genewein, Christopher Mattern, Jordi Grau-Moya et al. "Language modeling is compression."
+                arXiv preprint arXiv:2309.10668 (2023):
+                .. math::\text{Compression Rate} = \frac{\text{Compressed Size}}{\text{Raw Size}}
+                Default is 0.5. The actual compression rate is generally lower than the specified target, but there can be
+                fluctuations due to differences in tokenizers. If specified, it should be a float less than or equal
+                to 1.0, representing the target compression rate.
+            target_token (float, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
+                The actual number of tokens after compression should generally be less than the specified target_token, but there can
+                be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
+                the sole criterion, overriding the ``rate``.
+            iterative_size (int, optional): The number of tokens to consider in each iteration of compression. Default is 200.
+            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+            force_context_number (int, optional): The number of context sections to forcibly include. Default is None.
+            use_sentence_level_filter (bool, optional): Whether to apply sentence-level filtering in compression. Default is False.
+            use_context_level_filter (bool, optional): Whether to apply context-level filtering in compression. Default is True.
+            use_token_level_filter (bool, optional): Whether to apply token-level filtering in compression. Default is True.
+            keep_split (bool, optional): Whether to preserve the original separators without compression. Default is False.
+            keep_first_sentence (int, optional): Number of sentences to forcibly preserve from the start of the context. Default is 0.
+            keep_last_sentence (int, optional): Number of sentences to forcibly preserve from the end of the context. Default is 0.
+            keep_sentence_number (int, optional): Total number of sentences to forcibly preserve in the compression. Default is 0.
+            high_priority_bonus (int, optional): Bonus score for high-priority sentences to influence their likelihood of being retained. Default is 100.
+            context_budget (str, optional): Token budget for the context-level filtering, expressed as a string to indicate flexibility. Default is "+100".
+            token_budget_ratio (float, optional): Ratio to adjust token budget during sentence-level filtering. Default is 1.4.
+            condition_in_question (str, optional): Specific condition to apply to question in the context. Default is "none".
+            reorder_context (str, optional): Strategy for reordering context in the compressed result. Default is "original".
+            dynamic_context_compression_ratio (float, optional): Ratio for dynamically adjusting context compression. Default is 0.0.
+            condition_compare (bool, optional): Whether to enable condition comparison during token-level compression. Default is False.
+            add_instruction (bool, optional): Whether to add the instruction to the prompt prefix. Default is False.
+            rank_method (str, optional): Method used for ranking elements during compression. Default is "llmlingua".
+            concate_question (bool, optional): Whether to concatenate the question to the compressed prompt. Default is True.
+
+            target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
+            context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
+            context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
+                Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
+            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+            return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
+            word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
+            label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label.  Default is " ".
+            token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
+            force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
+            force_reserve_digit  (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+            drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. 
+                Default is False.
+            chunk_end_tokens (List[str], optinal): The early stop tokens for segmenting chunk. Default is [".", "\n"],
+        Returns:
+            dict: A dictionary containing:
+                - "compressed_prompt" (str): The resulting compressed prompt.
+                - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt. Only used in llmlingua2.
+                - "fn_labeled_original_prompt" (str): original words along with their labels
+                    indicating whether to reserve in compressed prompt, in the format (word label_sep label)
+                    Only used in llmlingua2 when return_word_label = True.
+                - "origin_tokens" (int): The original number of tokens in the input.
+                - "compressed_tokens" (int): The number of tokens in the compressed output.
+                - "ratio" (str): The compression ratio achieved, calculated as the original token number divided by the token number after compression.
+                - "rate" (str): The compression rate achieved, in a human-readable format.
+                - "saving" (str): Estimated savings in GPT-4 token usage.
+        """
+        if self.use_llmlingua2:
+            return self.compress_prompt_llmlingua2(
+                context,
+                rate=rate,
+                target_token=target_token,
+                use_context_level_filter=use_context_level_filter,
+                use_token_level_filter=use_token_level_filter,
+                target_context=target_context,
+                context_level_rate=context_level_rate,
+                context_level_target_token=context_level_target_token,
+                force_context_ids=force_context_ids,
+                return_word_label=return_word_label,
+                word_sep=word_sep,
+                label_sep=label_sep,
+                token_to_word=token_to_word,
+                force_tokens=force_tokens,
+                force_reserve_digit=force_reserve_digit,
+                drop_consecutive=drop_consecutive,
+                chunk_end_tokens=chunk_end_tokens,
+            )
+        assert (
+            rate <= 1.0
+        ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
+
+        if not context:
+            context = [" "]
+        if isinstance(context, str):
+            context = [context]
+        assert not (
+            rank_method == "longllmlingua" and not question
+        ), "In the LongLLMLingua, it is necessary to set a question."
+        if condition_compare and "_condition" not in condition_in_question:
+            condition_in_question += "_condition"
+        if rank_method == "longllmlingua":
+            if condition_in_question == "none":
+                condition_in_question = "after"
+        elif rank_method == "llmlingua":
+            condition_in_question = (
+                "none"
+                if "_condition" not in condition_in_question
+                else "none_condition"
+            )
+        origin_tokens = len(
+            self.oai_tokenizer.encode(
+                "\n\n".join([instruction] + context + [question]).strip()
+            )
+        )
+        context_tokens_length = [self.get_token_length(c) for c in context]
+        instruction_tokens_length, question_tokens_length = self.get_token_length(
+            instruction
+        ), self.get_token_length(question)
+        if target_token == -1:
+            target_token = (
+                (
+                    instruction_tokens_length
+                    + question_tokens_length
+                    + sum(context_tokens_length)
+                )
+                * rate
+                - instruction_tokens_length
+                - (question_tokens_length if concate_question else 0)
+            )
+        condition_flag = "_condition" in condition_in_question
+        condition_in_question = condition_in_question.replace("_condition", "")
+
+        if len(context) > 1 and use_context_level_filter:
+            context, dynamic_ratio, context_used = self.control_context_budget(
+                context,
+                context_tokens_length,
+                target_token,
+                force_context_ids,
+                force_context_number,
+                question,
+                condition_in_question,
+                reorder_context=reorder_context,
+                dynamic_context_compression_ratio=dynamic_context_compression_ratio,
+                rank_method=rank_method,
+                context_budget=context_budget,
+                context_segs=context_segs,
+                context_segs_rate=context_segs_rate,
+                context_segs_compress=context_segs_compress,
+            )
+            if context_segs is not None:
+                context_segs = [context_segs[idx] for idx in context_used]
+                context_segs_rate = [context_segs_rate[idx] for idx in context_used]
+                context_segs_compress = [
+                    context_segs_compress[idx] for idx in context_used
+                ]
+        else:
+            dynamic_ratio = [0.0] * len(context)
+
+        segments_info = []
+        if use_sentence_level_filter:
+            context, segments_info = self.control_sentence_budget(
+                context,
+                target_token,
+                keep_first_sentence=keep_first_sentence,
+                keep_last_sentence=keep_last_sentence,
+                keep_sentence_number=keep_sentence_number,
+                high_priority_bonus=high_priority_bonus,
+                token_budget_ratio=token_budget_ratio,
+                question=question,
+                condition_in_question=condition_in_question,
+                rank_method=rank_method,
+                context_segs=context_segs,
+                context_segs_rate=context_segs_rate,
+                context_segs_compress=context_segs_compress,
+            )
+        elif context_segs is not None:
+            for context_idx in range(len(context)):
+                segments_info.append(
+                    [
+                        (len(seg_text), seg_rate, seg_compress)
+                        for seg_text, seg_rate, seg_compress in zip(
+                            context_segs[context_idx],
+                            context_segs_rate[context_idx],
+                            context_segs_compress[context_idx],
+                        )
+                    ]
+                )
+        segments_info = [
+            self.concate_segment_info(segment_info) for segment_info in segments_info
+        ]
+
+        if condition_flag:
+            prefix = question + "\n\n" + instruction if add_instruction else question
+            if (
+                self.get_token_length(prefix + "\n\n") + iterative_size * 2
+                > self.max_position_embeddings
+            ):
+                tokens = self.tokenizer(prefix, add_special_tokens=False).input_ids
+                prefix = self.tokenizer.decode(
+                    tokens[: self.prefix_bos_num]
+                    + tokens[
+                        len(tokens)
+                        - self.max_position_embeddings
+                        + 2
+                        + self.prefix_bos_num
+                        + 2 * iterative_size :
+                    ]
+                )
+            start = self.get_prefix_length(prefix + "\n\n", context[0])
+            context = [prefix] + context
+        else:
+            start = 0
+
+        if use_token_level_filter:
+            context = self.iterative_compress_prompt(
+                context,
+                target_token,
+                iterative_size=iterative_size,
+                keep_split=keep_split,
+                start=start,
+                dynamic_ratio=dynamic_ratio,
+                condition_compare=condition_compare,
+                segments_info=segments_info,
+            )
+            compressed_prompt = (
+                self.tokenizer.batch_decode(context[0])[0]
+                .replace("<s> ", "")
+                .replace("<s>", "")
+            )
+        else:
+            if condition_flag:
+                context = context[1:]
+            compressed_prompt = "\n\n".join(context)
+
+        res = []
+        if instruction:
+            res.append(instruction)
+        if compressed_prompt.strip():
+            res.append(compressed_prompt)
+        if question and concate_question:
+            res.append(question)
+
+        compressed_prompt = "\n\n".join(res)
+
+        compressed_tokens = len(self.oai_tokenizer.encode(compressed_prompt))
+        saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
+        ratio = 1 if compressed_tokens == 0 else origin_tokens / compressed_tokens
+        rate = 1 / ratio
+        return {
+            "compressed_prompt": compressed_prompt,
+            "origin_tokens": origin_tokens,
+            "compressed_tokens": compressed_tokens,
+            "ratio": f"{ratio:.1f}x",
+            "rate": f"{rate * 100:.1f}%",
+            "saving": f", Saving ${saving:.1f} in GPT-4.",
+        }
+
+    def compress_prompt_llmlingua2(
+        self,
+        context: List[str],
+        rate: float = 0.5,
+        target_token: int = -1,
+        use_context_level_filter: bool = False,
+        use_token_level_filter: bool = True,
+        target_context: int = -1,
+        context_level_rate: float = 1.0,
+        context_level_target_token: int = -1,
+        force_context_ids: List[int] = [],
+        return_word_label: bool = False,
+        word_sep: str = "\t\t|\t\t",
+        label_sep: str = " ",
+        token_to_word: str = "mean",
+        force_tokens: List[str] = [],
+        force_reserve_digit: bool = False,
+        drop_consecutive: bool = False,
+        chunk_end_tokens: List[str] = [".", "\n"],
+    ):
+        """
+        Compresses the given context, instruction and question.
+
+        Args:
+            context (List[str]): List of context strings that form the basis of the prompt.
+            rate (float, optional): The minimum compression rate target to be achieved. Default is 0.5. The actual compression rate
+                generally exceeds the specified target, but there can be fluctuations due to differences in tokenizers. If specified,
+                it should be a float greater than or equal to 1.0, representing the target compression rate.
+            target_token (int, optional): The maximum number of tokens to be achieved. Default is -1, indicating no specific target.
+                The actual number of tokens after compression should generally be less than the specified target_token, but there can
+                be fluctuations due to differences in tokenizers. If specified, compression will be based on the target_token as
+                the sole criterion, overriding the rate.
+            target_context (int, optional): The maximum number of contexts to be achieved. Default is -1, indicating no specific target.
+                Only used in the coarse-to-fine compression.
+            context_level_rate (float, optional): The minimum compression rate target to be achieved in context level. Default is 1.0.
+                Only used in the coarse-to-fine compression.
+            context_level_target_token (float, optional): The maximum number of tokens to be achieved in context level compression.
+                Default is -1, indicating no specific target. Only used in the coarse-to-fine compression senario.
+            force_context_ids (List[int], optional): List of specific context IDs to always include in the compressed result. Default is None.
+            return_word_label (bool, optional): Whether to return word with corresponding label. Default is False.
+            word_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition words. Default is "\t\t|\t\t".
+            label_sep (str, optional): The sep token used in fn_labeled_original_prompt to partition word and label.  Default is " ".
+            token_to_word (str, optional): How to convert token probability to word probability. Default is "mean".
+            force_tokens (List[str], optional): List of specific tokens to always include in the compressed result. Default is [].
+            force_reserve_digit  (bool, optional): Whether to forcibly reserve tokens that containing digit (0,...,9). Default is False.
+            drop_consecutive (bool, optinal): Whether to drop tokens which are in 'force_tokens' but appears consecutively in compressed prompt. 
+                Default is False.
+            chunk_end_tokens (List[str], optional): The early stop tokens for segmenting chunk. Default is [".", "\n"].
+        Returns:
+            dict: A dictionary containing:
+                - "compressed_prompt" (str): The resulting compressed prompt.
+                - "compressed_prompt_list" (List[str]): List of the resulting compressed prompt.
+                - "fn_labeled_original_prompt" (str): original words along with their labels
+                    indicating whether to reserve in compressed prompt, in the format (word label_sep label)
+                - "origin_tokens" (int): The original number of tokens in the input.
+                - "compressed_tokens" (int): The number of tokens in the compressed output.
+                - "ratio" (str): The compression ratio achieved, in a human-readable format.
+                - "rate" (str): The compression rate achieved, in a human-readable format.
+                - "saving" (str): Estimated savings in GPT-4 token usage.
+
+        """
+        assert len(force_tokens) <= self.max_force_token
+        token_map = {}
+        for i, t in enumerate(force_tokens):
+            if len(self.tokenizer.tokenize(t)) != 1:
+                token_map[t] = self.added_tokens[i]
+        chunk_end_tokens = copy.deepcopy(chunk_end_tokens)
+        for c in chunk_end_tokens:
+            if c in token_map:
+                chunk_end_tokens.append(token_map[c])
+        chunk_end_tokens = set(chunk_end_tokens)
+
+        if type(context) == str:
+            context = [context]
+        context = copy.deepcopy(context)
+
+        if len(context) == 1 and use_context_level_filter:
+            use_context_level_filter = False
+
+        n_original_token = 0
+        context_chunked = []
+        for i in range(len(context)):
+            n_original_token += self.get_token_length(context[i], use_oai_tokenizer=True)
+            for ori_token, new_token in token_map.items():
+                context[i] = context[i].replace(ori_token, new_token)
+            context_chunked.append(self.__chunk_context(context[i], chunk_end_tokens=chunk_end_tokens))
+
+        if use_context_level_filter:
+            # want use_context_level_filter but do not specify any parameters in context level?
+            # we will set context_level_rate = (rate + 1.0) / 2 if specify rate or target_token * 2 if specify target_token
+            if (
+                target_context <= 0
+                and context_level_rate >= 1.0
+                and context_level_target_token <= 0
+            ):
+                if target_token < 0 and rate < 1.0:
+                    context_level_rate = (
+                        (rate + 1.0) / 2 if use_token_level_filter else rate
+                    )
+                    print(
+                        f"set context level compression rate to {context_level_rate}."
+                    )
+                if target_token >= 0:
+                    context_level_target_token = (
+                        target_token * 2 if use_token_level_filter else target_token
+                    )
+                    print(
+                        f"set context level target token to {context_level_target_token}."
+                    )
+
+            if target_context >= 0:
+                context_level_rate = min(target_context / len(context), 1.0)
+                # print(f'override context level compression rate to {context_level_rate} because you specified target_context = {target_context}.')
+            if context_level_target_token >= 0:
+                context_level_rate = min(
+                    context_level_target_token / n_original_token, 1.0
+                )
+                # print(f'override context level compression rate to {context_level_rate} because you specified context_level_target_token = {context_level_target_token}.')
+
+            context_probs, context_words = self.__get_context_prob(
+                context_chunked, 
+                token_to_word=token_to_word, 
+                force_tokens=force_tokens,
+                token_map=token_map, 
+                force_reserve_digit=force_reserve_digit, 
+            )
+
+            threshold = np.percentile(
+                context_probs, int(100 * (1 - context_level_rate))
+            )
+
+            reserved_context = []
+            context_label = [False] * len(context_probs)
+            for i, p in enumerate(context_probs):
+                if p >= threshold or (
+                    force_context_ids is not None and i in force_context_ids
+                ):
+                    reserved_context.append(context_chunked[i])
+                    context_label[i] = True
+            n_reserved_token = 0
+            for chunks in reserved_context:
+                for c in chunks:
+                    n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True)
+            if target_token >= 0:
+                rate = min(target_token / n_reserved_token, 1.0)
+                print(
+                    f"override compression rate to {rate} because you specified target_token = {target_token}."
+                )
+
+            if use_token_level_filter:
+                compressed_context, word_list, word_label_list = self.__compress(
+                    reserved_context,
+                    reduce_rate=max(0, 1 - rate),
+                    token_to_word=token_to_word,
+                    force_tokens=force_tokens,
+                    token_map=token_map, 
+                    force_reserve_digit=force_reserve_digit, 
+                    drop_consecutive=drop_consecutive,
+                )
+            else:
+                compressed_context, word_list, word_label_list = self.__compress(
+                    reserved_context, 
+                    reduce_rate=0, 
+                    token_to_word=token_to_word,
+                    force_tokens=force_tokens,
+                    token_map=token_map, 
+                    force_reserve_digit=force_reserve_digit, 
+                    drop_consecutive=drop_consecutive,
+                )
+                print(
+                    "return the original text because you specify use_token_level_filter=False"
+                )
+
+            n_compressed_token = 0
+            for c in compressed_context:
+                n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
+            saving = (n_original_token - n_compressed_token) * 0.06 / 1000
+            ratio = (
+                1 if n_compressed_token == 0 else n_original_token / n_compressed_token
+            )
+            res = {
+                "compressed_prompt": "\n\n".join(compressed_context),
+                "compressed_prompt_list": compressed_context,
+                "origin_tokens": n_original_token,
+                "compressed_tokens": n_compressed_token,
+                "ratio": f"{ratio:.1f}x",
+                "rate": f"{1 / ratio * 100:.1f}%",
+                "saving": f", Saving ${saving:.1f} in GPT-4.",
+            }
+            if return_word_label:
+                words = []
+                labels = []
+                j = 0
+                for i in range(len(context)):
+                    if context_label[i]:
+                        words.extend(word_list[j])
+                        labels.extend(word_label_list[j])
+                        j += 1
+                    else:
+                        words.extend(context_words[i])
+                        labels.extend([0] * len(context_words[i]))
+                word_label_lines = word_sep.join(
+                    [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
+                )
+                res["fn_labeled_original_prompt"] = word_label_lines
+            return res
+
+        if target_token > 0:
+            rate = min(target_token / n_original_token, 1.0)
+            print(
+                f"override compression rate to {rate} \
+                  because you specified target_token = {target_token}."
+            )
+
+        if use_token_level_filter:
+            compressed_context, word_list, word_label_list = self.__compress(
+                context_chunked,
+                reduce_rate=max(0, 1 - rate),
+                token_to_word=token_to_word,
+                force_tokens=force_tokens,
+                token_map=token_map, 
+                force_reserve_digit=force_reserve_digit, 
+                drop_consecutive=drop_consecutive,
+            )
+        else:
+            compressed_context, word_list, word_label_list = self.__compress(
+                context_chunked, 
+                reduce_rate=0, 
+                token_to_word=token_to_word,
+                force_tokens=force_tokens,
+                token_map=token_map, 
+                force_reserve_digit=force_reserve_digit, 
+                drop_consecutive=drop_consecutive,
+            )
+            print(
+                "return the original text because you specify use_token_level_filter=False"
+            )
+
+        n_compressed_token = 0
+        for c in compressed_context:
+            n_compressed_token += self.get_token_length(c, use_oai_tokenizer=True)
+        saving = (n_original_token - n_compressed_token) * 0.06 / 1000
+        ratio = 1 if n_compressed_token == 0 else n_original_token / n_compressed_token
+        res = {
+            "compressed_prompt": "\n\n".join(compressed_context),
+            "compressed_prompt_list": compressed_context,
+            "origin_tokens": n_original_token,
+            "compressed_tokens": n_compressed_token,
+            "ratio": f"{ratio:.1f}x",
+            "rate": f"{1 / ratio * 100:.1f}%",
+            "saving": f", Saving ${saving:.1f} in GPT-4.",
+        }
+        if return_word_label:
+            words = []
+            labels = []
+            for w_list, l_list in zip(word_list, word_label_list):
+                words.extend(w_list)
+                labels.extend(l_list)
+
+            # new_words = []
+            # new_labels = []
+            # for i in range(len(words)):
+            #     word, label = words[i], labels[i]
+            #     if word in string.punctuation:
+            #         if labels[i-1] == 1 and label == 1 and i > 0:
+            #             new_words[-1] += word
+            #     else:
+            #         new_words.append(word)
+            #         new_labels.append(label)
+            # word_label_lines = word_sep.join([f'{word}{label_sep}{label}' for word, label in zip(new_words, new_labels)])
+
+            word_label_lines = word_sep.join(
+                [f"{word}{label_sep}{label}" for word, label in zip(words, labels)]
+            )
+            res["fn_labeled_original_prompt"] = word_label_lines
+        return res
+
+    def get_token_length(self, text: str, add_special_tokens: bool = True, use_oai_tokenizer: bool = False):
+        if use_oai_tokenizer:
+            return len(self.oai_tokenizer.encode(text))
+        else:
+            return len(
+                self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
+            )
+
+    def get_prefix_length(self, prefix: str, text: str):
+        possible_prefix_token = max(self.get_token_length(prefix, False) - 3, 1)
+        full_input_ids = self.tokenizer(
+            prefix + text[:100], add_special_tokens=False
+        ).input_ids
+        for i in range(possible_prefix_token, len(full_input_ids)):
+            cur_prefix = self.tokenizer.decode(full_input_ids[:i])
+            if cur_prefix == prefix:
+                break
+        assert self.tokenizer.decode(full_input_ids[i:]) == text[:100]
+        return i
+
+    def get_condition_ppl(
+        self,
+        text: str,
+        question: str,
+        condition_in_question: str = "none",
+        granularity: str = "sentence",
+    ):
+        if condition_in_question == "none":
+            return self.get_ppl(text, granularity=granularity)
+        elif condition_in_question == "before":
+            return self.get_ppl(
+                question + text,
+                granularity=granularity,
+                condition_mode="after",
+                condition_pos_id=self.get_token_length(question) - 1,
+            )
+        elif condition_in_question == "after":
+            return self.get_ppl(
+                text + question,
+                granularity=granularity,
+                condition_mode="after",
+                condition_pos_id=self.get_token_length(text) - 1,
+            )
+
+    def get_dynamic_compression_ratio(
+        self,
+        context: list,
+        target_token: float,
+        iterative_size: int,
+        dynamic_ratio: list,
+        start: int,
+        seg_info: List[List[tuple]] = None,
+    ):
+        def get_ratio(base: float, delta: float):
+            return max(min(1, base + delta), 0)
+
+        context_length = [self.get_token_length(ii, False) + 2 for ii in context]
+        if start:
+            context_length = context_length[1:]
+        tau = target_token / (sum(context_length) + 1)
+        res, idx, last, last_target = [], 0, 1, []
+        while idx < len(context_length):
+            if last + context_length[idx] >= iterative_size:
+                last_target.append(
+                    (iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
+                )
+                res.append(last_target)
+                last = last + context_length[idx] - iterative_size
+                if last > iterative_size:
+                    k = last // iterative_size
+                    res.extend(
+                        [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
+                    )
+                    last -= k * iterative_size
+
+                last_target = (
+                    [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
+                )
+            else:
+                last += context_length[idx]
+                last_target.append(
+                    (context_length[idx], get_ratio(tau, dynamic_ratio[idx]))
+                )
+            idx += 1
+        if last_target:
+            res.append(last_target)
+        return res
+
+    def get_structured_dynamic_compression_ratio(
+        self,
+        context: list,
+        iterative_size: int,
+        dynamic_ratio: list,
+        start: int,
+        seg_info: List[List[tuple]] = None,
+    ):
+        if start:
+            pure_context = context[1:]
+        else:
+            pure_context = context
+        global_dynamic_rate, global_dynamic_compress, segments = [], [], []
+        for context_idx, text in enumerate(pure_context):
+            text_seen = 0
+            for seg_idx, (seg_len, seg_rate, seg_compress) in enumerate(
+                seg_info[context_idx]
+            ):
+                seg_text = text[text_seen : text_seen + seg_len]
+                if (
+                    seg_idx == len(seg_info[context_idx]) - 1
+                    and context_idx != len(pure_context) - 1
+                ):
+                    seg_text += "\n\n"
+                segments.append(seg_text)
+                if seg_compress:
+                    global_dynamic_rate.append(seg_rate)
+                else:
+                    global_dynamic_rate.append(1.0)
+                global_dynamic_compress.append(seg_compress)
+                text_seen += seg_len
+        origin_text = "\n\n".join(pure_context)
+        assert len("".join(segments)) == len(origin_text)
+        assert len(segments) == len(global_dynamic_rate) == len(global_dynamic_compress)
+
+        text_input_ids = self.tokenizer(
+            "\n\n".join(context), add_special_tokens=False
+        ).input_ids[start:]
+        assert self.tokenizer.decode(text_input_ids) == origin_text
+        dynamic_compression_ratio = self.token_segment(
+            text_input_ids,
+            iterative_size,
+            segments,
+            global_dynamic_rate,
+            global_dynamic_compress,
+        )
+        return dynamic_compression_ratio
+
+    def token_segment(
+        self,
+        text_input_ids: List[int],
+        iterative_size: int,
+        segments: List[str],
+        global_dynamic_rate: List[float],
+        global_dynamic_compress: List[bool],
+    ):
+        decode_window = 3
+        seg_idx, seg_seen, token_seen_num, last_rate = 0, 0, 0, -1
+        dynamic_compression_rate, local_compresssion_rate = [], []
+        for i in range(len(text_input_ids)):
+            if i < decode_window:
+                id_pre, id_cur = text_input_ids[:i], text_input_ids[: i + 1]
+            else:
+                id_pre, id_cur = (
+                    text_input_ids[i - decode_window + 1 : i],
+                    text_input_ids[i - decode_window + 1 : i + 1],
+                )
+            cur_word = self.tokenizer.decode(id_cur)[
+                len(self.tokenizer.decode(id_pre)) :
+            ]
+            cur_word_len = len(cur_word)
+            if cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen:
+                possible_rate, possible_compress = [], []
+                while (
+                    cur_word_len and cur_word_len >= len(segments[seg_idx]) - seg_seen
+                ):
+                    possible_rate.append(global_dynamic_rate[seg_idx])
+                    possible_compress.append(global_dynamic_compress[seg_idx])
+                    cur_word_len -= len(segments[seg_idx]) - seg_seen
+                    seg_idx += 1
+                    seg_seen = 0
+                if cur_word_len:
+                    possible_rate.append(global_dynamic_rate[seg_idx])
+                    possible_compress.append(global_dynamic_compress[seg_idx])
+                new_rate = 1.0 if False in possible_compress else min(possible_rate)
+            else:
+                new_rate = global_dynamic_rate[seg_idx]
+            if new_rate != last_rate and i - token_seen_num:
+                local_compresssion_rate.append((i - token_seen_num, last_rate))
+                token_seen_num = i
+            last_rate = new_rate
+            seg_seen += cur_word_len
+            if (i + 1) % iterative_size == 0:
+                if token_seen_num != i + 1:
+                    local_compresssion_rate.append((i + 1 - token_seen_num, last_rate))
+                    token_seen_num = i + 1
+                dynamic_compression_rate.append(local_compresssion_rate[:])
+                local_compresssion_rate = []
+        if token_seen_num != len(text_input_ids):
+            local_compresssion_rate.append(
+                (len(text_input_ids) - token_seen_num, last_rate)
+            )
+        if local_compresssion_rate != []:
+            dynamic_compression_rate.append(local_compresssion_rate[:])
+        return dynamic_compression_rate
+
+    def control_context_budget(
+        self,
+        context: List[str],
+        context_tokens_length: List[int],
+        target_token: float,
+        force_context_ids: List[int] = None,
+        force_context_number: int = None,
+        question: str = "",
+        condition_in_question: str = "none",
+        reorder_context: str = "original",
+        dynamic_context_compression_ratio: float = 0.0,
+        rank_method: str = "longllmlingua",
+        context_budget: str = "+100",
+        context_segs: List[List[str]] = None,
+        context_segs_rate: List[List[float]] = None,
+        context_segs_compress: List[List[bool]] = None,
+    ):
+        demostrations_sort = self.get_rank_results(
+            context,
+            question,
+            rank_method,
+            condition_in_question,
+            context_tokens_length,
+        )
+
+        if target_token < 0:
+            target_token = 100
+        target_token = eval("target_token" + context_budget)
+        res = []
+        used = force_context_ids if force_context_ids is not None else []
+        if context_segs is not None:
+            for idx, _ in enumerate(context):
+                if False in context_segs_compress[idx]:
+                    used.append(idx)
+
+        self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)])
+        for idx, _ in demostrations_sort:
+            if idx >= len(context_tokens_length):
+                continue
+            target_token -= context_tokens_length[idx]
+            if idx not in used:
+                used.append(idx)
+            if target_token < 0 or (
+                force_context_number is not None and len(res) >= force_context_number
+            ):
+                break
+        original_used = used
+        if reorder_context == "original":
+            used = sorted(used)
+        elif reorder_context == "two_stage":
+            l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
+                _ for idx, _ in enumerate(used) if idx % 2 == 1
+            ]
+            used = l + r[::-1]
+
+        if dynamic_context_compression_ratio > 0:
+            N = len(used)
+            dynamic_ratio = [
+                i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0
+                for i in range(-(N - 1), N, 2)
+            ][::-1]
+            dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
+            dynamic_ratio = [dynamic_ratio_map[i] for i in used]
+        else:
+            dynamic_ratio = [0.0] * len(used)
+
+        res = [context[idx] for idx in used if idx < len(context)]
+        return res, dynamic_ratio, used
+
+    def control_sentence_budget(
+        self,
+        context: List[str],
+        target_token: float,
+        keep_first_sentence: int = 0,
+        keep_last_sentence: int = 0,
+        keep_sentence_number: int = 0,
+        high_priority_bonus: int = 100,
+        token_budget_ratio: float = 1.4,
+        question: str = "",
+        condition_in_question: str = "none",
+        rank_method: str = "longllmlingua",
+        context_segs: List[List[str]] = None,
+        context_segs_rate: List[List[float]] = None,
+        context_segs_compress: List[List[bool]] = None,
+    ):
+        def keep_sentence(dem_idx: int, sent_keep: int):
+            idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
+            for idx in idxs:
+                sentence_ppl[idx] += high_priority_bonus
+
+        def sync_sentence(segments, text):
+            seg_num = len(segments)
+            new_segments = []
+            text_seen = 0
+            seg_idx, cur_seg_seen = 0, 0
+            for i, s in enumerate(text):
+                while seg_idx < seg_num and s != segments[seg_idx][cur_seg_seen]:
+                    if cur_seg_seen < len(segments[seg_idx]) - 1:
+                        cur_seg_seen += 1
+                        continue
+                    new_segments.append(text[text_seen:i])
+                    text_seen = i
+                    seg_idx += 1
+                    cur_seg_seen = 0
+                cur_seg_seen += 1
+                if seg_idx == seg_num:
+                    break
+                if cur_seg_seen == len(segments[seg_idx]):
+                    new_segments.append(text[text_seen : i + 1])
+                    text_seen = i + 1
+                    seg_idx += 1
+                    cur_seg_seen = 0
+            if text_seen < len(text):
+                new_segments.append(text[text_seen:])
+            assert len("".join(new_segments)) == len(text)
+            return new_segments
+
+        sentences = [nltk.sent_tokenize(c) for c in context]
+        dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
+        for idx_d, s in enumerate(sentences):
+            for _ in s:
+                dem_g[idx_d].add(idx)
+                s2de[idx] = idx_d
+                idx += 1
+
+        if context_segs is not None:
+            context_segs = [
+                sync_sentence(s, "".join(c)) for s, c in zip(context_segs, sentences)
+            ]
+            sen2seg_ratio = {}
+            idx = 0
+            for idx_d, sentences_each_context in enumerate(sentences):
+                segments_length = [len(s) for s in context_segs[idx_d]]
+                seg_idx, cur_seg_seen = 0, 0
+                for sentence in sentences_each_context:
+                    sentence_seg_ratio = []
+                    remain = len(sentence)
+                    while remain:
+                        if segments_length[seg_idx] - cur_seg_seen <= remain:
+                            new_seg_len = segments_length[seg_idx] - cur_seg_seen
+                            sentence_seg_ratio.append(
+                                (
+                                    new_seg_len,
+                                    context_segs_rate[idx_d][seg_idx],
+                                    context_segs_compress[idx_d][seg_idx],
+                                )
+                            )
+                            seg_idx += 1
+                            cur_seg_seen = 0
+                            remain -= new_seg_len
+                        else:
+                            sentence_seg_ratio.append(
+                                (
+                                    remain,
+                                    context_segs_rate[idx_d][seg_idx],
+                                    context_segs_compress[idx_d][seg_idx],
+                                )
+                            )
+                            cur_seg_seen += remain
+                            remain = 0
+                    sen2seg_ratio[idx] = sentence_seg_ratio
+                    idx += 1
+
+        context_sentences = [s for ii in sentences for s in ii]
+        sentence_tokens_length = [
+            self.get_token_length(sentence) for sentence in context_sentences
+        ]
+        N = len(context_sentences)
+        flags = list(range(len(context_sentences)))
+        if len(sentence_tokens_length) == 1:
+            return context
+        if rank_method == "longllmlingua":
+            sentence_ppl = [
+                self.get_condition_ppl(sentence, question, condition_in_question)
+                .cpu()
+                .numpy()
+                .item()
+                for sentence in context_sentences
+            ]
+            if keep_first_sentence:
+                sentence_ppl[:keep_first_sentence] = [
+                    ii + high_priority_bonus
+                    for ii in sentence_ppl[:keep_first_sentence]
+                ]
+            if keep_last_sentence:
+                sentence_ppl[-keep_last_sentence:] = [
+                    ii + high_priority_bonus
+                    for ii in sentence_ppl[-keep_last_sentence:]
+                ]
+            if keep_sentence_number:
+                for dem_idx in range(len(sentences)):
+                    keep_sentence(dem_idx, keep_sentence_number)
+            sort_direct = -1 if condition_in_question == "none" else 1
+            sent_sort = sorted(
+                enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
+            )
+        else:
+            sent_sort = self.get_rank_results(
+                context_sentences,
+                question,
+                rank_method,
+                condition_in_question,
+                [0] * len(context_sentences),
+            )
+
+        sentence_flags = [False] * N
+        if target_token < 0:
+            target_token = 100
+        target_token *= token_budget_ratio
+        res = []
+        for idx, _ in sent_sort:
+            idx = flags[idx]
+            target_token -= sentence_tokens_length[idx]
+            sentence_flags[idx] = True
+            if target_token < 0:
+                break
+
+        if context_segs is not None:
+            for idx in range(N):
+                preserved = [sen_seg_info[2] for sen_seg_info in sen2seg_ratio[idx]]
+                if False in preserved:
+                    sentence_flags[idx] = True
+
+        idx = 0
+        res = []
+        new_segments_info = []
+        for s in sentences:
+            tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
+            res.append("".join(tmp))
+            if context_segs is not None:
+                segment_ratio = []
+                for ii in range(len(s)):
+                    if sentence_flags[idx + ii]:
+                        segment_ratio.extend(sen2seg_ratio[idx + ii])
+                new_segments_info.append(segment_ratio)
+            idx += len(s)
+        if context_segs is not None:
+            new_segments_info = [
+                self.concate_segment_info(segment_info)
+                for segment_info in new_segments_info
+            ]
+        return res, new_segments_info
+
+    def get_compressed_input(
+        self,
+        loss,
+        input_ids,
+        attention_mask,
+        end=200,
+        iterative_size=200,
+        threshold=0.5,
+        keep_flag=None,
+        split_token_id: int = 13,
+        start: int = 0,
+        self_loss=None,
+        self_input_ids=None,
+        self_attention_mask=None,
+    ):
+        if self_loss is not None:
+            need_idx = torch.concat(
+                [
+                    loss[:start] > 0,
+                    self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
+                    loss[:1] > 0,
+                ]
+            )
+        else:
+            need_idx = torch.concat([loss > threshold, loss[:1] > 0])
+        need_idx[end:] = 1
+        need_idx[: end - iterative_size] = 1
+        loss = loss[need_idx[:-1]]
+        if self_loss is not None:
+            if need_idx.shape[0] < self_loss.shape[0] + start + 1:
+                need_idx = torch.cat(
+                    [
+                        need_idx,
+                        torch.ones(
+                            self_loss.shape[0] - need_idx.shape[0] + start + 1,
+                            dtype=torch.bool,
+                        ).to(need_idx.device),
+                    ]
+                )
+            self_loss = self_loss[need_idx[start:-1]]
+
+        if need_idx.shape[0] < input_ids.shape[1]:
+            need_idx = torch.cat(
+                [
+                    need_idx,
+                    torch.ones(
+                        input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
+                    ).to(need_idx.device),
+                ]
+            )
+        elif need_idx.shape[0] > input_ids.shape[1]:
+            need_idx = need_idx[: input_ids.shape[1]]
+
+        if keep_flag is not None:
+            need_idx[keep_flag == 1] = 1
+        last = -1
+        if keep_flag is not None:
+            for ii in range(max(0, end - iterative_size), end):
+                if need_idx[ii] != 1:
+                    continue
+                now = input_ids[0][ii].detach().cpu().item()
+                if (
+                    now == split_token_id
+                    and last == split_token_id
+                    and keep_flag[ii].detach().cpu().item() == 0
+                ):
+                    need_idx[ii] = 0
+                else:
+                    last = now
+        compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
+        compressed_attention_mask = attention_mask[attention_mask == 1][
+            need_idx
+        ].unsqueeze(0)
+
+        if self_loss is not None:
+            self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
+                need_idx[start:]
+            ].unsqueeze(0)
+            self_compressed_attention_mask = self_attention_mask[
+                self_attention_mask == 1
+            ][need_idx[start:]].unsqueeze(0)
+        else:
+            self_compressed_input_ids, self_compressed_attention_mask = None, None
+        if keep_flag is not None:
+            if len(keep_flag) > len(need_idx):
+                keep_flag = torch.cat(
+                    [
+                        keep_flag[:start],
+                        keep_flag[start : len(need_idx) + start][need_idx],
+                        keep_flag[start + len(need_idx) :],
+                    ]
+                )
+            else:
+                keep_flag = keep_flag[need_idx]
+        end -= (need_idx[:end] == 0).sum()
+        return (
+            compressed_input_ids,
+            compressed_attention_mask,
+            keep_flag,
+            end,
+            loss,
+            self_loss,
+            self_compressed_input_ids,
+            self_compressed_attention_mask,
+        )
+
+    def get_estimate_threshold_base_distribution(
+        self, ppl, ratio: float, condition_flag: bool = False
+    ):
+        if ratio == 1.0:
+            return float("-inf")
+        ppl = ppl[ppl != 10000]
+        target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
+        return (
+            ppl.sort(descending=not condition_flag)
+            .values[target_token]
+            .detach()
+            .cpu()
+            .item()
+        )
+
+    def iterative_compress_prompt(
+        self,
+        context: List[str],
+        target_token: float,
+        iterative_size: int = 200,
+        keep_split: bool = False,
+        split_token_id: int = 13,
+        start: int = 0,
+        dynamic_ratio: list = None,
+        condition_compare: bool = False,
+        segments_info: List[List[tuple]] = None,
+    ):
+        if segments_info is None or segments_info == []:
+            iterative_ratios = self.get_dynamic_compression_ratio(
+                context, target_token, iterative_size, dynamic_ratio, start
+            )
+        else:
+            iterative_ratios = self.get_structured_dynamic_compression_ratio(
+                context, iterative_size, dynamic_ratio, start, segments_info
+            )
+        context = "\n\n".join(context)
+        tokenized_text = self.tokenizer(
+            context, return_tensors="pt", add_special_tokens=False
+        )
+        input_ids = tokenized_text["input_ids"].to(self.device)
+        attention_mask = tokenized_text["attention_mask"].to(self.device)
+
+        N = (attention_mask == 1).sum()
+        compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
+        if condition_compare:
+            self_input_ids, self_attention_mask = (
+                input_ids[:, start:],
+                attention_mask[:, start:],
+            )
+            self_compressed_input_ids, self_compressed_attention_mask = (
+                self_input_ids,
+                self_attention_mask,
+            )
+
+        end = min(iterative_size + start, compressed_input_ids.shape[1])
+        threshold, keep_flag = None, None
+        if keep_split:
+            input_ids_numpy = input_ids.cpu().detach().numpy()[0]
+            N = len(input_ids_numpy)
+            keep_flag = [
+                int(
+                    (
+                        ii > 0
+                        and input_ids_numpy[ii] == split_token_id
+                        and input_ids_numpy[ii - 1] == split_token_id
+                    )
+                    or (
+                        ii < N - 1
+                        and input_ids_numpy[ii] == split_token_id
+                        and input_ids_numpy[ii + 1] == split_token_id
+                    )
+                )
+                for ii in range(N)
+            ]
+            keep_flag = torch.tensor(keep_flag).to(self.device)
+        past_key_values, past_loss, ready_end = None, None, 0
+        self_past_key_values, self_past_loss, self_ready_end = None, None, 0
+        pop_compressed_input_ids, pop_self_compressed_input_ids = None, None
+        idx = 0
+        while end <= compressed_input_ids.shape[1]:
+            if end > self.max_position_embeddings and past_key_values is not None:
+                # KV-Cache Compression
+                e, s = end - self.max_position_embeddings, min(
+                    self.cache_bos_num + start, self.max_position_embeddings
+                )
+                if pop_compressed_input_ids is None:
+                    pop_compressed_input_ids = compressed_input_ids[:, :e]
+                else:
+                    pop_compressed_input_ids = torch.cat(
+                        [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1
+                    )
+                compressed_input_ids = compressed_input_ids[:, e:]
+                compressed_attention_mask = compressed_attention_mask[:, e:]
+                past_key_values = [
+                    [
+                        torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
+                        torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
+                    ]
+                    for k, v in past_key_values
+                ]
+                if keep_flag is not None:
+                    keep_flag = keep_flag[e:]
+                end, ready_end = end - e, ready_end - e
+                if condition_compare:
+                    s = min(s, self_past_key_values[0][0].shape[2] - e)
+                    self_ready_end -= e
+                    if pop_self_compressed_input_ids is None:
+                        pop_self_compressed_input_ids = self_compressed_input_ids[:, :e]
+                    else:
+                        pop_self_compressed_input_ids = torch.cat(
+                            [
+                                pop_self_compressed_input_ids,
+                                self_compressed_input_ids[:, :e],
+                            ],
+                            dim=-1,
+                        )
+                    self_compressed_input_ids = self_compressed_input_ids[:, e:]
+                    self_compressed_attention_mask = self_compressed_attention_mask[
+                        :, e:
+                    ]
+                    self_past_key_values = [
+                        [
+                            torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
+                            torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
+                        ]
+                        for k, v in self_past_key_values
+                    ]
+
+            loss, past_key_values = self.get_ppl(
+                "",
+                "token",
+                compressed_input_ids,
+                compressed_attention_mask,
+                past_key_values=past_key_values,
+                return_kv=True,
+                end=end if idx else None,
+            )
+            if loss.shape[0] == 0:
+                break
+            if past_loss is not None:
+                if end - 1 > len(past_loss):
+                    past_loss = torch.cat(
+                        [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
+                    )
+                past_loss[ready_end : end - 1] = loss
+                loss = past_loss
+            else:
+                past_loss = loss
+            if idx:
+                past_key_values = [
+                    [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
+                    for k, v in past_key_values
+                ]
+            else:
+                past_key_values = None
+
+            if condition_compare:
+                self_loss, self_past_key_values = self.get_ppl(
+                    "",
+                    "token",
+                    self_compressed_input_ids,
+                    self_compressed_attention_mask,
+                    past_key_values=self_past_key_values,
+                    return_kv=True,
+                    end=end - start if idx else None,
+                )
+                if self_past_loss is not None:
+                    if end - start - 1 > len(self_past_loss):
+                        self_past_loss = torch.cat(
+                            [
+                                self_past_loss,
+                                torch.zeros_like(self_loss)[
+                                    : end - 1 - start - len(self_past_loss)
+                                ],
+                            ]
+                        )
+                    self_past_loss[self_ready_end : end - start - 1] = self_loss
+                    self_loss = self_past_loss
+                else:
+                    self_past_loss = self_loss
+                if idx:
+                    self_past_key_values = [
+                        [
+                            k[:, :, : end - iterative_size - start],
+                            v[:, :, : end - iterative_size - start],
+                        ]
+                        for k, v in self_past_key_values
+                    ]
+                else:
+                    self_past_key_values = None
+
+                self_ready_end = (
+                    end - start - iterative_size if not (start and idx == 0) else 0
+                )
+            ready_end = end - iterative_size if not (start and idx == 0) else 0
+
+            for delta_end, ratio in iterative_ratios[idx]:
+                loss = past_loss
+                if condition_compare:
+                    self_loss = self_past_loss
+                    threshold = self.get_estimate_threshold_base_distribution(
+                        self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
+                    )
+                else:
+                    threshold = self.get_estimate_threshold_base_distribution(
+                        loss, ratio, False
+                    )
+
+                (
+                    compressed_input_ids,
+                    compressed_attention_mask,
+                    keep_flag,
+                    end,
+                    past_loss,
+                    self_past_loss,
+                    self_compressed_input_ids,
+                    self_compressed_attention_mask,
+                ) = self.get_compressed_input(
+                    loss,
+                    compressed_input_ids,
+                    compressed_attention_mask,
+                    end - iterative_size + delta_end,
+                    iterative_size=delta_end,
+                    threshold=threshold,
+                    keep_flag=keep_flag,
+                    split_token_id=split_token_id,
+                    start=start,
+                    self_loss=self_loss if condition_compare else None,
+                    self_input_ids=(
+                        self_compressed_input_ids if condition_compare else None
+                    ),
+                    self_attention_mask=(
+                        self_compressed_attention_mask if condition_compare else None
+                    ),
+                )
+                end += iterative_size
+            idx += 1
+        if pop_compressed_input_ids is not None:
+            compressed_input_ids = torch.cat(
+                [pop_compressed_input_ids, compressed_input_ids], dim=-1
+            )
+        return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]
+
+    def recover(
+        self,
+        original_prompt: str,
+        compressed_prompt: str,
+        response: str,
+    ):
+        def match_from_compressed(response_word):
+            response_input_ids = self.tokenizer(
+                response_word, add_special_tokens=False
+            )["input_ids"]
+            response_set, response_c = set(response_input_ids), defaultdict(list)
+            for idx in range(M):
+                if original_input_ids[idx] in response_set:
+                    response_c[original_input_ids[idx]].append(idx)
+            res, res_min, res_c = None, float("inf"), 1
+            n = len(response_input_ids)
+            for l in response_c[response_input_ids[0]]:
+                x, y, c = 0, l, 1
+                for x in range(1, n):
+                    idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
+                    if (
+                        idx >= len(response_c[response_input_ids[x]])
+                        or response_c[response_input_ids[x]][idx] - y > 10
+                    ):
+                        continue
+                    c += 1
+                    y = response_c[response_input_ids[x]][idx]
+                if c > res_c:
+                    res_c = c
+                    res_min = y - l + 1
+                    res = (l, y + 1)
+                elif c == res_c and y - l + 1 < res_min:
+                    res_min = y - l + 1
+                    res = (l, y + 1)
+
+            if res is None:
+                return response_word
+            # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
+            #     l -= 1
+            # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
+            #     l -= 1
+            return self.tokenizer.decode(original_input_ids[res[0] : res[1]])
+
+        response_words = response.split(" ")
+
+        original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
+            "input_ids"
+        ]
+        N, M = len(response_words), len(original_input_ids)
+        recovered_response_words = []
+        l = 0
+        while l < N:
+            if response_words[l] not in compressed_prompt:
+                recovered_response_words.append(response_words[l])
+                l += 1
+                continue
+            r = l
+            while (
+                r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
+            ):
+                r += 1
+
+            match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
+            recovered_response_words.append(match_words)
+            l = r + 1
+        return " ".join(recovered_response_words)
+
+    def get_rank_results(
+        self,
+        context: list,
+        question: str,
+        rank_method: str,
+        condition_in_question: str,
+        context_tokens_length: list,
+    ):
+        def get_distance_bm25(corpus, query):
+            from rank_bm25 import BM25Okapi
+
+            tokenized_corpus = [doc.split(" ") for doc in corpus]
+            bm25 = BM25Okapi(tokenized_corpus)
+            tokenized_query = query.split(" ")
+            doc_scores = bm25.get_scores(tokenized_query)
+            idx = [(ii, 0) for ii in (-doc_scores).argsort()]
+            return idx
+
+        def get_distance_gzip(corpus, query):
+            def get_score(x, y):
+                cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
+                cxy = len(gzip.compress(f"{x} {y}".encode()))
+                return (cxy - min(cx, cy)) / max(cx, cy)
+
+            import gzip
+
+            doc_scores = [get_score(doc, query) for doc in corpus]
+            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+            return idx
+
+        def get_distance_sentbert(corpus, query):
+            from sentence_transformers import SentenceTransformer, util
+
+            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+                self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
+                self.retrieval_model_name = rank_method
+            doc_embeds = self.retrieval_model.encode(corpus)
+            query = self.retrieval_model.encode(query)
+            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+            return idx
+
+        def get_distance_openai(corpus, query):
+            import openai
+            from sentence_transformers import util
+
+            openai.api_key = self.open_api_config.get("api_key", "")
+            openai.api_base = self.open_api_config.get(
+                "api_base", "https://api.openai.com/v1"
+            )
+            openai.api_type = self.open_api_config.get("api_type", "open_ai")
+            openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
+            engine = self.open_api_config.get("engine", "text-embedding-ada-002")
+
+            def get_embed(text):
+                return openai.Embedding.create(
+                    input=[text.replace("\n", " ")], engine=engine
+                )["data"][0]["embedding"]
+
+            doc_embeds = [get_embed(i) for i in corpus]
+            query = get_embed(query)
+            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+            return idx
+
+        def get_distance_sentbert_bge(corpus, query):
+            from sentence_transformers import SentenceTransformer, util
+
+            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+                self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5")
+                self.retrieval_model_name = rank_method
+            doc_embeds = self.retrieval_model.encode(
+                [i for i in corpus], normalize_embeddings=True
+            )
+            query = self.retrieval_model.encode(query, normalize_embeddings=True)
+            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+            return idx
+
+        def get_distance_bge_ranker(corpus, query):
+            from transformers import AutoModelForSequenceClassification, AutoTokenizer
+
+            pairs = [[i, query] for i in corpus]
+            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+                tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large")
+                model = (
+                    AutoModelForSequenceClassification.from_pretrained(
+                        "BAAI/bge-reranker-large"
+                    )
+                    .eval()
+                    .to(self.device)
+                )
+                self.retrieval_model = [tokenizer, model]
+                self.retrieval_model_name = rank_method
+            with torch.no_grad():
+                inputs = self.retrieval_model[0](
+                    pairs,
+                    padding=True,
+                    truncation=True,
+                    return_tensors="pt",
+                    max_length=512,
+                ).to(self.device)
+                scores = (
+                    self.retrieval_model[1](**inputs, return_dict=True)
+                    .logits.view(
+                        -1,
+                    )
+                    .float()
+                )
+            idx = [(ii, 0) for ii in np.argsort(-scores.cpu())]
+            return idx
+
+        def get_distance_bge_llmembedder(corpus, query):
+            from transformers import AutoModel, AutoTokenizer
+
+            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+                tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder")
+                model = (
+                    AutoModel.from_pretrained("BAAI/llm-embedder")
+                    .eval()
+                    .to(self.device)
+                )
+                self.retrieval_model = [tokenizer, model]
+                self.retrieval_model_name = rank_method
+
+            instruction_qa_query = (
+                "Represent this query for retrieving relevant documents: "
+            )
+            instruction_qa_key = "Represent this document for retrieval: "
+            queries = [instruction_qa_query + query for _ in corpus]
+            keys = [instruction_qa_key + key for key in corpus]
+            with torch.no_grad():
+                query_inputs = self.retrieval_model[0](
+                    queries,
+                    padding=True,
+                    truncation=True,
+                    return_tensors="pt",
+                    max_length=512,
+                ).to(self.device)
+                key_inputs = self.retrieval_model[0](
+                    keys,
+                    padding=True,
+                    truncation=True,
+                    return_tensors="pt",
+                    max_length=512,
+                ).to(self.device)
+                query_outputs = self.retrieval_model[1](**query_inputs)
+                key_outputs = self.retrieval_model[1](**key_inputs)
+                # CLS pooling
+                query_embeddings = query_outputs.last_hidden_state[:, 0]
+                key_embeddings = key_outputs.last_hidden_state[:, 0]
+                # Normalize
+                query_embeddings = torch.nn.functional.normalize(
+                    query_embeddings, p=2, dim=1
+                )
+                key_embeddings = torch.nn.functional.normalize(
+                    key_embeddings, p=2, dim=1
+                )
+                similarity = query_embeddings @ key_embeddings.T
+            idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())]
+            return idx
+
+        def get_distance_jinza(corpus, query):
+            from numpy.linalg import norm
+
+            from transformers import AutoModel
+
+            def cos_sim(a, b):
+                return (a @ b.T) / (norm(a) * norm(b))
+
+            if self.retrieval_model is None or self.retrieval_model_name != rank_method:
+                model = (
+                    AutoModel.from_pretrained(
+                        "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True
+                    )
+                    .eval()
+                    .to(self.device)
+                )
+                self.retrieval_model = model
+                self.retrieval_model_name = rank_method
+
+            doc_embeds = self.retrieval_model.encode(corpus)
+            query = self.retrieval_model.encode(query)
+            doc_scores = cos_sim(doc_embeds, query)
+            idx = [(ii, 0) for ii in np.argsort(-doc_scores)]
+            return idx
+
+        def get_distance_voyageai(corpus, query):
+            import voyageai
+            from sentence_transformers import util
+
+            voyageai.api_key = self.open_api_config.get("voyageai_api_key", "")
+
+            def get_embed(text):
+                return voyageai.get_embedding(text, model="voyage-01")
+
+            doc_embeds = [get_embed(i) for i in corpus]
+            query = get_embed(query)
+            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
+            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
+            return idx
+
+        def get_distance_cohere(corpus, query):
+            import cohere
+
+            api_key = self.open_api_config.get("cohere_api_key", "")
+            co = cohere.Client(api_key)
+            results = co.rerank(
+                model="rerank-english-v2.0", query=query, documents=corpus, top_n=20
+            )
+            c_map = {jj: ii for ii, jj in enumerate(corpus)}
+            doc_rank = [c_map[ii.document["text"]] for ii in results]
+            idx = [(ii, 0) for ii in doc_rank]
+            return idx
+
+        def get_distance_longllmlingua(corpus, query):
+            context_ppl = [
+                self.get_condition_ppl(
+                    d,
+                    query
+                    + " We can get the answer to this question in the given documents.",
+                    condition_in_question,
+                )
+                - dl * 2 / 250 * 0
+                for d, dl in zip(corpus, context_tokens_length)
+            ]
+            sort_direct = -1 if condition_in_question == "none" else 1
+            ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1])
+            return ys
+
+        method = None
+        if rank_method == "bm25":
+            method = get_distance_bm25
+        elif rank_method == "gzip":
+            method = get_distance_gzip
+        elif rank_method == "sentbert":
+            method = get_distance_sentbert
+        elif rank_method == "openai":
+            method = get_distance_openai
+        elif rank_method in ["longllmlingua", "llmlingua"]:
+            method = get_distance_longllmlingua
+        elif rank_method == "bge":
+            method = get_distance_sentbert_bge
+        elif rank_method == "bge_reranker":
+            method = get_distance_bge_ranker
+        elif rank_method == "bge_llmembedder":
+            method = get_distance_bge_llmembedder
+        elif rank_method == "jinza":
+            method = get_distance_jinza
+        elif rank_method == "voyageai":
+            method = get_distance_voyageai
+        elif rank_method == "cohere":
+            method = get_distance_cohere
+        return method(context, question)
+
+    def segment_structured_context(
+        self,
+        context: List[str],
+        global_rate: float,
+    ):
+        new_context, context_segs, context_segs_rate, context_segs_compress = (
+            [],
+            [],
+            [],
+            [],
+        )
+        for text in context:
+            if not text.startswith("<llmlingua"):
+                text = "<llmlingua>" + text
+            if not text.endswith("</llmlingua>"):
+                text = text + "</llmlingua>"
+
+            # Regular expression to match <llmlingua, rate=x, compress=y>content</llmlingua>, allowing rate and compress in any order
+            pattern = r"<llmlingua\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*>([^<]+)</llmlingua>"
+            matches = re.findall(pattern, text)
+
+            # Extracting segment contents
+            segments = [match[4] for match in matches]
+
+            # Extracting rate and compress, considering their possible positions
+            segs_rate = [
+                float(match[0]) if match[0] else (float(match[2]) if match[2] else None)
+                for match in matches
+            ]
+            segs_compress = [
+                (
+                    match[1] == "True"
+                    if match[1]
+                    else (match[3] == "True" if match[3] else None)
+                )
+                for match in matches
+            ]
+
+            segs_compress = [
+                compress if compress is not None else True for compress in segs_compress
+            ]
+            segs_rate = [
+                rate if rate else (global_rate if compress else 1.0)
+                for rate, compress in zip(segs_rate, segs_compress)
+            ]
+            assert (
+                len(segments) == len(segs_rate) == len(segs_compress)
+            ), "The number of segments, rates, and compress flags should be the same."
+            assert all(
+                seg_rate <= 1.0 for seg_rate in segs_rate
+            ), "Error: 'rate' must not exceed 1.0. The value of 'rate' indicates compression rate and must be within the range [0, 1]."
+
+            new_context.append("".join(segments))
+            context_segs.append(segments)
+            context_segs_rate.append(segs_rate)
+            context_segs_compress.append(segs_compress)
+
+        return new_context, context_segs, context_segs_rate, context_segs_compress
+
+    def concate_segment_info(
+        self,
+        segment_info: List[List[tuple]],
+    ):
+        new_segment_info = []
+        for i, (seg_len, seg_ratio, seg_compress) in enumerate(segment_info):
+            if (
+                new_segment_info
+                and new_segment_info[-1][1] == seg_ratio
+                and new_segment_info[-1][2] == seg_compress
+            ):
+                new_segment_info[-1] = (
+                    new_segment_info[-1][0] + seg_len,
+                    seg_ratio,
+                    seg_compress,
+                )
+            else:
+                new_segment_info.append((seg_len, seg_ratio, seg_compress))
+        return new_segment_info
+
+    def __get_context_prob(
+            self, 
+            context_list: list, 
+            token_to_word="mean",
+            force_tokens: List[str]=[],
+            token_map: dict={}, 
+            force_reserve_digit: bool=False, 
+        ):
+        chunk_list = []
+        for chunks in context_list:
+            for c in chunks:
+                chunk_list.append(c)
+
+        dataset = TokenClfDataset(
+            chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
+        )
+        dataloader = DataLoader(
+            dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
+        )
+
+        chunk_probs = []
+        chunk_words = []
+        with torch.no_grad():
+            for batch in dataloader:
+                ids = batch["ids"].to(self.device, dtype=torch.long)
+                mask = batch["mask"].to(self.device, dtype=torch.long) == 1
+
+                outputs = self.model(input_ids=ids, attention_mask=mask)
+                loss, logits = outputs.loss, outputs.logits
+                probs = F.softmax(logits, dim=-1)
+
+                for j in range(ids.shape[0]):
+                    _probs = probs[j, :, 1] 
+                    _ids = ids[j] 
+                    _mask = mask[j] 
+
+                    active_probs = torch.masked_select(_probs, _mask)
+                    active_ids = torch.masked_select(_ids, _mask)
+
+                    tokens = self.tokenizer.convert_ids_to_tokens(
+                        active_ids.squeeze().tolist()
+                    )
+                    token_probs = [prob for prob in active_probs.cpu().numpy()]
+
+                    (
+                        words,
+                        valid_token_probs,
+                        valid_token_probs_no_force,
+                    ) = self.__merge_token_to_word(
+                        tokens, 
+                        token_probs, 
+                        force_tokens=force_tokens,
+                        token_map=token_map, 
+                        force_reserve_digit=force_reserve_digit, 
+                    )
+                    word_probs_no_force = self.__token_prob_to_word_prob(
+                        valid_token_probs_no_force, convert_mode=token_to_word
+                    )
+
+                    if "xlm-roberta-large" in self.model_name:
+                        for i in range(len(words)):
+                            words[i] = words[i].lstrip("▁")
+                    chunk_words.append(words)
+                    chunk_probs.append(word_probs_no_force)
+
+        prev_idx = 0
+        context_probs = []
+        context_words = []
+        for chunk_list in context_list:
+            n_chunk = len(chunk_list)
+            context_probs.append([])
+            context_words.append([])
+            for i in range(n_chunk):
+                context_probs[-1].extend(chunk_probs[prev_idx + i])
+                context_words[-1].extend(chunk_words[prev_idx + i])
+            prev_idx = prev_idx + n_chunk
+        context_probs = [sum(probs) / len(probs) for probs in context_probs]
+        return context_probs, context_words
+
+    def __chunk_context(self, origin_text, chunk_end_tokens):
+        origin_list = []
+        origin_tokens = self.tokenizer.tokenize(origin_text)
+        n = len(origin_tokens)
+        st = 0
+        while st < n:
+            if st + self.max_seq_len > n - 1:
+                chunk = self.tokenizer.convert_tokens_to_string(origin_tokens[st:n])
+                origin_list.append(chunk)
+                break
+            else:
+                ed = st + self.max_seq_len
+                for j in range(0, ed - st):
+                    if origin_tokens[ed - j] in chunk_end_tokens:
+                        ed = ed - j
+                        break
+                chunk = self.tokenizer.convert_tokens_to_string(
+                    origin_tokens[st : ed + 1]
+                )
+                origin_list.append(chunk)
+                st = ed + 1
+        return origin_list
+
+    def __merge_token_to_word(self, tokens, token_probs, force_tokens, token_map, force_reserve_digit):
+        words = []
+        word_probs = []
+        word_probs_no_force = []
+
+        for token, prob in zip(tokens, token_probs):
+            if token in self.special_tokens:
+                continue
+            # add a new word
+            elif is_begin_of_new_word(token, self.model_name, force_tokens, token_map):
+                pure_token = get_pure_token(token, self.model_name)
+                prob_no_force = prob
+                if pure_token in force_tokens or pure_token in set(token_map.values()):
+                    prob=1.0
+                token = replace_added_token(token, token_map)
+                words.append(token)
+                word_probs.append(
+                    [
+                        1.0
+                        if force_reserve_digit
+                        and bool(re.search(r"\d", token))
+                        else prob
+                    ]
+                )
+                word_probs_no_force.append([prob_no_force])
+            # concatenate with previous token
+            else:
+                pure_token = get_pure_token(token, self.model_name)
+                words[-1] += pure_token
+                word_probs[-1].append(
+                    1.0
+                    if force_reserve_digit
+                    and bool(re.search(r"\d", token))
+                    else prob
+                )
+                word_probs_no_force[-1].append(prob_no_force)
+
+        return words, word_probs, word_probs_no_force
+
+    def __token_prob_to_word_prob(self, token_probs, convert_mode="mean"):
+        if convert_mode == "mean":
+            word_probs = [sum(p) / len(p) for p in token_probs]
+        elif convert_mode == "first":
+            word_probs = [p[0] for p in token_probs]
+        else:
+            raise NotImplementedError()
+
+        return word_probs
+
+    def __compress(
+            self, 
+            context_list: list, 
+            reduce_rate: float=0.5, 
+            token_to_word: str="mean",
+            force_tokens: List[str]=[],
+            token_map: dict={}, 
+            force_reserve_digit: bool=False, 
+            drop_consecutive: bool=False,
+        ):
+        def split_string_to_words(input_string):
+            pattern = r'\b\w+\b|[<>=/!@#$%^&*()?":{}|\\`~;_+-]'
+            result = re.findall(pattern, input_string)
+            return result
+        # print(force_tokens, token_map, force_reserve_digit, drop_consecutive)
+        if reduce_rate <= 0:
+            words, word_labels = [], []
+            for i in range(len(context_list)):
+                chunk_list = context_list[i]
+                chunk_words = []
+                chunk_word_labels = []
+                for j in range(len(chunk_list)):
+                    # replace to original token
+                    for ori_token, new_token in token_map.items():
+                        chunk_list[j] = chunk_list[j].replace(new_token, ori_token)
+                    ws = split_string_to_words(chunk_list[j])
+                    chunk_words.extend(ws)
+                    chunk_word_labels.extend([1 for _ in range(len(ws))])
+                context_list[i] = "".join(chunk_list)
+                words.append(chunk_words)
+                word_labels.append(chunk_word_labels)
+            return context_list, words, word_labels
+
+        chunk_list = []
+        for chunks in context_list:
+            for c in chunks:
+                chunk_list.append(c)
+
+        dataset = TokenClfDataset(
+            chunk_list, tokenizer=self.tokenizer, max_len=self.max_seq_len
+        )
+        dataloader = DataLoader(
+            dataset, batch_size=self.max_batch_size, shuffle=False, drop_last=False
+        )
+
+        compressed_chunk_list = []
+        word_list = []
+        word_label_list = []
+        with torch.no_grad():
+            for batch in dataloader:
+                ids = batch["ids"].to(self.device, dtype=torch.long)
+                mask = batch["mask"].to(self.device, dtype=torch.long) == 1
+
+                outputs = self.model(input_ids=ids, attention_mask=mask)
+                loss, logits = outputs.loss, outputs.logits
+                probs = F.softmax(logits, dim=-1)
+
+                for j in range(ids.shape[0]):
+                    chunk_probs = probs[j, :, 1]
+                    chunk_ids = ids[j]
+                    chunk_mask = mask[j]
+
+                    active_probs = torch.masked_select(chunk_probs, chunk_mask)
+                    active_ids = torch.masked_select(chunk_ids, chunk_mask)
+
+                    tokens = self.tokenizer.convert_ids_to_tokens(
+                        active_ids.squeeze().tolist()
+                    )
+                    token_probs = [prob for prob in active_probs.cpu().numpy()]
+
+                    words, valid_token_probs, _ = self.__merge_token_to_word(
+                        tokens=tokens, 
+                        token_probs=token_probs, 
+                        force_tokens=force_tokens, 
+                        token_map=token_map,
+                        force_reserve_digit=force_reserve_digit,
+                    )
+                    word_probs = self.__token_prob_to_word_prob(
+                        valid_token_probs, convert_mode=token_to_word
+                    )
+
+                    if drop_consecutive:
+                        threshold = np.percentile(word_probs, int(100 * reduce_rate))
+                        is_token_between = False
+                        prev = None
+                        for i, (word, word_prob) in enumerate(zip(words, word_probs)):
+                            if word in force_tokens:
+                                if is_token_between:
+                                    is_token_between = False
+                                elif not is_token_between and word == prev:
+                                    word_probs[i] = 0.0
+                                prev = word
+                            else:
+                                is_token_between |= word_prob > threshold
+
+                    # calculate compression ratio w.r.t. gpt-4 tokenizer
+                    new_token_probs = []
+                    for word, word_prob in zip(words, word_probs):
+                        num_token = len(self.oai_tokenizer.encode(word))
+                        new_token_probs.extend([word_prob for _ in range(num_token)])
+                    threshold = np.percentile(
+                        new_token_probs, int(100 * reduce_rate + 1)
+                    )
+
+                    keep_words = []
+                    word_labels = []
+                    assert len(words) == len(word_probs)
+                    for word, word_porb in zip(words, word_probs):
+                        if word_porb > threshold:
+                            if (
+                                drop_consecutive
+                                and word in force_tokens
+                                and len(keep_words) > 0
+                                and keep_words[-1] == word
+                            ):
+                                word_labels.append(0)
+                            else:
+                                keep_words.append(word)
+                                word_labels.append(1)
+                        else:
+                            word_labels.append(0)
+                    keep_str = self.tokenizer.convert_tokens_to_string(keep_words)
+                    if "xlm-roberta-large" in self.model_name:
+                        for i in range(len(words)):
+                            words[i] = words[i].lstrip("▁")
+
+                    compressed_chunk_list.append(keep_str)
+                    word_list.append(words[:])
+                    word_label_list.append(word_labels[:])
+
+        compressed_context_list = []
+        original_word_list = []
+        original_word_label_list = []
+        prev_idx = 0
+        for chunk_list in context_list:
+            n_chunk = len(chunk_list)
+            compressed_context_list.append(
+                "".join(compressed_chunk_list[prev_idx : prev_idx + n_chunk])
+            )
+            original_word_list.append([])
+            original_word_label_list.append([])
+            for i in range(n_chunk):
+                original_word_list[-1].extend(word_list[prev_idx + i])
+                original_word_label_list[-1].extend(word_label_list[prev_idx + i])
+            prev_idx = prev_idx + n_chunk
+
+        return compressed_context_list, original_word_list, original_word_label_list