|
from huggingface_hub import hf_hub_download |
|
from shapely.validation import make_valid |
|
from shapely.geometry import Polygon |
|
from ultralytics import YOLO |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
|
|
from reading_order import OrderPolygons |
|
|
|
class SegmentImage: |
|
"""Class for segmenting document image regions and text lines.""" |
|
def __init__(self, |
|
line_model_path, |
|
device, |
|
line_iou=0.5, |
|
region_iou=0.5, |
|
line_overlap=0.5, |
|
line_nms_iou=0.7, |
|
region_nms_iou=0.3, |
|
line_conf_threshold=0.25, |
|
region_conf_threshold=0.25, |
|
region_model_path=None, |
|
order_regions=True, |
|
region_half_precision=False, |
|
line_half_precision=False): |
|
|
|
|
|
self.line_model_path = line_model_path |
|
|
|
self.region_model_path = region_model_path |
|
|
|
|
|
self.line_nms_iou = line_nms_iou |
|
self.region_nms_iou = region_nms_iou |
|
|
|
self.line_iou = line_iou |
|
|
|
self.region_iou = region_iou |
|
|
|
self.line_overlap = line_overlap |
|
|
|
self.line_conf_threshold = line_conf_threshold |
|
|
|
self.region_conf_threshold = region_conf_threshold |
|
|
|
self.device = device |
|
|
|
self.order_regions = order_regions |
|
|
|
self.region_half_precision = region_half_precision |
|
self.line_half_precision = line_half_precision |
|
self.order_poly = OrderPolygons() |
|
|
|
self.line_model = self.init_line_model() |
|
if self.region_model_path: |
|
self.region_model = self.init_region_model() |
|
|
|
def init_line_model(self): |
|
"""Function for initializing the line detection model.""" |
|
try: |
|
|
|
cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt") |
|
line_model = YOLO(cached_model_path) |
|
return line_model |
|
except Exception as e: |
|
print('Failed to load the line detection model: %s' % e) |
|
|
|
def init_region_model(self): |
|
"""Function for initializing the region detection model.""" |
|
try: |
|
|
|
cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt") |
|
region_model = YOLO(cached_model_path) |
|
return region_model |
|
except Exception as e: |
|
print('Failed to load the region detection model: %s' % e) |
|
|
|
def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape): |
|
"""Function for creating unique id for each detected region.""" |
|
n = min(len(classes), len(coords)) |
|
res = [] |
|
for i in range(n): |
|
|
|
region_id = str(i) |
|
|
|
region_type = names[classes[i]] |
|
poly_dict = {'coords': coords[i], |
|
'max_min': max_min[i], |
|
'class': str(classes[i]), |
|
'name': region_type, |
|
'conf': box_confs[i], |
|
'id': region_id, |
|
'img_shape': img_shape} |
|
res.append(poly_dict) |
|
return res |
|
|
|
def get_max_min(self, polygons): |
|
"""Creates an array with the minimum and maximum |
|
x and y values of the input polygons.""" |
|
n_rows = len(polygons) |
|
xy_array = np.zeros([n_rows, 4]) |
|
for i, poly in enumerate(polygons): |
|
x = [point[0] for point in poly] |
|
y = [point[1] for point in poly] |
|
if x: |
|
xy_array[i,0] = max(x) |
|
xy_array[i,1] = min(x) |
|
if y: |
|
xy_array[i,2] = max(y) |
|
xy_array[i,3] = min(y) |
|
return xy_array |
|
|
|
def validate_polygon(self, polygon): |
|
""""Function for testing and correcting the validity of polygons.""" |
|
if len(polygon) > 2: |
|
polygon = Polygon(polygon) |
|
if not polygon.is_valid: |
|
polygon = make_valid(polygon) |
|
return polygon |
|
else: |
|
return None |
|
|
|
def get_iou(self, poly1, poly2): |
|
"""Function for calculating Intersection over Union (IoU) values.""" |
|
|
|
iou = 0 |
|
poly1 = self.validate_polygon(poly1) |
|
poly2 = self.validate_polygon(poly2) |
|
|
|
if poly1 and poly2: |
|
if poly1.intersects(poly2): |
|
|
|
intersect = poly1.intersection(poly2).area |
|
|
|
uni = poly1.union(poly2) |
|
|
|
iou = intersect / uni.area |
|
return iou |
|
|
|
def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None): |
|
"""Merges polygons that have an IoU value |
|
above the given threshold.""" |
|
new_polygons = [] |
|
dropped = set() |
|
|
|
|
|
for i in range(0, len(polygons)): |
|
poly1 = self.validate_polygon(polygons[i]) |
|
merged = None |
|
for j in range(i+1, len(polygons)): |
|
poly2 = self.validate_polygon(polygons[j]) |
|
if poly1 and poly2: |
|
if poly1.intersects(poly2): |
|
overlap = False |
|
intersect = poly1.intersection(poly2) |
|
uni = poly1.union(poly2) |
|
|
|
iou = intersect.area / uni.area |
|
if overlap_threshold: |
|
overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area)) |
|
if (iou > iou_threshold) or overlap: |
|
if merged: |
|
|
|
|
|
merged = uni.union(merged) |
|
dropped.add(j) |
|
else: |
|
merged = uni |
|
|
|
|
|
dropped.add(i) |
|
dropped.add(j) |
|
if merged: |
|
if merged.geom_type in ['GeometryCollection','MultiPolygon']: |
|
for geom in merged.geoms: |
|
if geom.geom_type == 'Polygon': |
|
new_polygons.append(list(geom.exterior.coords)) |
|
elif merged.geom_type == 'Polygon': |
|
new_polygons.append(list(merged.exterior.coords)) |
|
res = [i for j, i in enumerate(polygons) if j not in dropped] |
|
res += new_polygons |
|
|
|
return res |
|
|
|
def get_region_preds(self, img): |
|
"""Function for predicting text region coordinates.""" |
|
results = self.region_model.predict(source=img, |
|
device=self.device, |
|
conf=self.region_conf_threshold, |
|
half=bool(self.region_half_precision), |
|
iou=self.region_nms_iou) |
|
results = results[0].cpu() |
|
if results.masks: |
|
|
|
coords = results.masks.xy |
|
|
|
coords = self.merge_polygons(coords, self.region_iou) |
|
|
|
max_min = self.get_max_min(coords).tolist() |
|
|
|
classes = results.boxes.cls.tolist() |
|
|
|
names = results.names |
|
|
|
box_confs = results.boxes.conf.tolist() |
|
|
|
img_shape = results.orig_shape |
|
res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape) |
|
return res |
|
else: |
|
return None |
|
|
|
|
|
def get_line_preds(self, img): |
|
"""Function for predicting text line coordinates.""" |
|
results = self.line_model.predict(source=img, |
|
device=self.device, |
|
conf=self.line_conf_threshold, |
|
half=bool(self.line_half_precision), |
|
iou=self.line_nms_iou) |
|
results = results[0].cpu() |
|
if results.masks: |
|
|
|
coords = results.masks.xy |
|
|
|
coords = self.merge_polygons(coords, self.line_iou, self.line_overlap) |
|
|
|
max_min = self.get_max_min(coords).tolist() |
|
|
|
box_confs = results.boxes.conf.tolist() |
|
res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs} |
|
return res_dict |
|
else: |
|
return None |
|
|
|
def get_dist(self, line_polygon, regions): |
|
"""Function for finding the closest region to the text line.""" |
|
dist, reg_id = 1000000, None |
|
line_polygon = self.validate_polygon(line_polygon) |
|
|
|
if line_polygon: |
|
for region in regions: |
|
|
|
region_polygon = self.validate_polygon(region['coords']) |
|
if region_polygon: |
|
line_reg_dist = line_polygon.distance(region_polygon) |
|
if line_reg_dist < dist: |
|
dist = line_reg_dist |
|
reg_id = region['id'] |
|
return reg_id |
|
|
|
def get_line_regions(self, lines, regions): |
|
"""Function for connecting each text line to one region.""" |
|
lines_list = [] |
|
for i in range(len(lines['coords'])): |
|
iou, reg_id, conf = 0, '', 0.0 |
|
max_min = [0.0, 0.0, 0.0, 0.0] |
|
polygon = lines['coords'][i] |
|
for region in regions: |
|
line_reg_iou = self.get_iou(polygon, region['coords']) |
|
if line_reg_iou > iou: |
|
iou = line_reg_iou |
|
reg_id = region['id'] |
|
|
|
|
|
if iou == 0: |
|
reg_id = self.get_dist(polygon, regions) |
|
|
|
if (len(lines['max_min']) - 1) >= i: |
|
max_min = lines['max_min'][i] |
|
|
|
if (len(lines['confs']) - 1) >= i: |
|
conf = lines['confs'][i] |
|
|
|
new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf} |
|
lines_list.append(new_line) |
|
return lines_list |
|
|
|
def order_regions_lines(self, lines, regions): |
|
"""Function for ordering line predictions inside each region.""" |
|
regions_with_rows = [] |
|
region_max_mins = [] |
|
for i, region in enumerate(regions): |
|
line_max_mins = [] |
|
line_confs = [] |
|
line_polygons = [] |
|
for line in lines: |
|
if line['reg_id'] == region['id']: |
|
line_max_mins.append(line['max_min']) |
|
line_confs.append(line['conf']) |
|
line_polygons.append(line['polygon']) |
|
if line_polygons: |
|
|
|
|
|
line_order = self.order_poly.order(line_max_mins) |
|
line_polygons = [line_polygons[i] for i in line_order] |
|
line_confs = [line_confs[i] for i in line_order] |
|
new_region = {'region_coords': region['coords'], |
|
'region_name': region['name'], |
|
'lines': line_polygons, |
|
'line_confs': line_confs, |
|
'region_conf': region['conf'], |
|
'img_shape': region['img_shape']} |
|
region_max_mins.append(region['max_min']) |
|
regions_with_rows.append(new_region) |
|
else: |
|
continue |
|
|
|
if self.order_regions: |
|
region_order = self.order_poly.order(region_max_mins) |
|
regions_with_rows = [regions_with_rows[i] for i in region_order] |
|
|
|
return regions_with_rows |
|
|
|
def get_default_region(self, image): |
|
"""Function for creating a default region if no regions are detected.""" |
|
w, h = image.size |
|
region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]], |
|
'max_min': [w, 0.0, h, 0.0], |
|
'class': '0', |
|
'name': "paragraph", |
|
'conf': 0.0, |
|
'id': '0', |
|
'img_shape': (h, w)} |
|
return [region] |
|
|
|
def get_segmentation(self, image): |
|
"""Segment input image into ordered text lines or ordered text regions and text lines.""" |
|
line_preds = self.get_line_preds(image) |
|
if line_preds: |
|
|
|
region_preds = self.get_region_preds(image) |
|
if not region_preds: |
|
region_preds = self.get_default_region(image) |
|
print(f'No regions detected from image {image}') |
|
lines_with_regions = self.get_line_regions(line_preds, region_preds) |
|
ordered_regions = self.order_regions_lines(lines_with_regions, region_preds) |
|
return ordered_regions |
|
else: |
|
print(f'No text lines detected from image {image}') |
|
return None |
|
|
|
|