ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
raw
history blame
5.74 kB
from copy import deepcopy
from typing import Dict, Union, Optional, List, Tuple
import torch
from torch import TensorType
from transformers import DonutImageProcessor, DonutProcessor
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
valid_images, to_numpy_array
import numpy as np
from PIL import Image
import PIL
from surya.settings import settings
def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
processor = OrderImageProcessor.from_pretrained(checkpoint)
processor.size = settings.ORDER_IMAGE_SIZE
box_size = 1024
max_tokens = 256
processor.token_sep_id = max_tokens + box_size + 1
processor.token_pad_id = max_tokens + box_size + 2
processor.max_boxes = settings.ORDER_MAX_BOXES - 1
processor.box_size = {"height": box_size, "width": box_size}
return processor
class OrderImageProcessor(DonutImageProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.patch_size = kwargs.get("patch_size", (4, 4))
def process_inner(self, images: List[np.ndarray]):
images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format
assert images[0].shape[0] == 3 # RGB input images, channel dim last
# Convert to float32 for rescale/normalize
images = [img.astype(np.float32) for img in images]
# Rescale and normalize
images = [
self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
for img in images
]
images = [
self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
for img in images
]
return images
def process_boxes(self, boxes):
padded_boxes = []
box_masks = []
box_counts = []
for b in boxes:
# Left pad for generation
padded_b = deepcopy(b)
padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions
padded_boxes.append(padded_b)
max_boxes = max(len(b) for b in padded_boxes)
for i in range(len(padded_boxes)):
pad_len = max_boxes - len(padded_boxes[i])
box_len = len(padded_boxes[i])
box_mask = [0] * pad_len + [1] * box_len
padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i]
padded_boxes[i] = padded_box
box_masks.append(box_mask)
box_counts.append([pad_len, max_boxes])
return padded_boxes, box_masks, box_counts
def resize_img_and_boxes(self, img, boxes):
orig_dim = img.size
new_size = (self.size["width"], self.size["height"])
img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
img = img.resize(new_size, Image.Resampling.LANCZOS) # Stretch smaller dimension to fit new size
img = np.asarray(img, dtype=np.uint8)
width, height = orig_dim
box_width, box_height = self.box_size["width"], self.box_size["height"]
for box in boxes:
# Rescale to 0-1024
box[0] = box[0] / width * box_width
box[1] = box[1] / height * box_height
box[2] = box[2] / width * box_width
box[3] = box[3] / height * box_height
if box[0] < 0:
box[0] = 0
if box[1] < 0:
box[1] = 0
if box[2] > box_width:
box[2] = box_width
if box[3] > box_height:
box[3] = box_height
return img, boxes
def preprocess(
self,
images: ImageInput,
boxes: List[List[int]],
do_resize: bool = None,
size: Dict[str, int] = None,
resample: PILImageResampling = None,
do_thumbnail: bool = None,
do_align_long_axis: bool = None,
do_pad: bool = None,
random_padding: bool = False,
do_rescale: bool = None,
rescale_factor: float = None,
do_normalize: bool = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
new_images = []
new_boxes = []
for img, box in zip(images, boxes):
if len(box) > self.max_boxes:
raise ValueError(f"Too many boxes, max is {self.max_boxes}")
img, box = self.resize_img_and_boxes(img, box)
new_images.append(img)
new_boxes.append(box)
images = new_images
boxes = new_boxes
# Convert to numpy for later processing steps
images = [np.array(image) for image in images]
images = self.process_inner(images)
boxes, box_mask, box_counts = self.process_boxes(boxes)
data = {
"pixel_values": images,
"input_boxes": boxes,
"input_boxes_mask": box_mask,
"input_boxes_counts": box_counts,
}
return BatchFeature(data=data, tensor_type=return_tensors)