File size: 5,954 Bytes
c6a14bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from selective_context_compressor import SCCompressor
from kis import KiSCompressor
from scrl_compressor import SCRLCompressor
from llmlingua_compressor_pro import LLMLinguaCompressor
from typing import List


class PromptCompressor:
    def __init__(self, type: str = 'SCCompressor', lang: str = 'en', model='gpt2', device='cuda', model_dir: str = '',
                 use_auth_token: bool = False, open_api_config: dict = {}, token: str = '',
                 tokenizer_dir: str = "sentence-transformers/paraphrase-distilroberta-base-v2"):
        self.type = type
        if self.type == 'SCCompressor':
            self.compressor = SCCompressor(lang=lang, model=model, device=device)
        elif self.type == 'KiSCompressor':
            self.compressor = KiSCompressor(DEVICE=device, model_dir=model_dir)
        elif self.type == 'LLMLinguaCompressor':
            self.compressor = LLMLinguaCompressor(device_map=device, model_name=model_dir, use_auth_token=use_auth_token, open_api_config=open_api_config, token=token)
        elif self.type == 'LongLLMLinguaCompressor':
            self.compressor = LLMLinguaCompressor(device_map=device, model_name=model_dir, use_auth_token=use_auth_token, open_api_config=open_api_config, token=token)
        elif self.type == 'SCRLCompressor':
            if model_dir:
                self.compressor = SCRLCompressor(model_dir=model_dir, device=device, tokenizer_dir=tokenizer_dir)
            else:
                print("model_dir parameter is required")

    def compressgo(self, original_prompt: str = '', ratio: float = 0.5, level: str = 'phrase',
                   max_length: int = 256, num_beams: int = 4, do_sample: bool = True, num_return_sequences: int = 1,
                   target_index: int = 0, instruction: str = "", question: str = "", target_token: float = -1,
                   iterative_size: int = 200, force_context_ids: List[int] = None, force_context_number: int = None,
                   use_sentence_level_filter: bool = False, use_context_level_filter: bool = True,
                   use_token_level_filter: bool = True, keep_split: bool = False, keep_first_sentence: int = 0,
                   keep_last_sentence: int = 0, keep_sentence_number: int = 0, high_priority_bonus: int = 100,
                   context_budget: str = "+100", token_budget_ratio: float = 1.4, condition_in_question: str = "none",
                   reorder_context: str = "original", dynamic_context_compression_ratio: float = 0.0,
                   condition_compare: bool = False, add_instruction: bool = False, rank_method: str = "llmlingua",
                   concate_question: bool = True,):
        if self.type == 'SCCompressor':
            return self.compressor.compress(original_prompt=original_prompt, ratio=ratio, level=level)
        elif self.type == 'KiSCompressor':
            return self.compressor.compress(original_prompt=original_prompt, ratio=ratio, max_length=max_length, num_beams=num_beams, do_sample=do_sample, num_return_sequences=num_return_sequences, target_index=target_index)
        elif self.type == 'SCRLCompressor':
            return self.compressor.compress(original_prompt=original_prompt, ratio=ratio, max_length=max_length)
        elif self.type == 'LLMLinguaCompressor':
            return self.compressor.compress(context=original_prompt, ratio=ratio, instruction=instruction, question=question, target_token=target_token,
                                            iterative_size=iterative_size, force_context_ids=force_context_ids, force_context_number=force_context_number,
                                            use_token_level_filter=use_token_level_filter, use_context_level_filter=use_context_level_filter,
                                            use_sentence_level_filter=use_sentence_level_filter, keep_split=keep_split, keep_first_sentence=keep_first_sentence,
                                            keep_last_sentence=keep_last_sentence, keep_sentence_number=keep_sentence_number, high_priority_bonus=high_priority_bonus,
                                            context_budget=context_budget, token_budget_ratio=token_budget_ratio, condition_in_question=condition_in_question,
                                            reorder_context = reorder_context, dynamic_context_compression_ratio=dynamic_context_compression_ratio, condition_compare=condition_compare,
                                            add_instruction=add_instruction, rank_method=rank_method, concate_question=concate_question)
        elif self.type == 'LongLLMLinguaCompressor':
            return self.compressor.compress(context=original_prompt, ratio=ratio, instruction=instruction, question=question, target_token=target_token,
                                            iterative_size=iterative_size, force_context_ids=force_context_ids, force_context_number=force_context_number,
                                            use_token_level_filter=use_token_level_filter, use_context_level_filter=use_context_level_filter,
                                            use_sentence_level_filter=use_sentence_level_filter, keep_split=keep_split, keep_first_sentence=keep_first_sentence,
                                            keep_last_sentence=keep_last_sentence, keep_sentence_number=keep_sentence_number, high_priority_bonus=high_priority_bonus,
                                            context_budget=context_budget, token_budget_ratio=token_budget_ratio, condition_in_question=condition_in_question,
                                            reorder_context = reorder_context, dynamic_context_compression_ratio=dynamic_context_compression_ratio, condition_compare=condition_compare,
                                            add_instruction=add_instruction, rank_method=rank_method, concate_question=concate_question)
        else:
            return self.compressor.compress(original_prompt=original_prompt, ratio=ratio)