|
from typing import Dict, List, Any |
|
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor |
|
import torch |
|
from subprocess import run |
|
import pandas as pd |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unnormalize_box(bbox, width, height): |
|
return [ |
|
width * (bbox[0] / 1000), |
|
height * (bbox[1] / 1000), |
|
width * (bbox[2] / 1000), |
|
height * (bbox[3] / 1000), |
|
] |
|
|
|
def predict(Image, processor, model): |
|
"""Process document and prepare the data for LayoutLM inference |
|
|
|
Args: |
|
urls (List[str]): Batch of pre-signed document urls |
|
Returns: |
|
(List[List[Dict]]): Features extraction |
|
""" |
|
|
|
|
|
|
|
encoding = processor( |
|
images = Image, |
|
return_tensors="pt", |
|
padding="max_length", |
|
truncation=True, |
|
) |
|
del encoding["image"] |
|
outputs = model(**encoding) |
|
results = process_outputs( |
|
outputs, encoding=encoding, |
|
images=Image, model=model, |
|
processor=processor, |
|
threshold = 0.75 |
|
) |
|
return results, encoding |
|
def get_uniqueLabelList(labels): |
|
uqnieue_labels =[] |
|
for label in labels[0]: |
|
try: |
|
label_short = label.split("-")[1] |
|
if label_short not in uqnieue_labels: |
|
uqnieue_labels.append(label_short) |
|
except: |
|
if label not in uqnieue_labels: |
|
uqnieue_labels.append(label) |
|
else: |
|
pass |
|
return uqnieue_labels |
|
|
|
def process_outputs(outputs, encoding, images, model, processor, threshold): |
|
scores, _ = torch.max(outputs.logits.softmax(axis=-1), dim=-1) |
|
scores = scores.tolist() |
|
predictions = outputs.logits.argmax(-1) |
|
labels = [[model.config.id2label[pred.item()] for pred in prediction] for prediction in predictions] |
|
results = _process_outputs( |
|
encoding=encoding, |
|
tokenizer=processor.tokenizer, |
|
processor = processor, |
|
labels=labels, |
|
scores=scores, |
|
images=images, |
|
threshold = threshold |
|
) |
|
return results |
|
|
|
def _process_outputs(encoding, tokenizer, labels, scores, images, processor, threshold): |
|
results = [] |
|
|
|
width, height = images.size |
|
entities = [] |
|
previous_word_idx = 0 |
|
unique_lables = get_uniqueLabelList(labels) |
|
|
|
|
|
|
|
entite_wordsidx = [] |
|
for idx, label in enumerate(unique_lables): |
|
score_sum = float(0) |
|
if label != "O": |
|
for ix, pred in enumerate(labels[0]): |
|
if scores[0][ix] > threshold: |
|
if label in pred: |
|
score_sum += scores[0][ix] |
|
entite_wordsidx.append(ix) |
|
|
|
|
|
try: |
|
score_mean = f'{score_sum/len(entite_wordsidx):.2f}' |
|
except: |
|
score_mean = 0.0 |
|
|
|
entities.append( |
|
{ |
|
"word": processor.decode(encoding.input_ids[0][entite_wordsidx].tolist()), |
|
"label": unique_lables[idx], |
|
"score": score_mean |
|
, |
|
} |
|
) |
|
|
|
entite_wordsidx = [] |
|
|
|
|
|
results.append(entities) |
|
return results |
|
|
|
def unnormalize_box(bbox, width, height): |
|
return [ |
|
int(width * (bbox[0] / 1000)), |
|
int(height * (bbox[1] / 1000)), |
|
int(width * (bbox[2] / 1000)), |
|
int(height * (bbox[3] / 1000)), |
|
] |
|
def get_image_from_url(Image): |
|
return Image.open(f).convert("RGB") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
|
|
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device) |
|
self.processor = LayoutLMv2Processor.from_pretrained(path, apply_ocr=True) |
|
|
|
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]: |
|
""" |
|
Args: |
|
data (:obj:): |
|
includes the deserialized image file as PIL.Image |
|
""" |
|
|
|
image = data.pop("inputs", data) |
|
|
|
result, encod = predict(image, self.processor, self.model) |
|
return {"predictions": result} |