File size: 8,723 Bytes
2720487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from typing import List, Optional
from PIL import Image
import numpy as np

from surya.detection import batch_detection
from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes
from surya.schema import LayoutResult, LayoutBox, TextDetectionResult
from surya.settings import settings


def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
    logits = np.stack(heatmaps, axis=0)
    vertical_line_bboxes = [line for line in detection_result.vertical_lines]
    line_bboxes = detection_result.bboxes

    # Scale back to processor size
    for line in vertical_line_bboxes:
        line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape)))

    for line in line_bboxes:
        line.rescale(orig_size, list(reversed(heatmaps[0].shape)))

    for bbox in vertical_line_bboxes:
        # Give some width to the vertical lines
        vert_bbox = list(bbox.bbox)
        vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width)

        logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0  # zero out where the column lines are

    logits[:, logits[0] >= .5] = 0 # zero out where blanks are

    # Zero out where other segments are
    for i in range(logits.shape[0]):
        logits[i, segment_assignment != i] = 0

    detected_boxes = []
    for heatmap_idx in range(1, len(id2label)):  # Skip the blank class
        heatmap = logits[heatmap_idx]
        bboxes = get_detected_boxes(heatmap)
        bboxes = [bbox for bbox in bboxes if bbox.area > 25]
        for bb in bboxes:
            bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1])

        for bbox in bboxes:
            detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1))

    detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True)
    # Expand bbox to cover intersecting lines
    box_lines = defaultdict(list)
    used_lines = set()

    # We try 2 rounds of identifying the correct lines to snap to
    # First round is majority intersection, second lowers the threshold
    for thresh in [.5, .4]:
        for bbox_idx, bbox in enumerate(detected_boxes):
            for line_idx, line_bbox in enumerate(line_bboxes):
                if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines:
                    box_lines[bbox_idx].append(line_bbox.bbox)
                    used_lines.add(line_idx)

    new_boxes = []
    for bbox_idx, bbox in enumerate(detected_boxes):
        if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures
            continue

        # Skip if we didn't find any lines to snap to, except for Pictures and Formulas
        if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]:
            continue

        covered_lines = box_lines[bbox_idx]
        # Snap non-picture layout boxes to correct text boundaries
        if len(covered_lines) > 0 and bbox.label not in ["Picture"]:
            min_x = min([line[0] for line in covered_lines])
            min_y = min([line[1] for line in covered_lines])
            max_x = max([line[2] for line in covered_lines])
            max_y = max([line[3] for line in covered_lines])

            # Tables and formulas can contain text, but text isn't the whole area
            if bbox.label in ["Table", "Formula"]:
                min_x_box = min([b[0] for b in bbox.polygon])
                min_y_box = min([b[1] for b in bbox.polygon])
                max_x_box = max([b[0] for b in bbox.polygon])
                max_y_box = max([b[1] for b in bbox.polygon])

                min_x = min(min_x, min_x_box)
                min_y = min(min_y, min_y_box)
                max_x = max(max_x, max_x_box)
                max_y = max(max_y, max_y_box)

            bbox.polygon[0][0] = min_x
            bbox.polygon[0][1] = min_y
            bbox.polygon[1][0] = max_x
            bbox.polygon[1][1] = min_y
            bbox.polygon[2][0] = max_x
            bbox.polygon[2][1] = max_y
            bbox.polygon[3][0] = min_x
            bbox.polygon[3][1] = max_y

        if bbox_idx in box_lines and bbox.label in ["Picture"]:
            bbox.label = "Figure"

        new_boxes.append(bbox)

    # Merge tables together (sometimes one column is detected as a separate table)
    for i in range(5): # Up to 5 rounds of merging
        to_remove = set()
        for bbox_idx, bbox in enumerate(new_boxes):
            if bbox.label != "Table" or bbox_idx in to_remove:
                continue

            for bbox_idx2, bbox2 in enumerate(new_boxes):
                if bbox2.label != "Table" or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
                    continue

                if bbox.intersection_pct(bbox2) > 0:
                    bbox.merge(bbox2)
                    to_remove.add(bbox_idx2)

        new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove]

    # Ensure we account for all text lines in the layout
    unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines]
    for bbox in unused_lines:
        new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5))

    for bbox in new_boxes:
        bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size)

    detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16]

    # Remove bboxes contained inside others, unless they're captions
    contained_bbox = []
    for i, bbox in enumerate(detected_boxes):
        for j, bbox2 in enumerate(detected_boxes):
            if i == j:
                continue

            if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]:
                contained_bbox.append(j)

    detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox]

    return detected_boxes


def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]:
    bboxes = []
    for i in range(1, len(id2label)):  # Skip the blank class
        heatmap = heatmaps[i]
        assert heatmap.shape == segment_assignment.shape
        heatmap[segment_assignment != i] = 0  # zero out where another segment is
        bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
        for bb in bbox:
            bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
        heatmaps.append(heatmap)

    bboxes = keep_largest_boxes(bboxes)
    return bboxes


def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult:
    logits = np.stack(heatmaps, axis=0)
    segment_assignment = logits.argmax(axis=0)
    if detection_results is not None:
        bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label,
                                                   segment_assignment)
    else:
        bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment)

    segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8))

    result = LayoutResult(
        bboxes=bboxes,
        segmentation_map=segmentation_img,
        heatmaps=heatmaps,
        image_bbox=[0, 0, orig_size[0], orig_size[1]]
    )

    return result


def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
    preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
    id2label = model.config.id2label

    results = []
    if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images
        for i in range(len(images)):
            result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
            results.append(result)
    else:
        futures = []
        max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            for i in range(len(images)):
                future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
                futures.append(future)

            for future in futures:
                results.append(future.result())

    return results