|
from io import BytesIO |
|
from typing import Dict, Any |
|
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList |
|
from transformers.image_utils import base64 |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
class RunningVarTorch: |
|
def __init__(self, L=15, norm=False): |
|
self.values = None |
|
self.L = L |
|
self.norm = norm |
|
|
|
def push(self, x: torch.Tensor): |
|
assert x.dim() == 1 |
|
if self.values is None: |
|
self.values = x[:, None] |
|
elif self.values.shape[1] < self.L: |
|
self.values = torch.cat((self.values, x[:, None]), 1) |
|
else: |
|
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1) |
|
|
|
def variance(self): |
|
if self.values is None: |
|
return |
|
if self.norm: |
|
return torch.var(self.values, 1) / self.values.shape[1] |
|
else: |
|
return torch.var(self.values, 1) |
|
|
|
|
|
class StoppingCriteriaScores(StoppingCriteria): |
|
def __init__(self, threshold: float = 0.015, window_size: int = 200): |
|
super().__init__() |
|
self.threshold = threshold |
|
self.vars = RunningVarTorch(norm=True) |
|
self.varvars = RunningVarTorch(L=window_size) |
|
self.stop_inds = defaultdict(int) |
|
self.stopped = defaultdict(bool) |
|
self.size = 0 |
|
self.window_size = window_size |
|
|
|
@torch.no_grad() |
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
last_scores = scores[-1] |
|
self.vars.push(last_scores.max(1)[0].float().cpu()) |
|
self.varvars.push(self.vars.variance()) |
|
self.size += 1 |
|
if self.size < self.window_size: |
|
return False |
|
|
|
varvar = self.varvars.variance() |
|
for b in range(len(last_scores)): |
|
if varvar[b] < self.threshold: |
|
if self.stop_inds[b] > 0 and not self.stopped[b]: |
|
self.stopped[b] = self.stop_inds[b] >= self.size |
|
else: |
|
self.stop_inds[b] = int( |
|
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095) |
|
) |
|
else: |
|
self.stop_inds[b] = 0 |
|
self.stopped[b] = False |
|
return all(self.stopped.values()) and len(self.stopped) > 0 |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="facebook/nougat-small") -> None: |
|
self.processor = NougatProcessor.from_pretrained(path) |
|
self.model = VisionEncoderDecoderModel.from_pretrained(path) |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.model.to(self.device) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
image = data.pop("inputs", data) |
|
image_data = Image.open(BytesIO(base64.b64decode(image))) |
|
pixel_values = self.processor(image_data, return_tensors="pt").pixel_values |
|
|
|
outputs = self.model.generate( |
|
pixel_values.to(self.device), |
|
min_length=1, |
|
max_length=3584, |
|
bad_words_ids=[[self.processor.tokenizer.unk_token_id]], |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()]) |
|
) |
|
|
|
text = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0] |
|
text = self.processor.post_process_generation(text, fix_markdown=False) |
|
|
|
return outputs |
|
|