from typing import List, Tuple import torch import numpy as np from PIL import Image from surya.model.detection.segformer import SegformerForRegressionMask from surya.postprocessing.heatmap import get_and_clean_boxes from surya.postprocessing.affinity import get_vertical_lines from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb from surya.schema import TextDetectionResult from surya.settings import settings from tqdm import tqdm from concurrent.futures import ProcessPoolExecutor import torch.nn.functional as F def get_batch_size(): batch_size = settings.DETECTOR_BATCH_SIZE if batch_size is None: batch_size = 6 if settings.TORCH_DEVICE_MODEL == "cuda": batch_size = 24 return batch_size def batch_detection(images: List, model: SegformerForRegressionMask, processor, batch_size=None) -> Tuple[List[List[np.ndarray]], List[Tuple[int, int]]]: assert all([isinstance(image, Image.Image) for image in images]) if batch_size is None: batch_size = get_batch_size() heatmap_count = model.config.num_labels images = [image.convert("RGB") for image in images] # also copies the images orig_sizes = [image.size for image in images] splits_per_image = [get_total_splits(size, processor) for size in orig_sizes] batches = [] current_batch_size = 0 current_batch = [] for i in range(len(images)): if current_batch_size + splits_per_image[i] > batch_size: if len(current_batch) > 0: batches.append(current_batch) current_batch = [] current_batch_size = 0 current_batch.append(i) current_batch_size += splits_per_image[i] if len(current_batch) > 0: batches.append(current_batch) all_preds = [] for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"): batch_image_idxs = batches[batch_idx] batch_images = convert_if_not_rgb([images[j] for j in batch_image_idxs]) split_index = [] split_heights = [] image_splits = [] for image_idx, image in enumerate(batch_images): image_parts, split_height = split_image(image, processor) image_splits.extend(image_parts) split_index.extend([image_idx] * len(image_parts)) split_heights.extend(split_height) image_splits = [prepare_image_detection(image, processor) for image in image_splits] # Batch images in dim 0 batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device) with torch.inference_mode(): pred = model(pixel_values=batch) logits = pred.logits correct_shape = [processor.size["height"], processor.size["width"]] current_shape = list(logits.shape[2:]) if current_shape != correct_shape: logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False) logits = logits.cpu().detach().numpy().astype(np.float32) preds = [] for i, (idx, height) in enumerate(zip(split_index, split_heights)): # If our current prediction length is below the image idx, that means we have a new image # Otherwise, we need to add to the current image if len(preds) <= idx: preds.append([logits[i][k] for k in range(heatmap_count)]) else: heatmaps = preds[idx] pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] if height < processor.size["height"]: # Cut off padding to get original height pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps] for k in range(heatmap_count): heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) preds[idx] = heatmaps all_preds.extend(preds) assert len(all_preds) == len(images) assert all([len(pred) == heatmap_count for pred in all_preds]) return all_preds, orig_sizes def parallel_get_lines(preds, orig_sizes): heatmap, affinity_map = preds heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) affinity_size = list(reversed(affinity_map.shape)) heatmap_size = list(reversed(heatmap.shape)) bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes) result = TextDetectionResult( bboxes=bboxes, vertical_lines=vertical_lines, heatmap=heat_img, affinity_map=aff_img, image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]] ) return result def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size) results = [] if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images for i in range(len(images)): result = parallel_get_lines(preds[i], orig_sizes[i]) results.append(result) else: max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) with ProcessPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(parallel_get_lines, preds, orig_sizes)) return results