File size: 5,738 Bytes
2720487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from copy import deepcopy
from typing import Dict, Union, Optional, List, Tuple
import torch
from torch import TensorType
from transformers import DonutImageProcessor, DonutProcessor
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
valid_images, to_numpy_array
import numpy as np
from PIL import Image
import PIL
from surya.settings import settings
def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
processor = OrderImageProcessor.from_pretrained(checkpoint)
processor.size = settings.ORDER_IMAGE_SIZE
box_size = 1024
max_tokens = 256
processor.token_sep_id = max_tokens + box_size + 1
processor.token_pad_id = max_tokens + box_size + 2
processor.max_boxes = settings.ORDER_MAX_BOXES - 1
processor.box_size = {"height": box_size, "width": box_size}
return processor
class OrderImageProcessor(DonutImageProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.patch_size = kwargs.get("patch_size", (4, 4))
def process_inner(self, images: List[np.ndarray]):
images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format
assert images[0].shape[0] == 3 # RGB input images, channel dim last
# Convert to float32 for rescale/normalize
images = [img.astype(np.float32) for img in images]
# Rescale and normalize
images = [
self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
for img in images
]
images = [
self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
for img in images
]
return images
def process_boxes(self, boxes):
padded_boxes = []
box_masks = []
box_counts = []
for b in boxes:
# Left pad for generation
padded_b = deepcopy(b)
padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions
padded_boxes.append(padded_b)
max_boxes = max(len(b) for b in padded_boxes)
for i in range(len(padded_boxes)):
pad_len = max_boxes - len(padded_boxes[i])
box_len = len(padded_boxes[i])
box_mask = [0] * pad_len + [1] * box_len
padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i]
padded_boxes[i] = padded_box
box_masks.append(box_mask)
box_counts.append([pad_len, max_boxes])
return padded_boxes, box_masks, box_counts
def resize_img_and_boxes(self, img, boxes):
orig_dim = img.size
new_size = (self.size["width"], self.size["height"])
img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
img = np.asarray(img, dtype=np.uint8)
width, height = orig_dim
box_width, box_height = self.box_size["width"], self.box_size["height"]
for box in boxes:
# Rescale to 0-1024
box[0] = box[0] / width * box_width
box[1] = box[1] / height * box_height
box[2] = box[2] / width * box_width
box[3] = box[3] / height * box_height
if box[0] < 0:
box[0] = 0
if box[1] < 0:
box[1] = 0
if box[2] > box_width:
box[2] = box_width
if box[3] > box_height:
box[3] = box_height
return img, boxes
def preprocess(
self,
images: ImageInput,
boxes: List[List[int]],
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_pad: bool = None,
random_padding: bool = False,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
new_images = []
new_boxes = []
for img, box in zip(images, boxes):
if len(box) > self.max_boxes:
raise ValueError(f"Too many boxes, max is {self.max_boxes}")
img, box = self.resize_img_and_boxes(img, box)
new_images.append(img)
new_boxes.append(box)
images = new_images
boxes = new_boxes
# Convert to numpy for later processing steps
images = [np.array(image) for image in images]
images = self.process_inner(images)
boxes, box_mask, box_counts = self.process_boxes(boxes)
data = {
"pixel_values": images,
"input_boxes": boxes,
"input_boxes_mask": box_mask,
"input_boxes_counts": box_counts,
}
return BatchFeature(data=data, tensor_type=return_tensors) |