Layoutlm_Inkaso_2 / handler.py
Szczotar93's picture
Update handler.py
ecabb16 verified
raw
history blame
4.8 kB
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
import torch
from subprocess import run
import pandas as pd
# install tesseract-ocr and pytesseract
# run("apt install -y tesseract-ocr", shell=True, check=True)
# run("pip install pytesseract", shell=True, check=True)
# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
def predict(Image, processor, model):
"""Process document and prepare the data for LayoutLM inference
Args:
urls (List[str]): Batch of pre-signed document urls
Returns:
(List[List[Dict]]): Features extraction
"""
# images = [get_image_from_url(url) for url in urls]
encoding = processor(
images = Image,
return_tensors="pt",
padding="max_length",
truncation=True,
)
del encoding["image"] # LayoutLM doesn't require the image
outputs = model(**encoding)
results = process_outputs(
outputs, encoding=encoding,
images=Image, model=model,
processor=processor,
threshold = 0.75
)
return results, encoding
def get_uniqueLabelList(labels):
uqnieue_labels =[]
for label in labels[0]:
try:
label_short = label.split("-")[1]
if label_short not in uqnieue_labels:
uqnieue_labels.append(label_short)
except:
if label not in uqnieue_labels:
uqnieue_labels.append(label)
else:
pass
return uqnieue_labels
def process_outputs(outputs, encoding, images, model, processor, threshold):
scores, _ = torch.max(outputs.logits.softmax(axis=-1), dim=-1)
scores = scores.tolist()
predictions = outputs.logits.argmax(-1)
labels = [[model.config.id2label[pred.item()] for pred in prediction] for prediction in predictions]
results = _process_outputs(
encoding=encoding,
tokenizer=processor.tokenizer,
processor = processor,
labels=labels,
scores=scores,
images=images,
threshold = threshold
)
return results
def _process_outputs(encoding, tokenizer, labels, scores, images, processor, threshold):
results = []
width, height = images.size
entities = []
previous_word_idx = 0
unique_lables = get_uniqueLabelList(labels)
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
# word_ids = encoding.word_ids(batch_index=batch_idx)
# word = ""
entite_wordsidx = []
for idx, label in enumerate(unique_lables):
score_sum = float(0)
if label != "O":
for ix, pred in enumerate(labels[0]):
if scores[0][ix] > threshold:
if label in pred:
score_sum += scores[0][ix]
entite_wordsidx.append(ix)
#
try:
score_mean = f'{score_sum/len(entite_wordsidx):.2f}'
except:
score_mean = 0.0
# entite_wordsidx.append(entite_wordsidx[-1] + 1)
entities.append(
{
"word": processor.decode(encoding.input_ids[0][entite_wordsidx].tolist()),
"label": unique_lables[idx],
"score": score_mean
,
}
)
entite_wordsidx = []
results.append(entities)
return results
def unnormalize_box(bbox, width, height):
return [
int(width * (bbox[0] / 1000)),
int(height * (bbox[1] / 1000)),
int(width * (bbox[2] / 1000)),
int(height * (bbox[3] / 1000)),
]
def get_image_from_url(Image):
return Image.open(f).convert("RGB") # LayoutLMv2Processor requires RGB format
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
self.processor = LayoutLMv2Processor.from_pretrained(path)
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
# process input
image = data.pop("inputs", data)
result, encod = predict(image, self.processor, self.model)
return {"predictions": result}