|
from copy import deepcopy |
|
from typing import List |
|
import torch |
|
from PIL import Image |
|
|
|
from surya.input.processing import convert_if_not_rgb |
|
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel |
|
from surya.schema import OrderBox, OrderResult |
|
from surya.settings import settings |
|
from tqdm import tqdm |
|
import numpy as np |
|
|
|
|
|
def get_batch_size(): |
|
batch_size = settings.ORDER_BATCH_SIZE |
|
if batch_size is None: |
|
batch_size = 8 |
|
if settings.TORCH_DEVICE_MODEL == "mps": |
|
batch_size = 8 |
|
if settings.TORCH_DEVICE_MODEL == "cuda": |
|
batch_size = 32 |
|
return batch_size |
|
|
|
|
|
def rank_elements(arr): |
|
enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1]) |
|
rank = [0] * len(arr) |
|
|
|
for rank_value, (original_index, value) in enumerate(enumerated_and_sorted): |
|
rank[original_index] = rank_value |
|
|
|
return rank |
|
|
|
|
|
def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]: |
|
assert all([isinstance(image, Image.Image) for image in images]) |
|
assert len(images) == len(bboxes) |
|
if batch_size is None: |
|
batch_size = get_batch_size() |
|
|
|
images = [image.convert("RGB") for image in images] |
|
|
|
output_order = [] |
|
for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"): |
|
batch_bboxes = deepcopy(bboxes[i:i+batch_size]) |
|
batch_images = images[i:i+batch_size] |
|
orig_sizes = [image.size for image in batch_images] |
|
model_inputs = processor(images=batch_images, boxes=batch_bboxes) |
|
|
|
batch_pixel_values = model_inputs["pixel_values"] |
|
batch_bboxes = model_inputs["input_boxes"] |
|
batch_bbox_mask = model_inputs["input_boxes_mask"] |
|
batch_bbox_counts = model_inputs["input_boxes_counts"] |
|
|
|
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) |
|
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) |
|
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) |
|
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) |
|
|
|
token_count = 0 |
|
past_key_values = None |
|
encoder_outputs = None |
|
batch_predictions = [[] for _ in range(len(batch_images))] |
|
done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device) |
|
|
|
with torch.inference_mode(): |
|
while token_count < settings.ORDER_MAX_BOXES: |
|
return_dict = model( |
|
pixel_values=batch_pixel_values, |
|
decoder_input_boxes=batch_bboxes, |
|
decoder_input_boxes_mask=batch_bbox_mask, |
|
decoder_input_boxes_counts=batch_bbox_counts, |
|
encoder_outputs=encoder_outputs, |
|
past_key_values=past_key_values, |
|
) |
|
logits = return_dict["logits"].detach() |
|
|
|
last_tokens = [] |
|
last_token_mask = [] |
|
min_val = torch.finfo(model.dtype).min |
|
for j in range(logits.shape[0]): |
|
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 |
|
new_logits = logits[j, -1] |
|
new_logits[batch_predictions[j]] = min_val |
|
new_logits[label_count:] = min_val |
|
pred = int(torch.argmax(new_logits, dim=-1).item()) |
|
|
|
|
|
last_tokens.append([[pred + processor.box_size["height"] + 1] * 4]) |
|
if len(batch_predictions[j]) == label_count - 1: |
|
last_token_mask.append([0]) |
|
batch_predictions[j].append(pred) |
|
done[j] = True |
|
elif len(batch_predictions[j]) < label_count - 1: |
|
last_token_mask.append([1]) |
|
batch_predictions[j].append(pred) |
|
else: |
|
last_token_mask.append([0]) |
|
|
|
if done.all(): |
|
break |
|
|
|
past_key_values = return_dict["past_key_values"] |
|
encoder_outputs = (return_dict["encoder_last_hidden_state"],) |
|
|
|
batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device) |
|
token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device) |
|
batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1) |
|
token_count += 1 |
|
|
|
for j, row_pred in enumerate(batch_predictions): |
|
row_bboxes = bboxes[i+j] |
|
assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}" |
|
|
|
orig_size = orig_sizes[j] |
|
ranks = [0] * len(row_bboxes) |
|
|
|
for box_idx in range(len(row_bboxes)): |
|
ranks[row_pred[box_idx]] = box_idx |
|
|
|
order_boxes = [] |
|
for row_bbox, rank in zip(row_bboxes, ranks): |
|
order_box = OrderBox( |
|
bbox=row_bbox, |
|
position=rank, |
|
) |
|
order_boxes.append(order_box) |
|
|
|
result = OrderResult( |
|
bboxes=order_boxes, |
|
image_bbox=[0, 0, orig_size[0], orig_size[1]], |
|
) |
|
output_order.append(result) |
|
return output_order |
|
|
|
|
|
|
|
|
|
|
|
|
|
|