Spaces:
Runtime error
Runtime error
File size: 1,708 Bytes
10b912d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
from SCRL_new.scrl.model import load_model
from transformers import AutoTokenizer
import re
from abs_compressor import AbstractCompressor
class SCRLCompressor(AbstractCompressor):
def __init__(self, model_dir: str, device: str = "cpu", tokenizer_dir: str = "sentence-transformers/paraphrase-distilroberta-base-v2"):
self.model_dir = model_dir
self.device = device
self.model = load_model(self.model_dir, self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
def compress(self, original_prompt: str, ratio: float = 0.5, max_length: int = 256) -> dict:
original_tokens = len(self.gpt_tokenizer.encode(original_prompt))
# sources = [original_prompt.strip()]
sources = re.findall(r'.{%d}' % max_length, original_prompt.strip())
# print(sources)
if sources:
summaries = self.model.predict(sources, self.tokenizer, self.device)
# print(sources)
# print(summaries)
compressed_prompt = ""
for s in summaries:
compressed_prompt += s
compressed_tokens = len(self.gpt_tokenizer.encode(compressed_prompt))
result = {
'compressed_prompt': compressed_prompt,
'ratio': compressed_tokens / original_tokens,
'original_tokens': original_tokens,
'compressed_tokens': compressed_tokens,
}
return result
else:
result = {
'compressed_prompt': "",
'ratio': 0,
'original_tokens': "",
'compressed_tokens': "",
}
return result
|