File size: 5,738 Bytes
2720487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from copy import deepcopy
from typing import Dict, Union, Optional, List, Tuple

import torch
from torch import TensorType
from transformers import DonutImageProcessor, DonutProcessor
from transformers.image_processing_utils import BatchFeature
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
    valid_images, to_numpy_array
import numpy as np
from PIL import Image
import PIL
from surya.settings import settings


def load_processor(checkpoint=settings.ORDER_MODEL_CHECKPOINT):
    processor = OrderImageProcessor.from_pretrained(checkpoint)
    processor.size = settings.ORDER_IMAGE_SIZE
    box_size = 1024
    max_tokens = 256
    processor.token_sep_id = max_tokens + box_size + 1
    processor.token_pad_id = max_tokens + box_size + 2
    processor.max_boxes = settings.ORDER_MAX_BOXES - 1
    processor.box_size = {"height": box_size, "width": box_size}
    return processor


class OrderImageProcessor(DonutImageProcessor):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.patch_size = kwargs.get("patch_size", (4, 4))

    def process_inner(self, images: List[np.ndarray]):
        images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format

        assert images[0].shape[0] == 3 # RGB input images, channel dim last

        # Convert to float32 for rescale/normalize
        images = [img.astype(np.float32) for img in images]

        # Rescale and normalize
        images = [
            self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
            for img in images
        ]
        images = [
            self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
            for img in images
        ]

        return images

    def process_boxes(self, boxes):
        padded_boxes = []
        box_masks = []
        box_counts = []
        for b in boxes:
            # Left pad for generation
            padded_b = deepcopy(b)
            padded_b.append([self.token_sep_id] * 4) # Sep token to indicate start of label predictions
            padded_boxes.append(padded_b)

        max_boxes = max(len(b) for b in padded_boxes)
        for i in range(len(padded_boxes)):
            pad_len = max_boxes - len(padded_boxes[i])
            box_len = len(padded_boxes[i])
            box_mask = [0] * pad_len + [1] * box_len
            padded_box = [[self.token_pad_id] * 4] * pad_len + padded_boxes[i]
            padded_boxes[i] = padded_box
            box_masks.append(box_mask)
            box_counts.append([pad_len, max_boxes])

        return padded_boxes, box_masks, box_counts

    def resize_img_and_boxes(self, img, boxes):
        orig_dim = img.size
        new_size = (self.size["width"], self.size["height"])
        img.thumbnail(new_size, Image.Resampling.LANCZOS)  # Shrink largest dimension to fit new size
        img = img.resize(new_size, Image.Resampling.LANCZOS)  # Stretch smaller dimension to fit new size

        img = np.asarray(img, dtype=np.uint8)

        width, height = orig_dim
        box_width, box_height = self.box_size["width"], self.box_size["height"]
        for box in boxes:
            # Rescale to 0-1024
            box[0] = box[0] / width * box_width
            box[1] = box[1] / height * box_height
            box[2] = box[2] / width * box_width
            box[3] = box[3] / height * box_height

            if box[0] < 0:
                box[0] = 0
            if box[1] < 0:
                box[1] = 0
            if box[2] > box_width:
                box[2] = box_width
            if box[3] > box_height:
                box[3] = box_height

        return img, boxes

    def preprocess(
        self,
        images: ImageInput,
        boxes: List[List[int]],
        do_resize: bool = None,
        size: Dict[str, int] = None,
        resample: PILImageResampling = None,
        do_thumbnail: bool = None,
        do_align_long_axis: bool = None,
        do_pad: bool = None,
        random_padding: bool = False,
        do_rescale: bool = None,
        rescale_factor: float = None,
        do_normalize: bool = None,
        image_mean: Optional[Union[float, List[float]]] = None,
        image_std: Optional[Union[float, List[float]]] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
        input_data_format: Optional[Union[str, ChannelDimension]] = None,
        **kwargs,
    ) -> PIL.Image.Image:
        images = make_list_of_images(images)

        if not valid_images(images):
            raise ValueError(
                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
                "torch.Tensor, tf.Tensor or jax.ndarray."
            )

        new_images = []
        new_boxes = []
        for img, box in zip(images, boxes):
            if len(box) > self.max_boxes:
                raise ValueError(f"Too many boxes, max is {self.max_boxes}")
            img, box = self.resize_img_and_boxes(img, box)
            new_images.append(img)
            new_boxes.append(box)

        images = new_images
        boxes = new_boxes

        # Convert to numpy for later processing steps
        images = [np.array(image) for image in images]

        images = self.process_inner(images)
        boxes, box_mask, box_counts = self.process_boxes(boxes)
        data = {
            "pixel_values": images,
            "input_boxes": boxes,
            "input_boxes_mask": box_mask,
            "input_boxes_counts": box_counts,
        }
        return BatchFeature(data=data, tensor_type=return_tensors)