""" |
Copyright (C) 2021 Microsoft Corporation |
""" |
from collections import defaultdict |
from fitz import Rect |
def apply_threshold(objects, threshold): |
""" |
Filter out objects below a certain score. |
""" |
return [obj for obj in objects if obj['score'] >= threshold] |
def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds): |
""" |
Filter out bounding boxes whose confidence is below the confidence threshold for |
its associated class label. |
""" |
indices_above_threshold = [idx for idx, (score, label) in enumerate(zip(scores, labels)) |
if score >= class_thresholds[ |
class_names[label] |
] |
] |
bboxes = [bboxes[idx] for idx in indices_above_threshold] |
scores = [scores[idx] for idx in indices_above_threshold] |
labels = [labels[idx] for idx in indices_above_threshold] |
return bboxes, scores, labels |
def iou(bbox1, bbox2): |
""" |
Compute the intersection-over-union of two bounding boxes. |
""" |
intersection = Rect(bbox1).intersect(bbox2) |
union = Rect(bbox1).include_rect(bbox2) |
union_area = union.get_area() |
if union_area > 0: |
return intersection.get_area() / union.get_area() |
return 0 |
def iob(bbox1, bbox2): |
""" |
Compute the intersection area over box area, for bbox1. |
""" |
intersection = Rect(bbox1).intersect(bbox2) |
bbox1_area = Rect(bbox1).get_area() |
if bbox1_area > 0: |
return intersection.get_area() / bbox1_area |
return 0 |
def objects_to_cells(table, objects_in_table, tokens_in_table, class_map, class_thresholds): |
""" |
Process the bounding boxes produced by the table structure recognition model |
and the token/word/span bounding boxes into table cells. |
Also return a confidence score based on how well the text was able to be |
uniquely slotted into the cells detected by the table model. |
""" |
table_structures = objects_to_table_structures(table, objects_in_table, tokens_in_table, class_map, |
class_thresholds) |
if len(table_structures['columns']) < 1 or len(table_structures['rows']) < 1: |
cells = [] |
confidence_score = 0 |
else: |
cells, confidence_score = table_structure_to_cells(table_structures, tokens_in_table, table['bbox']) |
return table_structures, cells, confidence_score |
def objects_to_table_structures(table_object, objects_in_table, tokens_in_table, class_names, class_thresholds): |
""" |
Process the bounding boxes produced by the table structure recognition model into |
a *consistent* set of table structures (rows, columns, supercells, headers). |
This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment |
conditions (for example: rows should all have the same width, etc.). |
""" |
page_num = table_object['page_num'] |
table_structures = {} |
columns = [obj for obj in objects_in_table if class_names[obj['label']] == 'table column'] |
rows = [obj for obj in objects_in_table if class_names[obj['label']] == 'table row'] |
headers = [obj for obj in objects_in_table if class_names[obj['label']] == 'table column header'] |
supercells = [obj for obj in objects_in_table if class_names[obj['label']] == 'table spanning cell'] |
for obj in supercells: |
obj['subheader'] = False |
subheaders = [obj for obj in objects_in_table if class_names[obj['label']] == 'table projected row header'] |
for obj in subheaders: |
obj['subheader'] = True |
supercells += subheaders |
for obj in rows: |
obj['header'] = False |
for header_obj in headers: |
if iob(obj['bbox'], header_obj['bbox']) >= 0.5: |
obj['header'] = True |
for row in rows: |
row['page'] = page_num |
for column in columns: |
column['page'] = page_num |
rows = refine_rows(rows, tokens_in_table, class_thresholds['table row']) |
columns = refine_columns(columns, tokens_in_table, class_thresholds['table column']) |
row_rect = Rect() |
for obj in rows: |
row_rect.include_rect(obj['bbox']) |
column_rect = Rect() |
for obj in columns: |
column_rect.include_rect(obj['bbox']) |
table_object['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]] |
table_object['bbox'] = table_object['row_column_bbox'] |
columns = align_columns(columns, table_object['row_column_bbox']) |
rows = align_rows(rows, table_object['row_column_bbox']) |
table_structures['rows'] = rows |
table_structures['columns'] = columns |
table_structures['headers'] = headers |
table_structures['supercells'] = supercells |
if len(rows) > 0 and len(columns) > 1: |
table_structures = refine_table_structures(table_object['bbox'], table_structures, tokens_in_table, class_thresholds) |
return table_structures |
def refine_rows(rows, page_spans, score_threshold): |
""" |
Apply operations to the detected rows, such as |
thresholding, NMS, and alignment. |
""" |
rows = nms_by_containment(rows, page_spans, overlap_threshold=0.5) |
if len(rows) > 1: |
rows = sort_objects_top_to_bottom(rows) |
return rows |
def refine_columns(columns, page_spans, score_threshold): |
""" |
Apply operations to the detected columns, such as |
thresholding, NMS, and alignment. |
""" |
columns = nms_by_containment(columns, page_spans, overlap_threshold=0.5) |
if len(columns) > 1: |
columns = sort_objects_left_to_right(columns) |
return columns |
def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5): |
""" |
Non-maxima suppression (NMS) of objects based on shared containment of other objects. |
""" |
container_objects = sort_objects_by_score(container_objects) |
num_objects = len(container_objects) |
suppression = [False for obj in container_objects] |
packages_by_container, _, _ = slot_into_containers(container_objects, package_objects, overlap_threshold=overlap_threshold, |
unique_assignment=True, forced_assignment=False) |
for object2_num in range(1, num_objects): |
object2_packages = set(packages_by_container[object2_num]) |
if len(object2_packages) == 0: |
suppression[object2_num] = True |
for object1_num in range(object2_num): |
if not suppression[object1_num]: |
object1_packages = set(packages_by_container[object1_num]) |
if len(object2_packages.intersection(object1_packages)) > 0: |
suppression[object2_num] = True |
final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]] |
return final_objects |
def slot_into_containers(container_objects, package_objects, overlap_threshold=0.5, |
unique_assignment=True, forced_assignment=False): |
""" |
Slot a collection of objects into the container they occupy most (the container which holds the largest fraction of the object). |
""" |
best_match_scores = [] |
container_assignments = [[] for container in container_objects] |
package_assignments = [[] for package in package_objects] |
if len(container_objects) == 0 or len(package_objects) == 0: |
return container_assignments, package_assignments, best_match_scores |
match_scores = defaultdict(dict) |
for package_num, package in enumerate(package_objects): |
match_scores = [] |
package_rect = Rect(package['bbox']) |
package_area = package_rect.get_area() |
for container_num, container in enumerate(container_objects): |
container_rect = Rect(container['bbox']) |
intersect_area = container_rect.intersect(package['bbox']).get_area() |
overlap_fraction = intersect_area / package_area |
match_scores.append({'container': container, 'container_num': container_num, 'score': overlap_fraction}) |
sorted_match_scores = sort_objects_by_score(match_scores) |
best_match_score = sorted_match_scores[0] |
best_match_scores.append(best_match_score['score']) |
if forced_assignment or best_match_score['score'] >= overlap_threshold: |
container_assignments[best_match_score['container_num']].append(package_num) |
package_assignments[package_num].append(best_match_score['container_num']) |
if not unique_assignment: |
for match_score in sorted_match_scores[1:]: |
if match_score['score'] >= overlap_threshold: |
container_assignments[match_score['container_num']].append(package_num) |
package_assignments[package_num].append(match_score['container_num']) |
else: |
break |
return container_assignments, package_assignments, best_match_scores |
def sort_objects_by_score(objects, reverse=True): |
""" |
Put any set of objects in order from high score to low score. |
""" |
if reverse: |
sign = -1 |
else: |
sign = 1 |
return sorted(objects, key=lambda k: sign*k['score']) |
def remove_objects_without_content(page_spans, objects): |
""" |
Remove any objects (these can be rows, columns, supercells, etc.) that don't |
have any text associated with them. |
""" |
for obj in objects[:]: |
object_text, _ = extract_text_inside_bbox(page_spans, obj['bbox']) |
if len(object_text.strip()) == 0: |
objects.remove(obj) |
def extract_text_inside_bbox(spans, bbox): |
""" |
Extract the text inside a bounding box. |
""" |
bbox_spans = get_bbox_span_subset(spans, bbox) |
bbox_text = extract_text_from_spans(bbox_spans, remove_integer_superscripts=True) |
return bbox_text, bbox_spans |
def get_bbox_span_subset(spans, bbox, threshold=0.5): |
""" |
Reduce the set of spans to those that fall within a bounding box. |
threshold: the fraction of the span that must overlap with the bbox. |
""" |
span_subset = [] |
for span in spans: |
if overlaps(span['bbox'], bbox, threshold): |
span_subset.append(span) |
return span_subset |
def overlaps(bbox1, bbox2, threshold=0.5): |
""" |
Test if more than "threshold" fraction of bbox1 overlaps with bbox2. |
""" |
rect1 = Rect(list(bbox1)) |
area1 = rect1.get_area() |
if area1 == 0: |
return False |
return rect1.intersect(list(bbox2)).get_area()/area1 >= threshold |
def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True): |
""" |
Convert a collection of page tokens/words/spans into a single text string. |
""" |
if join_with_space: |
join_char = " " |
else: |
join_char = "" |
spans_copy = spans[:] |
if remove_integer_superscripts: |
for span in spans: |
flags = span['flags'] |
if flags & 2**0: |
if is_int(span['text']): |
spans_copy.remove(span) |
else: |
span['superscript'] = True |
if len(spans_copy) == 0: |
return "" |
spans_copy.sort(key=lambda span: span['span_num']) |
spans_copy.sort(key=lambda span: span['line_num']) |
spans_copy.sort(key=lambda span: span['block_num']) |
line_texts = [] |
line_span_texts = [spans_copy[0]['text']] |
for span1, span2 in zip(spans_copy[:-1], spans_copy[1:]): |
if not span1['block_num'] == span2['block_num'] or not span1['line_num'] == span2['line_num']: |
line_text = join_char.join(line_span_texts).strip() |
if (len(line_text) > 0 |
and not line_text[-1] == ' ' |
and not (len(line_text) > 1 and line_text[-1] == "-" and not line_text[-2] == ' ')): |
if not join_with_space: |
line_text += ' ' |
line_texts.append(line_text) |
line_span_texts = [span2['text']] |
else: |
line_span_texts.append(span2['text']) |
line_text = join_char.join(line_span_texts) |
line_texts.append(line_text) |
return join_char.join(line_texts).strip() |
def sort_objects_left_to_right(objs): |
""" |
Put the objects in order from left to right. |
""" |
return sorted(objs, key=lambda k: k['bbox'][0] + k['bbox'][2]) |
def sort_objects_top_to_bottom(objs): |
""" |
Put the objects in order from top to bottom. |
""" |
return sorted(objs, key=lambda k: k['bbox'][1] + k['bbox'][3]) |
def align_columns(columns, bbox): |
""" |
For every column, align the top and bottom boundaries to the final |
table bounding box. |
""" |
try: |
for column in columns: |
column['bbox'][1] = bbox[1] |
column['bbox'][3] = bbox[3] |
except Exception as err: |
print("Could not align columns: {}".format(err)) |
pass |
return columns |
def align_rows(rows, bbox): |
""" |
For every row, align the left and right boundaries to the final |
table bounding box. |
""" |
try: |
for row in rows: |
row['bbox'][0] = bbox[0] |
row['bbox'][2] = bbox[2] |
except Exception as err: |
print("Could not align rows: {}".format(err)) |
pass |
return rows |
def refine_table_structures(table_bbox, table_structures, page_spans, class_thresholds): |
""" |
Apply operations to the detected table structure objects such as |
thresholding, NMS, and alignment. |
""" |
rows = table_structures["rows"] |
columns = table_structures['columns'] |
headers = table_structures['headers'] |
headers = apply_threshold(headers, class_thresholds["table column header"]) |
headers = nms(headers) |
headers = align_headers(headers, rows) |
supercells = [elem for elem in table_structures['supercells'] if not elem['subheader']] |
subheaders = [elem for elem in table_structures['supercells'] if elem['subheader']] |
supercells = apply_threshold(supercells, class_thresholds["table spanning cell"]) |
subheaders = apply_threshold(subheaders, class_thresholds["table projected row header"]) |
supercells += subheaders |
supercells = align_supercells(supercells, rows, columns) |
supercells = nms_supercells(supercells) |
header_supercell_tree(supercells) |
table_structures['columns'] = columns |
table_structures['rows'] = rows |
table_structures['supercells'] = supercells |
table_structures['headers'] = headers |
return table_structures |
def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_metric="score", keep_higher=True): |
""" |
A customizable version of non-maxima suppression (NMS). |
Default behavior: If a lower-confidence object overlaps more than 5% of its area |
with a higher-confidence object, remove the lower-confidence object. |
objects: set of dicts; each object dict must have a 'bbox' and a 'score' field |
match_criteria: how to measure how much two objects "overlap" |
match_threshold: the cutoff for determining that overlap requires suppression of one object |
keep_metric: which metric to use to determine the object to keep |
keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower |
""" |
if len(objects) == 0: |
return [] |
if keep_metric=="score": |
objects = sort_objects_by_score(objects, reverse=keep_higher) |
elif keep_metric=="area": |
objects = sort_objects_by_area(objects, reverse=keep_higher) |
num_objects = len(objects) |
suppression = [False for obj in objects] |
for object2_num in range(1, num_objects): |
object2_rect = Rect(objects[object2_num]['bbox']) |
object2_area = object2_rect.get_area() |
for object1_num in range(object2_num): |
if not suppression[object1_num]: |
object1_rect = Rect(objects[object1_num]['bbox']) |
object1_area = object1_rect.get_area() |
intersect_area = object1_rect.intersect(object2_rect).get_area() |
try: |
if match_criteria=="object1_overlap": |
metric = intersect_area / object1_area |
elif match_criteria=="object2_overlap": |
metric = intersect_area / object2_area |
elif match_criteria=="iou": |
metric = intersect_area / (object1_area + object2_area - intersect_area) |
if metric >= match_threshold: |
suppression[object2_num] = True |
break |
except Exception: |
pass |
return [obj for idx, obj in enumerate(objects) if not suppression[idx]] |
def align_headers(headers, rows): |
""" |
Adjust the header boundary to be the convex hull of the rows it intersects |
at least 50% of the height of. |
For now, we are not supporting tables with multiple headers, so we need to |
eliminate anything besides the top-most header. |
""" |
aligned_headers = [] |
for row in rows: |
row['header'] = False |
header_row_nums = [] |
for header in headers: |
for row_num, row in enumerate(rows): |
row_height = row['bbox'][3] - row['bbox'][1] |
min_row_overlap = max(row['bbox'][1], header['bbox'][1]) |
max_row_overlap = min(row['bbox'][3], header['bbox'][3]) |
overlap_height = max_row_overlap - min_row_overlap |
if overlap_height / row_height >= 0.5: |
header_row_nums.append(row_num) |
if len(header_row_nums) == 0: |
return aligned_headers |
header_rect = Rect() |
if header_row_nums[0] > 0: |
header_row_nums = list(range(header_row_nums[0]+1)) + header_row_nums |
last_row_num = -1 |
for row_num in header_row_nums: |
if row_num == last_row_num + 1: |
row = rows[row_num] |
row['header'] = True |
header_rect = header_rect.include_rect(row['bbox']) |
last_row_num = row_num |
else: |
break |
header = {'bbox': list(header_rect)} |
aligned_headers.append(header) |
return aligned_headers |
def align_supercells(supercells, rows, columns): |
""" |
For each supercell, align it to the rows it intersects 50% of the height of, |
and the columns it intersects 50% of the width of. |
Eliminate supercells for which there are no rows and columns it intersects 50% with. |
""" |
aligned_supercells = [] |
for supercell in supercells: |
supercell['header'] = False |
row_bbox_rect = None |
col_bbox_rect = None |
intersecting_header_rows = set() |
intersecting_data_rows = set() |
for row_num, row in enumerate(rows): |
row_height = row['bbox'][3] - row['bbox'][1] |
supercell_height = supercell['bbox'][3] - supercell['bbox'][1] |
min_row_overlap = max(row['bbox'][1], supercell['bbox'][1]) |
max_row_overlap = min(row['bbox'][3], supercell['bbox'][3]) |
overlap_height = max_row_overlap - min_row_overlap |
if 'span' in supercell: |
overlap_fraction = max(overlap_height/row_height, |
overlap_height/supercell_height) |
else: |
overlap_fraction = overlap_height / row_height |
if overlap_fraction >= 0.5: |
if 'header' in row and row['header']: |
intersecting_header_rows.add(row_num) |
else: |
intersecting_data_rows.add(row_num) |
supercell['header'] = False |
if len(intersecting_data_rows) > 0 and len(intersecting_header_rows) > 0: |
if len(intersecting_data_rows) > len(intersecting_header_rows): |
intersecting_header_rows = set() |
else: |
intersecting_data_rows = set() |
if len(intersecting_header_rows) > 0: |
supercell['header'] = True |
elif 'span' in supercell: |
continue |
intersecting_rows = intersecting_data_rows.union(intersecting_header_rows) |
for row_num in intersecting_rows: |
if row_bbox_rect is None: |
row_bbox_rect = Rect(rows[row_num]['bbox']) |
else: |
row_bbox_rect = row_bbox_rect.include_rect(rows[row_num]['bbox']) |
if row_bbox_rect is None: |
continue |
intersecting_cols = [] |
for col_num, col in enumerate(columns): |
col_width = col['bbox'][2] - col['bbox'][0] |
supercell_width = supercell['bbox'][2] - supercell['bbox'][0] |
min_col_overlap = max(col['bbox'][0], supercell['bbox'][0]) |
max_col_overlap = min(col['bbox'][2], supercell['bbox'][2]) |
overlap_width = max_col_overlap - min_col_overlap |
if 'span' in supercell: |
overlap_fraction = max(overlap_width/col_width, |
overlap_width/supercell_width) |
if supercell['header']: |
overlap_fraction = overlap_fraction * 2 |
else: |
overlap_fraction = overlap_width / col_width |
if overlap_fraction >= 0.5: |
intersecting_cols.append(col_num) |
if col_bbox_rect is None: |
col_bbox_rect = Rect(col['bbox']) |
else: |
col_bbox_rect = col_bbox_rect.include_rect(col['bbox']) |
if col_bbox_rect is None: |
continue |
supercell_bbox = list(row_bbox_rect.intersect(col_bbox_rect)) |
supercell['bbox'] = supercell_bbox |
if (len(intersecting_rows) > 0 and len(intersecting_cols) > 0 |
and (len(intersecting_rows) > 1 or len(intersecting_cols) > 1)): |
supercell['row_numbers'] = list(intersecting_rows) |
supercell['column_numbers'] = intersecting_cols |
aligned_supercells.append(supercell) |
if 'span' in supercell and supercell['header'] and len(supercell['column_numbers']) > 1: |
for row_num in range(0, min(supercell['row_numbers'])): |
new_supercell = {'row_numbers': [row_num], 'column_numbers': supercell['column_numbers'], |
'score': supercell['score'], 'propagated': True} |
new_supercell_columns = [columns[idx] for idx in supercell['column_numbers']] |
new_supercell_rows = [rows[idx] for idx in supercell['row_numbers']] |
bbox = [min([column['bbox'][0] for column in new_supercell_columns]), |
min([row['bbox'][1] for row in new_supercell_rows]), |
max([column['bbox'][2] for column in new_supercell_columns]), |
max([row['bbox'][3] for row in new_supercell_rows])] |
new_supercell['bbox'] = bbox |
aligned_supercells.append(new_supercell) |
return aligned_supercells |
def nms_supercells(supercells): |
""" |
A NMS scheme for supercells that first attempts to shrink supercells to |
resolve overlap. |
If two supercells overlap the same (sub)cell, shrink the lower confidence |
supercell to resolve the overlap. If shrunk supercell is empty, remove it. |
""" |
supercells = sort_objects_by_score(supercells) |
num_supercells = len(supercells) |
suppression = [False for supercell in supercells] |
for supercell2_num in range(1, num_supercells): |
supercell2 = supercells[supercell2_num] |
for supercell1_num in range(supercell2_num): |
supercell1 = supercells[supercell1_num] |
remove_supercell_overlap(supercell1, supercell2) |
if ((len(supercell2['row_numbers']) < 2 and len(supercell2['column_numbers']) < 2) |
or len(supercell2['row_numbers']) == 0 or len(supercell2['column_numbers']) == 0): |
suppression[supercell2_num] = True |
return [obj for idx, obj in enumerate(supercells) if not suppression[idx]] |
def header_supercell_tree(supercells): |
""" |
Make sure no supercell in the header is below more than one supercell in any row above it. |
The cells in the header form a tree, but a supercell with more than one supercell in a row |
above it means that some cell has more than one parent, which is not allowed. Eliminate |
any supercell that would cause this to be violated. |
""" |
header_supercells = [supercell for supercell in supercells if 'header' in supercell and supercell['header']] |
header_supercells = sort_objects_by_score(header_supercells) |
for header_supercell in header_supercells[:]: |
ancestors_by_row = defaultdict(int) |
min_row = min(header_supercell['row_numbers']) |
for header_supercell2 in header_supercells: |
max_row2 = max(header_supercell2['row_numbers']) |
if max_row2 < min_row: |
if (set(header_supercell['column_numbers']).issubset( |
set(header_supercell2['column_numbers']))): |
for row2 in header_supercell2['row_numbers']: |
ancestors_by_row[row2] += 1 |
for row in range(0, min_row): |
if not ancestors_by_row[row] == 1: |
supercells.remove(header_supercell) |
break |
def table_structure_to_cells(table_structures, table_spans, table_bbox): |
""" |
Assuming the row, column, supercell, and header bounding boxes have |
been refined into a set of consistent table structures, process these |
table structures into table cells. This is a universal representation |
format for the table, which can later be exported to Pandas or CSV formats. |
Classify the cells as header/access cells or data cells |
based on if they intersect with the header bounding box. |
""" |
columns = table_structures['columns'] |
rows = table_structures['rows'] |
supercells = table_structures['supercells'] |
cells = [] |
subcells = [] |
for column_num, column in enumerate(columns): |
for row_num, row in enumerate(rows): |
column_rect = Rect(list(column['bbox'])) |
row_rect = Rect(list(row['bbox'])) |
cell_rect = row_rect.intersect(column_rect) |
header = 'header' in row and row['header'] |
cell = {'bbox': list(cell_rect), 'column_nums': [column_num], 'row_nums': [row_num], |
'header': header} |
cell['subcell'] = False |
for supercell in supercells: |
supercell_rect = Rect(list(supercell['bbox'])) |
if (supercell_rect.intersect(cell_rect).get_area() |
/ cell_rect.get_area()) > 0.5: |
cell['subcell'] = True |
break |
if cell['subcell']: |
subcells.append(cell) |
else: |
cell['subheader'] = False |
cells.append(cell) |
for supercell in supercells: |
supercell_rect = Rect(list(supercell['bbox'])) |
cell_columns = set() |
cell_rows = set() |
cell_rect = None |
header = True |
for subcell in subcells: |
subcell_rect = Rect(list(subcell['bbox'])) |
subcell_rect_area = subcell_rect.get_area() |
if (subcell_rect.intersect(supercell_rect).get_area() |
/ subcell_rect_area) > 0.5: |
if cell_rect is None: |
cell_rect = Rect(list(subcell['bbox'])) |
else: |
cell_rect.include_rect(Rect(list(subcell['bbox']))) |
cell_rows = cell_rows.union(set(subcell['row_nums'])) |
cell_columns = cell_columns.union(set(subcell['column_nums'])) |
header = header and 'header' in subcell and subcell['header'] |
if len(cell_rows) > 0 and len(cell_columns) > 0: |
cell = {'bbox': list(cell_rect), 'column_nums': list(cell_columns), 'row_nums': list(cell_rows), |
'header': header, 'subheader': supercell['subheader']} |
cells.append(cell) |
_, _, cell_match_scores = slot_into_containers(cells, table_spans) |
try: |
mean_match_score = sum(cell_match_scores) / len(cell_match_scores) |
min_match_score = min(cell_match_scores) |
confidence_score = (mean_match_score + min_match_score)/2 |
except: |
confidence_score = 0 |
dilated_columns = columns |
dilated_rows = rows |
for cell in cells: |
column_rect = Rect() |
for column_num in cell['column_nums']: |
column_rect.include_rect(list(dilated_columns[column_num]['bbox'])) |
row_rect = Rect() |
for row_num in cell['row_nums']: |
row_rect.include_rect(list(dilated_rows[row_num]['bbox'])) |
cell_rect = column_rect.intersect(row_rect) |
cell['bbox'] = list(cell_rect) |
span_nums_by_cell, _, _ = slot_into_containers(cells, table_spans, overlap_threshold=0.001, |
unique_assignment=True, forced_assignment=False) |
for cell, cell_span_nums in zip(cells, span_nums_by_cell): |
cell_spans = [table_spans[num] for num in cell_span_nums] |
cell['spans'] = cell_spans |
num_rows = len(rows) |
rows = sort_objects_top_to_bottom(rows) |
num_columns = len(columns) |
columns = sort_objects_left_to_right(columns) |
min_y_values_by_row = defaultdict(list) |
max_y_values_by_row = defaultdict(list) |
min_x_values_by_column = defaultdict(list) |
max_x_values_by_column = defaultdict(list) |
for cell in cells: |
min_row = min(cell["row_nums"]) |
max_row = max(cell["row_nums"]) |
min_column = min(cell["column_nums"]) |
max_column = max(cell["column_nums"]) |
for span in cell['spans']: |
min_x_values_by_column[min_column].append(span['bbox'][0]) |
min_y_values_by_row[min_row].append(span['bbox'][1]) |
max_x_values_by_column[max_column].append(span['bbox'][2]) |
max_y_values_by_row[max_row].append(span['bbox'][3]) |
for row_num, row in enumerate(rows): |
if len(min_x_values_by_column[0]) > 0: |
row['bbox'][0] = min(min_x_values_by_column[0]) |
if len(min_y_values_by_row[row_num]) > 0: |
row['bbox'][1] = min(min_y_values_by_row[row_num]) |
if len(max_x_values_by_column[num_columns-1]) > 0: |
row['bbox'][2] = max(max_x_values_by_column[num_columns-1]) |
if len(max_y_values_by_row[row_num]) > 0: |
row['bbox'][3] = max(max_y_values_by_row[row_num]) |
for column_num, column in enumerate(columns): |
if len(min_x_values_by_column[column_num]) > 0: |
column['bbox'][0] = min(min_x_values_by_column[column_num]) |
if len(min_y_values_by_row[0]) > 0: |
column['bbox'][1] = min(min_y_values_by_row[0]) |
if len(max_x_values_by_column[column_num]) > 0: |
column['bbox'][2] = max(max_x_values_by_column[column_num]) |
if len(max_y_values_by_row[num_rows-1]) > 0: |
column['bbox'][3] = max(max_y_values_by_row[num_rows-1]) |
for cell in cells: |
row_rect = Rect() |
column_rect = Rect() |
for row_num in cell['row_nums']: |
row_rect.include_rect(list(rows[row_num]['bbox'])) |
for column_num in cell['column_nums']: |
column_rect.include_rect(list(columns[column_num]['bbox'])) |
cell_rect = row_rect.intersect(column_rect) |
if cell_rect.get_area() > 0: |
cell['bbox'] = list(cell_rect) |
pass |
return cells, confidence_score |
def remove_supercell_overlap(supercell1, supercell2): |
""" |
This function resolves overlap between supercells (supercells must be |
disjoint) by iteratively shrinking supercells by the fewest grid cells |
necessary to resolve the overlap. |
Example: |
If two supercells overlap at grid cell (R, C), and supercell #1 is less |
confident than supercell #2, we eliminate either row R from supercell #1 |
or column C from supercell #1 by comparing the number of columns in row R |
versus the number of rows in column C. If the number of columns in row R |
is less than the number of rows in column C, we eliminate row R from |
supercell #1. This resolves the overlap by removing fewer grid cells from |
supercell #1 than if we eliminated column C from it. |
""" |
common_rows = set(supercell1['row_numbers']).intersection(set(supercell2['row_numbers'])) |
common_columns = set(supercell1['column_numbers']).intersection(set(supercell2['column_numbers'])) |
while len(common_rows) > 0 and len(common_columns) > 0: |
if len(supercell2['row_numbers']) < len(supercell2['column_numbers']): |
min_column = min(supercell2['column_numbers']) |
max_column = max(supercell2['column_numbers']) |
if max_column in common_columns: |
common_columns.remove(max_column) |
supercell2['column_numbers'].remove(max_column) |
elif min_column in common_columns: |
common_columns.remove(min_column) |
supercell2['column_numbers'].remove(min_column) |
else: |
supercell2['column_numbers'] = [] |
common_columns = set() |
else: |
min_row = min(supercell2['row_numbers']) |
max_row = max(supercell2['row_numbers']) |
if max_row in common_rows: |
common_rows.remove(max_row) |
supercell2['row_numbers'].remove(max_row) |
elif min_row in common_rows: |
common_rows.remove(min_row) |
supercell2['row_numbers'].remove(min_row) |
else: |
supercell2['row_numbers'] = [] |
common_rows = set() |