|
from typing import Dict, Union, Optional, List, Iterable |
|
|
|
import cv2 |
|
from torch import TensorType |
|
from transformers import DonutImageProcessor, DonutProcessor |
|
from transformers.image_processing_utils import BatchFeature |
|
from transformers.image_transforms import pad, normalize |
|
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size |
|
import numpy as np |
|
from PIL import Image |
|
import PIL |
|
from test_surya.surya.model.recognition.tokenizer import Byt5LangTokenizer |
|
from surya.settings import settings |
|
|
|
|
|
def load_processor(): |
|
processor = SuryaProcessor() |
|
processor.image_processor.train = False |
|
processor.image_processor.max_size = settings.RECOGNITION_IMAGE_SIZE |
|
processor.tokenizer.model_max_length = settings.RECOGNITION_MAX_TOKENS |
|
return processor |
|
|
|
|
|
class SuryaImageProcessor(DonutImageProcessor): |
|
def __init__(self, *args, max_size=None, train=False, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
self.patch_size = kwargs.get("patch_size", (4, 4)) |
|
self.max_size = max_size |
|
self.train = train |
|
|
|
@classmethod |
|
def numpy_resize(cls, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4): |
|
height, width = image.shape[:2] |
|
max_width, max_height = size["width"], size["height"] |
|
|
|
if (height == max_height and width <= max_width) or (width == max_width and height <= max_height): |
|
image = image.transpose(2, 0, 1) |
|
return image |
|
|
|
scale = min(max_width / width, max_height / height) |
|
|
|
new_width = int(width * scale) |
|
new_height = int(height * scale) |
|
|
|
resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation) |
|
resized_image = resized_image.transpose(2, 0, 1) |
|
|
|
return resized_image |
|
|
|
def process_inner(self, images: List[np.ndarray]): |
|
assert images[0].shape[2] == 3 |
|
|
|
|
|
images = [SuryaImageProcessor.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in images] |
|
|
|
|
|
for img in images: |
|
assert img.shape[1] >= img.shape[0] |
|
|
|
|
|
images = [SuryaImageProcessor.numpy_resize(img, self.max_size, self.resample) for img in images] |
|
assert images[0].shape[0] == 3 |
|
|
|
|
|
images = [img.astype(np.float32) for img in images] |
|
|
|
|
|
|
|
max_size = self.max_size |
|
images = [ |
|
SuryaImageProcessor.pad_image( |
|
image=image, |
|
size=max_size, |
|
input_data_format=ChannelDimension.FIRST, |
|
pad_value=settings.RECOGNITION_PAD_VALUE |
|
) |
|
for image in images |
|
] |
|
|
|
for idx in range(len(images)): |
|
images[idx] = images[idx] * self.rescale_factor |
|
images = [ |
|
SuryaImageProcessor.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST) |
|
for img in images |
|
] |
|
|
|
return images |
|
|
|
def preprocess( |
|
self, |
|
images: ImageInput, |
|
do_resize: bool = None, |
|
size: Dict[str, int] = None, |
|
resample: PILImageResampling = None, |
|
do_thumbnail: bool = None, |
|
do_align_long_axis: bool = None, |
|
do_pad: bool = None, |
|
random_padding: bool = False, |
|
do_rescale: bool = None, |
|
rescale_factor: float = None, |
|
do_normalize: bool = None, |
|
image_mean: Optional[Union[float, List[float]]] = None, |
|
image_std: Optional[Union[float, List[float]]] = None, |
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
**kwargs, |
|
) -> PIL.Image.Image: |
|
images = make_list_of_images(images) |
|
|
|
|
|
images = [np.array(img) for img in images] |
|
images = self.process_inner(images) |
|
|
|
data = {"pixel_values": images} |
|
return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
|
@classmethod |
|
def pad_image( |
|
cls, |
|
image: np.ndarray, |
|
size: Dict[str, int], |
|
data_format: Optional[Union[str, ChannelDimension]] = None, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
pad_value: float = 0.0, |
|
) -> np.ndarray: |
|
output_height, output_width = size["height"], size["width"] |
|
input_height, input_width = get_image_size(image, channel_dim=input_data_format) |
|
|
|
delta_width = output_width - input_width |
|
delta_height = output_height - input_height |
|
|
|
assert delta_width >= 0 and delta_height >= 0 |
|
|
|
pad_top = delta_height // 2 |
|
pad_left = delta_width // 2 |
|
|
|
pad_bottom = delta_height - pad_top |
|
pad_right = delta_width - pad_left |
|
|
|
padding = ((pad_top, pad_bottom), (pad_left, pad_right)) |
|
return pad(image, padding, data_format=data_format, input_data_format=input_data_format, constant_values=pad_value) |
|
|
|
@classmethod |
|
def align_long_axis( |
|
cls, |
|
image: np.ndarray, |
|
size: Dict[str, int], |
|
data_format: Optional[Union[str, ChannelDimension]] = None, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
) -> np.ndarray: |
|
input_height, input_width = image.shape[:2] |
|
output_height, output_width = size["height"], size["width"] |
|
|
|
if (output_width < output_height and input_width > input_height) or ( |
|
output_width > output_height and input_width < input_height |
|
): |
|
image = np.rot90(image, 3) |
|
|
|
return image |
|
|
|
@classmethod |
|
def normalize( |
|
cls, |
|
image: np.ndarray, |
|
mean: Union[float, Iterable[float]], |
|
std: Union[float, Iterable[float]], |
|
data_format: Optional[Union[str, ChannelDimension]] = None, |
|
input_data_format: Optional[Union[str, ChannelDimension]] = None, |
|
**kwargs, |
|
) -> np.ndarray: |
|
return normalize( |
|
image, mean=mean, std=std, data_format=data_format, input_data_format=input_data_format, **kwargs |
|
) |
|
|
|
|
|
class SuryaProcessor(DonutProcessor): |
|
def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs): |
|
image_processor = SuryaImageProcessor.from_pretrained(settings.RECOGNITION_MODEL_CHECKPOINT) |
|
tokenizer = Byt5LangTokenizer() |
|
if image_processor is None: |
|
raise ValueError("You need to specify an `image_processor`.") |
|
if tokenizer is None: |
|
raise ValueError("You need to specify a `tokenizer`.") |
|
|
|
super().__init__(image_processor, tokenizer) |
|
self.current_processor = self.image_processor |
|
self._in_target_context_manager = False |
|
|
|
def __call__(self, *args, **kwargs): |
|
images = kwargs.pop("images", None) |
|
text = kwargs.pop("text", None) |
|
lang = kwargs.pop("lang", None) |
|
|
|
if len(args) > 0: |
|
images = args[0] |
|
args = args[1:] |
|
|
|
if images is None and text is None: |
|
raise ValueError("You need to specify either an `images` or `text` input to process.") |
|
|
|
if images is not None: |
|
inputs = self.image_processor(images, *args, **kwargs) |
|
|
|
if text is not None: |
|
encodings = self.tokenizer(text, lang, **kwargs) |
|
|
|
if text is None: |
|
return inputs |
|
elif images is None: |
|
return encodings |
|
else: |
|
inputs["labels"] = encodings["input_ids"] |
|
inputs["langs"] = encodings["langs"] |
|
return inputs |