nougat-api / handler.py
kevin-pek
test without postprocessing
835c1c9
from io import BytesIO
from typing import Dict, Any
from transformers import NougatProcessor, VisionEncoderDecoderModel, StoppingCriteria, StoppingCriteriaList
from transformers.image_utils import base64
from PIL import Image
import torch
class RunningVarTorch:
def __init__(self, L=15, norm=False):
self.values = None
self.L = L
self.norm = norm
def push(self, x: torch.Tensor):
assert x.dim() == 1
if self.values is None:
self.values = x[:, None]
elif self.values.shape[1] < self.L:
self.values = torch.cat((self.values, x[:, None]), 1)
else:
self.values = torch.cat((self.values[:, 1:], x[:, None]), 1)
def variance(self):
if self.values is None:
return
if self.norm:
return torch.var(self.values, 1) / self.values.shape[1]
else:
return torch.var(self.values, 1)
class StoppingCriteriaScores(StoppingCriteria):
def __init__(self, threshold: float = 0.015, window_size: int = 200):
super().__init__()
self.threshold = threshold
self.vars = RunningVarTorch(norm=True)
self.varvars = RunningVarTorch(L=window_size)
self.stop_inds = defaultdict(int)
self.stopped = defaultdict(bool)
self.size = 0
self.window_size = window_size
@torch.no_grad()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
last_scores = scores[-1]
self.vars.push(last_scores.max(1)[0].float().cpu())
self.varvars.push(self.vars.variance())
self.size += 1
if self.size < self.window_size:
return False
varvar = self.varvars.variance()
for b in range(len(last_scores)):
if varvar[b] < self.threshold:
if self.stop_inds[b] > 0 and not self.stopped[b]:
self.stopped[b] = self.stop_inds[b] >= self.size
else:
self.stop_inds[b] = int(
min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)
)
else:
self.stop_inds[b] = 0
self.stopped[b] = False
return all(self.stopped.values()) and len(self.stopped) > 0
class EndpointHandler():
def __init__(self, path="facebook/nougat-small") -> None:
self.processor = NougatProcessor.from_pretrained(path)
self.model = VisionEncoderDecoderModel.from_pretrained(path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> str:
image = data.pop("inputs", data)
image_data = Image.open(BytesIO(base64.b64decode(image)))
pixel_values = self.processor(image_data, return_tensors="pt").pixel_values
outputs = self.model.generate(
pixel_values.to(self.device),
min_length=1,
max_length=3584,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()])
)
text = self.processor.batch_decode(outputs[0], skip_special_tokens=True)[0]
text = self.processor.post_process_generation(text, fix_markdown=False)
return outputs