pierreguillou
commited on
Commit
•
b050ba1
1
Parent(s):
f43f6f8
Update files/functions.py
Browse files- files/functions.py +44 -23
files/functions.py
CHANGED
@@ -51,22 +51,13 @@ label2color = {
|
|
51 |
|
52 |
# bounding boxes start and end of a sequence
|
53 |
cls_box = [0, 0, 0, 0]
|
54 |
-
sep_box =
|
55 |
|
56 |
# model
|
57 |
-
|
58 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
63 |
-
model = AutoModelForTokenClassification.from_pretrained(model_id);
|
64 |
-
model.to(device);
|
65 |
-
|
66 |
-
# get labels
|
67 |
-
id2label = model.config.id2label
|
68 |
-
label2id = model.config.label2id
|
69 |
-
num_labels = len(id2label)
|
70 |
|
71 |
# (tokenization) The maximum length of a feature (sequence)
|
72 |
if str(384) in model_id:
|
@@ -81,7 +72,21 @@ doc_stride = 128 # The authorized overlap between two part of the context when s
|
|
81 |
|
82 |
# max PDF page images that will be displayed
|
83 |
max_imgboxes = 2
|
|
|
|
|
84 |
examples_dir = 'files/'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
image_wo_content = examples_dir + "wo_content.png" # image without content
|
86 |
pdf_blank = examples_dir + "blank.pdf" # blank PDF
|
87 |
image_blank = examples_dir + "blank.png" # blank image
|
@@ -368,8 +373,8 @@ def extraction_data_from_image(images):
|
|
368 |
|
369 |
# https://pyimagesearch.com/2021/11/15/tesseract-page-segmentation-modes-psms-explained-how-to-improve-your-ocr-accuracy/
|
370 |
custom_config = r'--oem 3 --psm 3 -l eng' # default config PyTesseract: --oem 3 --psm 3 -l eng+deu+fra+jpn+por+spa+rus+hin+chi_sim
|
371 |
-
results, lines, row_indexes, par_boxes, line_boxes = dict(), dict(), dict(), dict(), dict()
|
372 |
-
images_ids_list, lines_list, par_boxes_list, line_boxes_list, images_list, page_no_list, num_pages_list = list(), list(), list(), list(), list(), list(), list()
|
373 |
|
374 |
try:
|
375 |
for i,image in enumerate(images):
|
@@ -401,11 +406,15 @@ def extraction_data_from_image(images):
|
|
401 |
results[i] = pytesseract.image_to_data(img, config=custom_config, output_type=pytesseract.Output.DICT)
|
402 |
# results[i] = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
|
403 |
|
|
|
|
|
|
|
404 |
lines[i], row_indexes[i], par_boxes[i], line_boxes[i] = get_data(results[i], factor, conf_min=0)
|
405 |
lines_list.append(lines[i])
|
406 |
par_boxes_list.append(par_boxes[i])
|
407 |
line_boxes_list.append(line_boxes[i])
|
408 |
images_ids_list.append(i)
|
|
|
409 |
images_list.append(images[i])
|
410 |
page_no_list.append(i)
|
411 |
num_pages_list.append(num_imgs)
|
@@ -414,7 +423,7 @@ def extraction_data_from_image(images):
|
|
414 |
print(f"There was an error within the extraction of PDF text by the OCR!")
|
415 |
else:
|
416 |
from datasets import Dataset
|
417 |
-
dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts": lines_list, "bboxes_line": line_boxes_list})
|
418 |
|
419 |
# print(f"The text data was successfully extracted by the OCR!")
|
420 |
|
@@ -424,11 +433,12 @@ def extraction_data_from_image(images):
|
|
424 |
|
425 |
def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
|
426 |
|
427 |
-
images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list = list(), list(), list(), list(), list()
|
428 |
|
429 |
# get batch
|
430 |
batch_images_ids = example["images_ids"]
|
431 |
batch_images = example["images"]
|
|
|
432 |
batch_bboxes_line = example["bboxes_line"]
|
433 |
batch_texts = example["texts"]
|
434 |
batch_images_size = [image.size for image in batch_images]
|
@@ -439,12 +449,13 @@ def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
|
|
439 |
if not isinstance(batch_images_ids, list):
|
440 |
batch_images_ids = [batch_images_ids]
|
441 |
batch_images = [batch_images]
|
|
|
442 |
batch_bboxes_line = [batch_bboxes_line]
|
443 |
batch_texts = [batch_texts]
|
444 |
batch_width, batch_height = [batch_width], [batch_height]
|
445 |
|
446 |
# process all images of the batch
|
447 |
-
for num_batch, (image_id, boxes, texts, width, height) in enumerate(zip(batch_images_ids, batch_bboxes_line, batch_texts, batch_width, batch_height)):
|
448 |
tokens_list = []
|
449 |
bboxes_list = []
|
450 |
|
@@ -506,6 +517,7 @@ def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
|
|
506 |
bb_list.append(bb)
|
507 |
images_ids_list.append(image_id)
|
508 |
chunks_ids_list.append(i)
|
|
|
509 |
|
510 |
return {
|
511 |
"images_ids": images_ids_list,
|
@@ -513,6 +525,7 @@ def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
|
|
513 |
"input_ids": input_ids_list,
|
514 |
"attention_mask": attention_mask_list,
|
515 |
"normalized_bboxes": bb_list,
|
|
|
516 |
}
|
517 |
|
518 |
from torch.utils.data import Dataset
|
@@ -534,18 +547,21 @@ class CustomDataset(Dataset):
|
|
534 |
encoding["input_ids"] = example["input_ids"]
|
535 |
encoding["attention_mask"] = example["attention_mask"]
|
536 |
encoding["bbox"] = example["normalized_bboxes"]
|
|
|
537 |
|
538 |
return encoding
|
539 |
|
540 |
import torch.nn.functional as F
|
541 |
|
|
|
|
|
542 |
# get predictions at token level
|
543 |
def predictions_token_level(images, custom_encoded_dataset):
|
544 |
|
545 |
num_imgs = len(images)
|
546 |
if num_imgs > 0:
|
547 |
|
548 |
-
chunk_ids, input_ids, bboxes, outputs, token_predictions = dict(), dict(), dict(), dict(), dict()
|
549 |
images_ids_list = list()
|
550 |
|
551 |
for i,encoding in enumerate(custom_encoded_dataset):
|
@@ -556,6 +572,7 @@ def predictions_token_level(images, custom_encoded_dataset):
|
|
556 |
input_id = torch.tensor(encoding['input_ids'])[None]
|
557 |
attention_mask = torch.tensor(encoding['attention_mask'])[None]
|
558 |
bbox = torch.tensor(encoding['bbox'])[None]
|
|
|
559 |
|
560 |
# save data in dictionnaries
|
561 |
if image_id not in images_ids_list: images_ids_list.append(image_id)
|
@@ -569,14 +586,18 @@ def predictions_token_level(images, custom_encoded_dataset):
|
|
569 |
if image_id in bboxes: bboxes[image_id].append(bbox)
|
570 |
else: bboxes[image_id] = [bbox]
|
571 |
|
|
|
|
|
|
|
572 |
# get prediction with forward pass
|
573 |
with torch.no_grad():
|
574 |
output = model(
|
575 |
-
input_ids=input_id,
|
576 |
-
attention_mask=attention_mask,
|
577 |
-
bbox=bbox
|
|
|
578 |
)
|
579 |
-
|
580 |
# save probabilities of predictions in dictionnary
|
581 |
if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
|
582 |
else: outputs[image_id] = [F.softmax(output.logits.squeeze(), dim=-1)]
|
|
|
51 |
|
52 |
# bounding boxes start and end of a sequence
|
53 |
cls_box = [0, 0, 0, 0]
|
54 |
+
sep_box = [1000, 1000, 1000, 1000]
|
55 |
|
56 |
# model
|
57 |
+
model_id = "pierreguillou/layout-xlm-base-finetuned-with-DocLayNet-base-at-linelevel-ml384"
|
|
|
58 |
|
59 |
+
# tokenizer
|
60 |
+
tokenizer_id = "xlm-roberta-base"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
# (tokenization) The maximum length of a feature (sequence)
|
63 |
if str(384) in model_id:
|
|
|
72 |
|
73 |
# max PDF page images that will be displayed
|
74 |
max_imgboxes = 2
|
75 |
+
|
76 |
+
# get files
|
77 |
examples_dir = 'files/'
|
78 |
+
Path(examples_dir).mkdir(parents=True, exist_ok=True)
|
79 |
+
from huggingface_hub import hf_hub_download
|
80 |
+
files = ["example.pdf", "blank.pdf", "blank.png", "languages_iso.csv", "languages_tesseract.csv", "wo_content.png"]
|
81 |
+
for file_name in files:
|
82 |
+
path_to_file = hf_hub_download(
|
83 |
+
repo_id = "pierreguillou/Inference-APP-Document-Understanding-at-linelevel-v2",
|
84 |
+
filename = "files/" + file_name,
|
85 |
+
repo_type = "space"
|
86 |
+
)
|
87 |
+
shutil.copy(path_to_file,examples_dir)
|
88 |
+
|
89 |
+
# path to files
|
90 |
image_wo_content = examples_dir + "wo_content.png" # image without content
|
91 |
pdf_blank = examples_dir + "blank.pdf" # blank PDF
|
92 |
image_blank = examples_dir + "blank.png" # blank image
|
|
|
373 |
|
374 |
# https://pyimagesearch.com/2021/11/15/tesseract-page-segmentation-modes-psms-explained-how-to-improve-your-ocr-accuracy/
|
375 |
custom_config = r'--oem 3 --psm 3 -l eng' # default config PyTesseract: --oem 3 --psm 3 -l eng+deu+fra+jpn+por+spa+rus+hin+chi_sim
|
376 |
+
results, lines, row_indexes, par_boxes, line_boxes, images_pixels = dict(), dict(), dict(), dict(), dict(), dict()
|
377 |
+
images_ids_list, lines_list, par_boxes_list, line_boxes_list, images_list, images_pixels_list, page_no_list, num_pages_list = list(), list(), list(), list(), list(), list(), list(), list()
|
378 |
|
379 |
try:
|
380 |
for i,image in enumerate(images):
|
|
|
406 |
results[i] = pytesseract.image_to_data(img, config=custom_config, output_type=pytesseract.Output.DICT)
|
407 |
# results[i] = os.popen(f'tesseract {img_filepath} - {custom_config}').read()
|
408 |
|
409 |
+
# get image pixels
|
410 |
+
images_pixels[i] = feature_extractor(images[i], return_tensors="pt").pixel_values
|
411 |
+
|
412 |
lines[i], row_indexes[i], par_boxes[i], line_boxes[i] = get_data(results[i], factor, conf_min=0)
|
413 |
lines_list.append(lines[i])
|
414 |
par_boxes_list.append(par_boxes[i])
|
415 |
line_boxes_list.append(line_boxes[i])
|
416 |
images_ids_list.append(i)
|
417 |
+
images_pixels_list.append(images_pixels[i])
|
418 |
images_list.append(images[i])
|
419 |
page_no_list.append(i)
|
420 |
num_pages_list.append(num_imgs)
|
|
|
423 |
print(f"There was an error within the extraction of PDF text by the OCR!")
|
424 |
else:
|
425 |
from datasets import Dataset
|
426 |
+
dataset = Dataset.from_dict({"images_ids": images_ids_list, "images": images_list, "images_pixels": images_pixels_list, "page_no": page_no_list, "num_pages": num_pages_list, "texts": lines_list, "bboxes_line": line_boxes_list})
|
427 |
|
428 |
# print(f"The text data was successfully extracted by the OCR!")
|
429 |
|
|
|
433 |
|
434 |
def prepare_inference_features(example, cls_box = cls_box, sep_box = sep_box):
|
435 |
|
436 |
+
images_ids_list, chunks_ids_list, input_ids_list, attention_mask_list, bb_list, images_pixels_list = list(), list(), list(), list(), list(), list()
|
437 |
|
438 |
# get batch
|
439 |
batch_images_ids = example["images_ids"]
|
440 |
batch_images = example["images"]
|
441 |
+
batch_images_pixels = example["images_pixels"]
|
442 |
batch_bboxes_line = example["bboxes_line"]
|
443 |
batch_texts = example["texts"]
|
444 |
batch_images_size = [image.size for image in batch_images]
|
|
|
449 |
if not isinstance(batch_images_ids, list):
|
450 |
batch_images_ids = [batch_images_ids]
|
451 |
batch_images = [batch_images]
|
452 |
+
batch_images_pixels = [batch_images_pixels]
|
453 |
batch_bboxes_line = [batch_bboxes_line]
|
454 |
batch_texts = [batch_texts]
|
455 |
batch_width, batch_height = [batch_width], [batch_height]
|
456 |
|
457 |
# process all images of the batch
|
458 |
+
for num_batch, (image_id, image_pixels, boxes, texts, width, height) in enumerate(zip(batch_images_ids, batch_images_pixels, batch_bboxes_line, batch_texts, batch_width, batch_height)):
|
459 |
tokens_list = []
|
460 |
bboxes_list = []
|
461 |
|
|
|
517 |
bb_list.append(bb)
|
518 |
images_ids_list.append(image_id)
|
519 |
chunks_ids_list.append(i)
|
520 |
+
images_pixels_list.append(image_pixels)
|
521 |
|
522 |
return {
|
523 |
"images_ids": images_ids_list,
|
|
|
525 |
"input_ids": input_ids_list,
|
526 |
"attention_mask": attention_mask_list,
|
527 |
"normalized_bboxes": bb_list,
|
528 |
+
"images_pixels": images_pixels_list
|
529 |
}
|
530 |
|
531 |
from torch.utils.data import Dataset
|
|
|
547 |
encoding["input_ids"] = example["input_ids"]
|
548 |
encoding["attention_mask"] = example["attention_mask"]
|
549 |
encoding["bbox"] = example["normalized_bboxes"]
|
550 |
+
encoding["images_pixels"] = example["images_pixels"]
|
551 |
|
552 |
return encoding
|
553 |
|
554 |
import torch.nn.functional as F
|
555 |
|
556 |
+
import torch.nn.functional as F
|
557 |
+
|
558 |
# get predictions at token level
|
559 |
def predictions_token_level(images, custom_encoded_dataset):
|
560 |
|
561 |
num_imgs = len(images)
|
562 |
if num_imgs > 0:
|
563 |
|
564 |
+
chunk_ids, input_ids, bboxes, pixels_values, outputs, token_predictions = dict(), dict(), dict(), dict(), dict(), dict()
|
565 |
images_ids_list = list()
|
566 |
|
567 |
for i,encoding in enumerate(custom_encoded_dataset):
|
|
|
572 |
input_id = torch.tensor(encoding['input_ids'])[None]
|
573 |
attention_mask = torch.tensor(encoding['attention_mask'])[None]
|
574 |
bbox = torch.tensor(encoding['bbox'])[None]
|
575 |
+
pixel_values = torch.tensor(encoding["images_pixels"])
|
576 |
|
577 |
# save data in dictionnaries
|
578 |
if image_id not in images_ids_list: images_ids_list.append(image_id)
|
|
|
586 |
if image_id in bboxes: bboxes[image_id].append(bbox)
|
587 |
else: bboxes[image_id] = [bbox]
|
588 |
|
589 |
+
if image_id in pixels_values: pixels_values[image_id].append(pixel_values)
|
590 |
+
else: pixels_values[image_id] = [pixel_values]
|
591 |
+
|
592 |
# get prediction with forward pass
|
593 |
with torch.no_grad():
|
594 |
output = model(
|
595 |
+
input_ids=input_id.to(device),
|
596 |
+
attention_mask=attention_mask.to(device),
|
597 |
+
bbox=bbox.to(device),
|
598 |
+
image=pixel_values.to(device)
|
599 |
)
|
600 |
+
|
601 |
# save probabilities of predictions in dictionnary
|
602 |
if image_id in outputs: outputs[image_id].append(F.softmax(output.logits.squeeze(), dim=-1))
|
603 |
else: outputs[image_id] = [F.softmax(output.logits.squeeze(), dim=-1)]
|