Donut_Receipt_v2 / handler.py
Jeney's picture
Update handler.py
d75ed94
raw
history blame contribute delete
No virus
2.13 kB
import torch
import io
import re
from typing import Any, Dict
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
self.model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
# process input
inputs = data.pop("inputs", data)
image = inputs["image"]
image = Image.open(io.BytesIO(eval(image)))
return self.process_document(image)
def process_document(self, image:Image) -> dict[str, Any]:
# prepare encoder inputs
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = self.processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = self.model.generate(
pixel_values.to(self.device),
decoder_input_ids=decoder_input_ids.to(self.device),
max_length=self.model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=self.processor.tokenizer.pad_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = self.processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return self.processor.token2json(sequence)