ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
raw
history blame
8.72 kB
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from typing import List, Optional
from PIL import Image
import numpy as np
from surya.detection import batch_detection
from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes
from surya.schema import LayoutResult, LayoutBox, TextDetectionResult
from surya.settings import settings
def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]:
logits = np.stack(heatmaps, axis=0)
vertical_line_bboxes = [line for line in detection_result.vertical_lines]
line_bboxes = detection_result.bboxes
# Scale back to processor size
for line in vertical_line_bboxes:
line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape)))
for line in line_bboxes:
line.rescale(orig_size, list(reversed(heatmaps[0].shape)))
for bbox in vertical_line_bboxes:
# Give some width to the vertical lines
vert_bbox = list(bbox.bbox)
vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width)
logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are
logits[:, logits[0] >= .5] = 0 # zero out where blanks are
# Zero out where other segments are
for i in range(logits.shape[0]):
logits[i, segment_assignment != i] = 0
detected_boxes = []
for heatmap_idx in range(1, len(id2label)): # Skip the blank class
heatmap = logits[heatmap_idx]
bboxes = get_detected_boxes(heatmap)
bboxes = [bbox for bbox in bboxes if bbox.area > 25]
for bb in bboxes:
bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1])
for bbox in bboxes:
detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1))
detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True)
# Expand bbox to cover intersecting lines
box_lines = defaultdict(list)
used_lines = set()
# We try 2 rounds of identifying the correct lines to snap to
# First round is majority intersection, second lowers the threshold
for thresh in [.5, .4]:
for bbox_idx, bbox in enumerate(detected_boxes):
for line_idx, line_bbox in enumerate(line_bboxes):
if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines:
box_lines[bbox_idx].append(line_bbox.bbox)
used_lines.add(line_idx)
new_boxes = []
for bbox_idx, bbox in enumerate(detected_boxes):
if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures
continue
# Skip if we didn't find any lines to snap to, except for Pictures and Formulas
if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]:
continue
covered_lines = box_lines[bbox_idx]
# Snap non-picture layout boxes to correct text boundaries
if len(covered_lines) > 0 and bbox.label not in ["Picture"]:
min_x = min([line[0] for line in covered_lines])
min_y = min([line[1] for line in covered_lines])
max_x = max([line[2] for line in covered_lines])
max_y = max([line[3] for line in covered_lines])
# Tables and formulas can contain text, but text isn't the whole area
if bbox.label in ["Table", "Formula"]:
min_x_box = min([b[0] for b in bbox.polygon])
min_y_box = min([b[1] for b in bbox.polygon])
max_x_box = max([b[0] for b in bbox.polygon])
max_y_box = max([b[1] for b in bbox.polygon])
min_x = min(min_x, min_x_box)
min_y = min(min_y, min_y_box)
max_x = max(max_x, max_x_box)
max_y = max(max_y, max_y_box)
bbox.polygon[0][0] = min_x
bbox.polygon[0][1] = min_y
bbox.polygon[1][0] = max_x
bbox.polygon[1][1] = min_y
bbox.polygon[2][0] = max_x
bbox.polygon[2][1] = max_y
bbox.polygon[3][0] = min_x
bbox.polygon[3][1] = max_y
if bbox_idx in box_lines and bbox.label in ["Picture"]:
bbox.label = "Figure"
new_boxes.append(bbox)
# Merge tables together (sometimes one column is detected as a separate table)
for i in range(5): # Up to 5 rounds of merging
to_remove = set()
for bbox_idx, bbox in enumerate(new_boxes):
if bbox.label != "Table" or bbox_idx in to_remove:
continue
for bbox_idx2, bbox2 in enumerate(new_boxes):
if bbox2.label != "Table" or bbox_idx2 in to_remove or bbox_idx == bbox_idx2:
continue
if bbox.intersection_pct(bbox2) > 0:
bbox.merge(bbox2)
to_remove.add(bbox_idx2)
new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove]
# Ensure we account for all text lines in the layout
unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines]
for bbox in unused_lines:
new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5))
for bbox in new_boxes:
bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size)
detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16]
# Remove bboxes contained inside others, unless they're captions
contained_bbox = []
for i, bbox in enumerate(detected_boxes):
for j, bbox2 in enumerate(detected_boxes):
if i == j:
continue
if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]:
contained_bbox.append(j)
detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox]
return detected_boxes
def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]:
bboxes = []
for i in range(1, len(id2label)): # Skip the blank class
heatmap = heatmaps[i]
assert heatmap.shape == segment_assignment.shape
heatmap[segment_assignment != i] = 0 # zero out where another segment is
bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size)
for bb in bbox:
bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i]))
heatmaps.append(heatmap)
bboxes = keep_largest_boxes(bboxes)
return bboxes
def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult:
logits = np.stack(heatmaps, axis=0)
segment_assignment = logits.argmax(axis=0)
if detection_results is not None:
bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label,
segment_assignment)
else:
bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment)
segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8))
result = LayoutResult(
bboxes=bboxes,
segmentation_map=segmentation_img,
heatmaps=heatmaps,
image_bbox=[0, 0, orig_size[0], orig_size[1]]
)
return result
def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
id2label = model.config.id2label
results = []
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images
for i in range(len(images)):
result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
results.append(result)
else:
futures = []
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for i in range(len(images)):
future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
futures.append(future)
for future in futures:
results.append(future.result())
return results