File size: 5,917 Bytes
2720487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from copy import deepcopy
from typing import List
import torch
from PIL import Image
from surya.input.processing import convert_if_not_rgb
from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel
from surya.schema import OrderBox, OrderResult
from surya.settings import settings
from tqdm import tqdm
import numpy as np
def get_batch_size():
batch_size = settings.ORDER_BATCH_SIZE
if batch_size is None:
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "mps":
batch_size = 8
if settings.TORCH_DEVICE_MODEL == "cuda":
batch_size = 32
return batch_size
def rank_elements(arr):
enumerated_and_sorted = sorted(enumerate(arr), key=lambda x: x[1])
rank = [0] * len(arr)
for rank_value, (original_index, value) in enumerate(enumerated_and_sorted):
rank[original_index] = rank_value
return rank
def batch_ordering(images: List, bboxes: List[List[List[float]]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[OrderResult]:
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(bboxes)
if batch_size is None:
batch_size = get_batch_size()
images = [image.convert("RGB") for image in images] # also copies the images
output_order = []
for i in tqdm(range(0, len(images), batch_size), desc="Finding reading order"):
batch_bboxes = deepcopy(bboxes[i:i+batch_size])
batch_images = images[i:i+batch_size]
orig_sizes = [image.size for image in batch_images]
model_inputs = processor(images=batch_images, boxes=batch_bboxes)
batch_pixel_values = model_inputs["pixel_values"]
batch_bboxes = model_inputs["input_boxes"]
batch_bbox_mask = model_inputs["input_boxes_mask"]
batch_bbox_counts = model_inputs["input_boxes_counts"]
batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device)
batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device)
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device)
token_count = 0
past_key_values = None
encoder_outputs = None
batch_predictions = [[] for _ in range(len(batch_images))]
done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device)
with torch.inference_mode():
while token_count < settings.ORDER_MAX_BOXES:
return_dict = model(
pixel_values=batch_pixel_values,
decoder_input_boxes=batch_bboxes,
decoder_input_boxes_mask=batch_bbox_mask,
decoder_input_boxes_counts=batch_bbox_counts,
encoder_outputs=encoder_outputs,
past_key_values=past_key_values,
)
logits = return_dict["logits"].detach()
last_tokens = []
last_token_mask = []
min_val = torch.finfo(model.dtype).min
for j in range(logits.shape[0]):
label_count = batch_bbox_counts[j, 1] - batch_bbox_counts[j, 0] - 1 # Subtract 1 for the sep token
new_logits = logits[j, -1]
new_logits[batch_predictions[j]] = min_val # Mask out already predicted tokens, we can only predict each token once
new_logits[label_count:] = min_val # Mask out all logit positions above the number of bboxes
pred = int(torch.argmax(new_logits, dim=-1).item())
# Add one to avoid colliding with the 1000 height/width token for bboxes
last_tokens.append([[pred + processor.box_size["height"] + 1] * 4])
if len(batch_predictions[j]) == label_count - 1: # Minus one since we're appending the final label
last_token_mask.append([0])
batch_predictions[j].append(pred)
done[j] = True
elif len(batch_predictions[j]) < label_count - 1:
last_token_mask.append([1])
batch_predictions[j].append(pred) # Get rank prediction for given position
else:
last_token_mask.append([0])
if done.all():
break
past_key_values = return_dict["past_key_values"]
encoder_outputs = (return_dict["encoder_last_hidden_state"],)
batch_bboxes = torch.tensor(last_tokens, dtype=torch.long).to(model.device)
token_bbox_mask = torch.tensor(last_token_mask, dtype=torch.long).to(model.device)
batch_bbox_mask = torch.cat([batch_bbox_mask, token_bbox_mask], dim=1)
token_count += 1
for j, row_pred in enumerate(batch_predictions):
row_bboxes = bboxes[i+j]
assert len(row_pred) == len(row_bboxes), f"Mismatch between logits and bboxes. Logits: {len(row_pred)}, Bboxes: {len(row_bboxes)}"
orig_size = orig_sizes[j]
ranks = [0] * len(row_bboxes)
for box_idx in range(len(row_bboxes)):
ranks[row_pred[box_idx]] = box_idx
order_boxes = []
for row_bbox, rank in zip(row_bboxes, ranks):
order_box = OrderBox(
bbox=row_bbox,
position=rank,
)
order_boxes.append(order_box)
result = OrderResult(
bboxes=order_boxes,
image_bbox=[0, 0, orig_size[0], orig_size[1]],
)
output_order.append(result)
return output_order
|