Layoutlm_Inkaso_2 / handler.py
Szczotar93's picture
Update handler.py
84522a1 verified
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
import torch
from subprocess import run
import pandas as pd
# install tesseract-ocr and pytesseract
# run("apt install -y tesseract-ocr", shell=True, check=True)
# run("pip install pytesseract", shell=True, check=True)
# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
def predict(Image, processor, model):
"""Process document and prepare the data for LayoutLM inference
Args:
urls (List[str]): Batch of pre-signed document urls
Returns:
(List[List[Dict]]): Features extraction
"""
# images = [get_image_from_url(url) for url in urls]
encoding = processor(
images = Image,
return_tensors="pt",
padding="max_length",
truncation=True,
)
del encoding["image"] # LayoutLM doesn't require the image
outputs = model(**encoding)
results = process_outputs(
outputs, encoding=encoding,
images=Image, model=model,
processor=processor,
threshold = 0.75
)
return results, encoding
def get_uniqueLabelList(labels):
uqnieue_labels =[]
for label in labels[0]:
try:
label_short = label.split("-")[1]
if label_short not in uqnieue_labels:
uqnieue_labels.append(label_short)
except:
if label not in uqnieue_labels:
uqnieue_labels.append(label)
else:
pass
return uqnieue_labels
def process_outputs(outputs, encoding, images, model, processor, threshold):
scores, _ = torch.max(outputs.logits.softmax(axis=-1), dim=-1)
scores = scores.tolist()
predictions = outputs.logits.argmax(-1)
labels = [[model.config.id2label[pred.item()] for pred in prediction] for prediction in predictions]
results = _process_outputs(
encoding=encoding,
tokenizer=processor.tokenizer,
processor = processor,
labels=labels,
scores=scores,
images=images,
threshold = threshold
)
return results
def _process_outputs(encoding, tokenizer, labels, scores, images, processor, threshold):
results = []
width, height = images.size
entities = []
previous_word_idx = 0
unique_lables = get_uniqueLabelList(labels)
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
# word_ids = encoding.word_ids(batch_index=batch_idx)
# word = ""
entite_wordsidx = []
for idx, label in enumerate(unique_lables):
score_sum = float(0)
if label != "O":
for ix, pred in enumerate(labels[0]):
if scores[0][ix] > threshold:
if label in pred:
score_sum += scores[0][ix]
entite_wordsidx.append(ix)
#
try:
score_mean = f'{score_sum/len(entite_wordsidx):.2f}'
except:
score_mean = 0.0
# entite_wordsidx.append(entite_wordsidx[-1] + 1)
entities.append(
{
"word": processor.decode(encoding.input_ids[0][entite_wordsidx].tolist()),
"label": unique_lables[idx],
"score": score_mean
,
}
)
entite_wordsidx = []
results.append(entities)
return results
def unnormalize_box(bbox, width, height):
return [
int(width * (bbox[0] / 1000)),
int(height * (bbox[1] / 1000)),
int(width * (bbox[2] / 1000)),
int(height * (bbox[3] / 1000)),
]
def get_image_from_url(Image):
return Image.open(f).convert("RGB") # LayoutLMv2Processor requires RGB format
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
self.processor = LayoutLMv2Processor.from_pretrained(path, apply_ocr=True)
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
# process input
image = data.pop("inputs", data)
result, encod = predict(image, self.processor, self.model)
return {"predictions": result}