from copy import deepcopy from typing import List import torch from PIL import Image from surya.input.processing import convert_if_not_rgb from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel from surya.schema import OrderBox, OrderResult from surya.settings import settings from tqdm import tqdm import numpy as np def get_batch_size(): batch_size = settings.ORDER_BATCH_SIZE if batch_size is None: batch_size = 8 if settings.TORCH_DEVICE_MODEL == "mps": batch_size = 8 if settings.TORCH_DEVICE_MODEL == "cuda": batch_size = 32 return batch_size def rank_elements(arr): enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1]) rank = [0] * len(arr) for rank_value, (original_index, value) in enumerate(enumerated_and_sorted): rank[original_index] = rank_value return rank def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]: assert all([isinstance(image, Image.Image) for image in images]) assert len(images) == len(bboxes) if batch_size is None: batch_size = get_batch_size() images = [image.convert("RGB") for image in images] # also copies the images output_order = [] for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"): batch_bboxes = deepcopy(bboxes[i:i+batch_size]) batch_images = images[i:i+batch_size] orig_sizes = [image.size for image in batch_images] model_inputs = processor(images=batch_images, boxes=batch_bboxes) batch_pixel_values = model_inputs["pixel_values"] batch_bboxes = model_inputs["input_boxes"] batch_bbox_mask = model_inputs["input_boxes_mask"] batch_bbox_counts = model_inputs["input_boxes_counts"] batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) token_count = 0 past_key_values = None encoder_outputs = None batch_predictions = [[] for _ in range(len(batch_images))] done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device) with torch.inference_mode(): while token_count < settings.ORDER_MAX_BOXES: return_dict = model( pixel_values=batch_pixel_values, decoder_input_boxes=batch_bboxes, decoder_input_boxes_mask=batch_bbox_mask, decoder_input_boxes_counts=batch_bbox_counts, encoder_outputs=encoder_outputs, past_key_values=past_key_values, ) logits = return_dict["logits"].detach() last_tokens = [] last_token_mask = [] min_val = torch.finfo(model.dtype).min for j in range(logits.shape[0]): label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token new_logits = logits[j, -1] new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes pred = int(torch.argmax(new_logits, dim=-1).item()) # Add one to avoid colliding with the 1000 height/width token for bboxes last_tokens.append([[pred + processor.box_size["height"] + 1] * 4]) if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label last_token_mask.append([0]) batch_predictions[j].append(pred) done[j] = True elif len(batch_predictions[j]) < label_count - 1: last_token_mask.append([1]) batch_predictions[j].append(pred) # Get rank prediction for given position else: last_token_mask.append([0]) if done.all(): break past_key_values = return_dict["past_key_values"] encoder_outputs = (return_dict["encoder_last_hidden_state"],) batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device) token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device) batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1) token_count += 1 for j, row_pred in enumerate(batch_predictions): row_bboxes = bboxes[i+j] assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}" orig_size = orig_sizes[j] ranks = [0] * len(row_bboxes) for box_idx in range(len(row_bboxes)): ranks[row_pred[box_idx]] = box_idx order_boxes = [] for row_bbox, rank in zip(row_bboxes, ranks): order_box = OrderBox( bbox=row_bbox, position=rank, ) order_boxes.append(order_box) result = OrderResult( bboxes=order_boxes, image_bbox=[0, 0, orig_size[0], orig_size[1]], ) output_order.append(result) return output_order