File size: 4,819 Bytes
1125d95 7805cf7 1125d95 ecabb16 1125d95 e10b388 1125d95 84522a1 1125d95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
from typing import Dict, List, Any
from transformers import LayoutLMForTokenClassification, LayoutLMv2Processor
import torch
from subprocess import run
import pandas as pd
# install tesseract-ocr and pytesseract
# run("apt install -y tesseract-ocr", shell=True, check=True)
# run("pip install pytesseract", shell=True, check=True)
# helper function to unnormalize bboxes for drawing onto the image
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
def predict(Image, processor, model):
"""Process document and prepare the data for LayoutLM inference
Args:
urls (List[str]): Batch of pre-signed document urls
Returns:
(List[List[Dict]]): Features extraction
"""
# images = [get_image_from_url(url) for url in urls]
encoding = processor(
images = Image,
return_tensors="pt",
padding="max_length",
truncation=True,
)
del encoding["image"] # LayoutLM doesn't require the image
outputs = model(**encoding)
results = process_outputs(
outputs, encoding=encoding,
images=Image, model=model,
processor=processor,
threshold = 0.75
)
return results, encoding
def get_uniqueLabelList(labels):
uqnieue_labels =[]
for label in labels[0]:
try:
label_short = label.split("-")[1]
if label_short not in uqnieue_labels:
uqnieue_labels.append(label_short)
except:
if label not in uqnieue_labels:
uqnieue_labels.append(label)
else:
pass
return uqnieue_labels
def process_outputs(outputs, encoding, images, model, processor, threshold):
scores, _ = torch.max(outputs.logits.softmax(axis=-1), dim=-1)
scores = scores.tolist()
predictions = outputs.logits.argmax(-1)
labels = [[model.config.id2label[pred.item()] for pred in prediction] for prediction in predictions]
results = _process_outputs(
encoding=encoding,
tokenizer=processor.tokenizer,
processor = processor,
labels=labels,
scores=scores,
images=images,
threshold = threshold
)
return results
def _process_outputs(encoding, tokenizer, labels, scores, images, processor, threshold):
results = []
width, height = images.size
entities = []
previous_word_idx = 0
unique_lables = get_uniqueLabelList(labels)
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
# word_ids = encoding.word_ids(batch_index=batch_idx)
# word = ""
entite_wordsidx = []
for idx, label in enumerate(unique_lables):
score_sum = float(0)
if label != "O":
for ix, pred in enumerate(labels[0]):
if scores[0][ix] > threshold:
if label in pred:
score_sum += scores[0][ix]
entite_wordsidx.append(ix)
#
try:
score_mean = f'{score_sum/len(entite_wordsidx):.2f}'
except:
score_mean = 0.0
# entite_wordsidx.append(entite_wordsidx[-1] + 1)
entities.append(
{
"word": processor.decode(encoding.input_ids[0][entite_wordsidx].tolist()),
"label": unique_lables[idx],
"score": score_mean
,
}
)
entite_wordsidx = []
results.append(entities)
return results
def unnormalize_box(bbox, width, height):
return [
int(width * (bbox[0] / 1000)),
int(height * (bbox[1] / 1000)),
int(width * (bbox[2] / 1000)),
int(height * (bbox[3] / 1000)),
]
def get_image_from_url(Image):
return Image.open(f).convert("RGB") # LayoutLMv2Processor requires RGB format
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.model = LayoutLMForTokenClassification.from_pretrained(path).to(device)
self.processor = LayoutLMv2Processor.from_pretrained(path, apply_ocr=True)
def __call__(self, data: Dict[str, bytes]) -> Dict[str, List[Any]]:
"""
Args:
data (:obj:):
includes the deserialized image file as PIL.Image
"""
# process input
image = data.pop("inputs", data)
result, encod = predict(image, self.processor, self.model)
return {"predictions": result} |