|
from typing import List, Tuple |
|
|
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
|
|
from surya.model.detection.segformer import SegformerForRegressionMask |
|
from surya.postprocessing.heatmap import get_and_clean_boxes |
|
from surya.postprocessing.affinity import get_vertical_lines |
|
from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb |
|
from surya.schema import TextDetectionResult |
|
from surya.settings import settings |
|
from tqdm import tqdm |
|
from concurrent.futures import ProcessPoolExecutor |
|
import torch.nn.functional as F |
|
|
|
|
|
def get_batch_size(): |
|
batch_size = settings.DETECTOR_BATCH_SIZE |
|
if batch_size is None: |
|
batch_size = 6 |
|
if settings.TORCH_DEVICE_MODEL == "cuda": |
|
batch_size = 24 |
|
return batch_size |
|
|
|
|
|
def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]: |
|
assert all([isinstance(image, Image.Image) for image in images]) |
|
if batch_size is None: |
|
batch_size = get_batch_size() |
|
heatmap_count = model.config.num_labels |
|
|
|
images = [image.convert("RGB") for image in images] |
|
|
|
orig_sizes = [image.size for image in images] |
|
splits_per_image = [get_total_splits(size, processor) for size in orig_sizes] |
|
|
|
batches = [] |
|
current_batch_size = 0 |
|
current_batch = [] |
|
for i in range(len(images)): |
|
if current_batch_size + splits_per_image[i] > batch_size: |
|
if len(current_batch) > 0: |
|
batches.append(current_batch) |
|
current_batch = [] |
|
current_batch_size = 0 |
|
current_batch.append(i) |
|
current_batch_size += splits_per_image[i] |
|
|
|
if len(current_batch) > 0: |
|
batches.append(current_batch) |
|
|
|
all_preds = [] |
|
for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"): |
|
batch_image_idxs = batches[batch_idx] |
|
batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs]) |
|
|
|
split_index = [] |
|
split_heights = [] |
|
image_splits = [] |
|
for image_idx, image in enumerate(batch_images): |
|
image_parts, split_height = split_image(image, processor) |
|
image_splits.extend(image_parts) |
|
split_index.extend([image_idx] * len(image_parts)) |
|
split_heights.extend(split_height) |
|
|
|
image_splits = [prepare_image_detection(image, processor) for image in image_splits] |
|
|
|
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device) |
|
|
|
with torch.inference_mode(): |
|
pred = model(pixel_values=batch) |
|
|
|
logits = pred.logits |
|
correct_shape = [processor.size["height"], processor.size["width"]] |
|
current_shape = list(logits.shape[2:]) |
|
if current_shape != correct_shape: |
|
logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False) |
|
|
|
logits = logits.cpu().detach().numpy().astype(np.float32) |
|
preds = [] |
|
for i, (idx, height) in enumerate(zip(split_index, split_heights)): |
|
|
|
|
|
if len(preds) <= idx: |
|
preds.append([logits[i][k] for k in range(heatmap_count)]) |
|
else: |
|
heatmaps = preds[idx] |
|
pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] |
|
|
|
if height < processor.size["height"]: |
|
|
|
pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps] |
|
|
|
for k in range(heatmap_count): |
|
heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) |
|
preds[idx] = heatmaps |
|
|
|
all_preds.extend(preds) |
|
|
|
assert len(all_preds) == len(images) |
|
assert all([len(pred) == heatmap_count for pred in all_preds]) |
|
return all_preds, orig_sizes |
|
|
|
|
|
def parallel_get_lines(preds, orig_sizes): |
|
heatmap, affinity_map = preds |
|
heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) |
|
aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) |
|
affinity_size = list(reversed(affinity_map.shape)) |
|
heatmap_size = list(reversed(heatmap.shape)) |
|
bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) |
|
vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes) |
|
|
|
result = TextDetectionResult( |
|
bboxes=bboxes, |
|
vertical_lines=vertical_lines, |
|
heatmap=heat_img, |
|
affinity_map=aff_img, |
|
image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]] |
|
) |
|
return result |
|
|
|
|
|
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: |
|
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size) |
|
results = [] |
|
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: |
|
for i in range(len(images)): |
|
result = parallel_get_lines(preds[i], orig_sizes[i]) |
|
results.append(result) |
|
else: |
|
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) |
|
with ProcessPoolExecutor(max_workers=max_workers) as executor: |
|
results = list(executor.map(parallel_get_lines, preds, orig_sizes)) |
|
|
|
return results |
|
|
|
|
|
|