File size: 23,025 Bytes
56df21f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
# coding=utf-8
# Copyright 2024 FIXME copyright?
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for Molmo"""
import pdb
from typing import List, Optional, Union, Mapping

import numpy as np
import torch
import torchvision.transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import convert_image_dtype

from transformers.image_utils import (
    OPENAI_CLIP_MEAN,
    OPENAI_CLIP_STD,
    ImageInput,
)
from transformers.processing_utils import ImagesKwargs
from transformers.image_processing_utils import BaseImageProcessor
from transformers.utils import logging


logger = logging.get_logger(__name__)


def resize_and_pad(
    image,
    desired_output_size,
    resize_method="torch-bilinear",
    pad_value=0,
    normalize=True,
    image_mean=OPENAI_CLIP_MEAN,
    image_std=OPENAI_CLIP_STD,
):
    """Resize an image while padding to preserve uts aspect ratio."""
    desired_height, desired_width = desired_output_size
    height, width = image.shape[:2]

    # Cast into float32 since the training code did this in float32 and it (very rarely) effects
    # the results after rounding.
    image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
    image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
    image_scale = min(image_scale_x, image_scale_y)
    scaled_height = int(np.array(height, np.float32) * image_scale)
    scaled_width = int(np.array(width, np.float32) * image_scale)

    if resize_method == "tensorflow":
        # This how the original training code did resizing, it can produce slightly different
        # results then using torch resize so we keep it just in case
        import tensorflow as tf
        image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
        image = tf.image.resize(
            image,
            [scaled_height, scaled_width],
            method=tf.image.ResizeMethod.BILINEAR,
            antialias=True,
        )
        image = tf.clip_by_value(image, 0.0, 1.0)
        image = image.numpy()
    elif resize_method == "torch-bilinear":
        image = torch.permute(torch.from_numpy(image), [2, 0, 1])
        image = convert_image_dtype(image)  # resize in float32 to match the training code
        image = torchvision.transforms.Resize(
            [scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True
        )(image)
        image = torch.clip(image, 0.0, 1.0)
        image = torch.permute(image, [1, 2, 0]).numpy()
    else:
        raise NotImplementedError(resize_method)

    top_pad = (desired_height - scaled_height) // 2
    left_pad = (desired_width - scaled_width) // 2
    padding = [
        [top_pad, desired_height - scaled_height - top_pad],
        [left_pad, desired_width - scaled_width - left_pad],
        [0, 0]
    ]
    image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
    image = np.pad(image, padding, constant_values=pad_value)
    return image, image_mask


def select_tiling(h, w, patch_size, max_num_crops):
    """Divide in image of size [w, h] in up to max_num_patches of size patch_size"""
    original_size = np.stack([h, w])  # [1, 2]
    original_res = h * w
    tilings = []
    for i in range(1, max_num_crops + 1):
        for j in range(1, max_num_crops + 1):
            if i*j <= max_num_crops:
                tilings.append((i, j))
    # sort so argmin and argmax favour smaller tilings in the event of a tie
    tilings.sort(key=lambda x: (x[0]*x[1], x[0]))
    candidate_tilings = np.array(tilings, dtype=np.int32)  # [n_resolutions, 2]
    candidate_resolutions = candidate_tilings * patch_size  # [n_resolutions, 2]

    # How much we would need to scale the image to fit exactly in each tiling
    original_size = np.stack([h, w], dtype=np.float32)  # [1, 2]
    required_scale_d = candidate_resolutions.astype(np.float32) / original_size
    required_scale = np.min(required_scale_d, axis=-1, keepdims=True)  # [n_resolutions, 1]
    if np.all(required_scale < 1):
        # We are forced to downscale, so try to minimize the amount of downscaling
        ix = np.argmax(required_scale)
    else:
        # Pick the resolution that required the least upscaling so that it most closely fits the image
        required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
        ix = np.argmin(required_scale)
    return candidate_tilings[ix]


def pixels_to_patches(array, patch_size):
    """Reshape an image of [h, w, 3] -> [n_patches, pixels_per_patch]"""
    w, h, c = array.shape
    h_patches = h//patch_size
    w_patches = w//patch_size
    array = np.reshape(array, [h_patches, patch_size, w_patches, patch_size, c])
    array = np.transpose(array, [0, 2, 1, 3, 4])
    array = np.reshape(array, [h_patches*w_patches, patch_size*patch_size*c])
    return array


def batch_pixels_to_patches(array, patch_size):
    """Reshape images of [n_images, h, w, 3] -> [n_images, n_patches, pixels_per_patch]"""
    if len(array.shape) == 3:
        n_crops, w, h = array.shape
        h_patches = h//patch_size
        w_patches = w//patch_size
        array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size])
        array = np.transpose(array, [0, 1, 3, 2, 4])
        array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size])
        return array
    else:
        n_crops, w, h, c = array.shape
        h_patches = h//patch_size
        w_patches = w//patch_size
        array = np.reshape(array, [n_crops, h_patches, patch_size, w_patches, patch_size, c])
        array = np.transpose(array, [0, 1, 3, 2, 4, 5])
        array = np.reshape(array, [n_crops, h_patches*w_patches, patch_size*patch_size*c])
        return array


class MolmoImagesKwargs(ImagesKwargs, total=False):
    max_crops: Optional[int]
    overlap_margins: Optional[List[int]]
    base_image_input_size: Optional[List[int]]
    image_token_length_w: Optional[int]
    image_token_length_h: Optional[int]
    image_patch_size: Optional[int]
    image_padding_mask: Optional[bool]


class MolmoImageProcessor(BaseImageProcessor):
    """Preprocess images and multi-model inputs"""

    def __init__(
        self,
        max_crops: int = 12,
        overlap_margins: List[int] = (4, 4),
        base_image_input_size: List[int] = (336, 336),
        image_token_length_w: int = 12,
        image_token_length_h: int = 12,
        image_patch_size: int = 14,
        image_padding_mask: bool = True,
        do_normalize: bool = True,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.max_crops = max_crops
        self.overlap_margins = overlap_margins
        self.base_image_input_size = base_image_input_size
        self.image_token_length_w = image_token_length_w
        self.image_token_length_h = image_token_length_h
        self.image_patch_size = image_patch_size
        self.image_padding_mask = image_padding_mask
        self.do_normalize = do_normalize

    def _normalize(self, image):
        if self.do_normalize:
            image -= np.array(OPENAI_CLIP_MEAN, dtype=np.float32)[None, None, :]
            image /= np.array(OPENAI_CLIP_STD, dtype=np.float32)[None, None, :]
        return image

    def image_to_patches_and_tokens(
        self,
        image: ImageInput,
        image_patch_token_id: int,
        image_col_token_id: int,
        image_start_token_id: int,
        image_end_token_id: int,
        max_crops: Optional[int] = None,
        overlap_margins: Optional[List[int]] = None,
        base_image_input_size: Optional[Union[int, List[int]]] = None,
        image_token_length_w: Optional[int] = None,
        image_token_length_h: Optional[int] = None,
        image_patch_size: Optional[int] = None,
    ):
        if isinstance(base_image_input_size, int):
            base_image_input_size = (base_image_input_size, base_image_input_size)

        base_image_input_d = image_patch_size
        tokens_per_image = image_token_length_w * image_token_length_h
        image_base_patch_w = base_image_input_size[1] // base_image_input_d
        image_base_patch_h = base_image_input_size[0] // base_image_input_d

        original_image_h, original_image_w = image.shape[:2]
        crop_size = base_image_input_size[0]

        # Discard this many patches from the (left/top, right/bottom) of crops
        left_margin, right_margin = overlap_margins
        # left_margin, right_margin = 2, 2
        assert left_margin % 2 == 0  # Required for compatibility with 2x2 pooling
        total_margin_pixels = base_image_input_d*(right_margin + left_margin)  # pixels removed per dim
        crop_patches = base_image_input_size[0] // base_image_input_d  # patches per crop dim
        crop_window_patches = crop_patches - (right_margin + left_margin)  # usable patches
        crop_window_size = crop_window_patches * base_image_input_d

        # Decide how to tile the image, to account for the overlap margins we compute the tiling
        # as if we had an image without the margins and were using a crop size without the margins
        tiling = select_tiling(
            original_image_h - total_margin_pixels,
            original_image_w - total_margin_pixels,
            crop_window_size,
            max_crops
        )
        src, img_mask = resize_and_pad(
            image,
            [tiling[0]*crop_window_size+total_margin_pixels, tiling[1]*crop_window_size+total_margin_pixels]
        )
        src = self._normalize(src)

        # Now we have to split the image into crops, while keeping track of how each patch in the
        # each crop should be ordered in the global image, this require a lot of tricky booking
        n_crops = tiling[0] * tiling[1]
        patches_arr = []
        mask_arr = []
        patch_ordering_arr = []

        # We assume 2x2 pooling, but can allow padding the right/bottom with extra
        # patches if the number of patches per side is not even
        assert (crop_patches+1)//2 == image_token_length_h
        assert (crop_patches+1)//2 == image_token_length_w
        on = 0
        on_patch = 0
        for i in range(tiling[0]):
            y0 = i*crop_window_size
            if i == 0:
                crop_y0 = 0
            else:
                crop_y0 = left_margin // 2

            crop_h = image_base_patch_h - (right_margin + left_margin)
            if i == 0:
                crop_h += left_margin
            if i == (tiling[0]-1):
                crop_h += right_margin
            for j in range(tiling[1]):
                x0 = j*crop_window_size
                if j == 0:
                    crop_x0 = 0
                else:
                    crop_x0 = left_margin // 2

                crop_w = image_base_patch_w - (right_margin + left_margin)
                if j == 0:
                    crop_w += left_margin
                if j == (tiling[1]-1):
                    crop_w += right_margin

                pooled_w = (crop_w + 1) // 2
                pooled_h = (crop_h + 1) // 2
                after_padding_width = image_token_length_w - pooled_w - crop_x0
                after_padding_height = image_token_length_h - pooled_h - crop_y0
                patch_ordering_arr.append(
                    np.pad(
                        np.reshape(
                            np.arange(on, on+pooled_h*pooled_w, dtype=np.int32),
                            (pooled_h, pooled_w)),
                        [[crop_y0, after_padding_height], [crop_x0, after_padding_width]],
                        constant_values=-1, mode='constant'
                    )
                )
                patches_arr.append(src[y0:y0+crop_size, x0:x0+crop_size])
                mask_arr.append(img_mask[y0:y0+crop_size, x0:x0+crop_size])

                on += pooled_h*pooled_w
                on_patch += 1
        patches = np.stack(patches_arr)
        patch_ordering = np.stack(patch_ordering_arr)
        img_mask = np.stack(mask_arr)

        # Switch to [n_crops, n_patches, pixels_per_patch] format
        image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]

        patches = batch_pixels_to_patches(patches, image_patch_size)
        img_mask = batch_pixels_to_patches(img_mask, image_patch_size)
        img_mask = img_mask.astype(np.float32).mean(axis=-1)
        patch_ordering = np.reshape(patch_ordering, [-1])
        valid = patch_ordering >= 0

        # Path order numbers the patches crop-by-crop, here we transpose
        # it to get left-to-right order
        patch_ordering_rh = np.reshape(
            patch_ordering,
            [tiling[0], tiling[1], image_token_length_h, image_token_length_w]
        )
        patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
        patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])

        # The transpose will screw up which patches are masked, project the
        # new order into sparse structure of `patch_ordering` to fix it
        patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]

        # Now build the output tokens
        h = tiling[0] * crop_window_patches + (right_margin+left_margin)
        w = tiling[1] * crop_window_patches + (right_margin+left_margin)
        per_row = np.full(
            ((w+1)//2,),
            image_patch_token_id,
        )
        per_row = np.concatenate([per_row, [image_col_token_id]], 0)

        joint = np.tile(per_row, [(h+1)//2])
        joint = [
            [image_start_token_id],
            joint,
            [image_end_token_id]
        ]

        # Finally do the same for the global image
        resized, _ = resize_and_pad(image, base_image_input_size)
        resized = self._normalize(resized)
        resized = pixels_to_patches(resized, image_patch_size)
        patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)

        # Global image goes first, so the order of patches in previous crops gets increased
        patch_ordering = np.where(
            patch_ordering >= 0,
            patch_ordering + tokens_per_image,
            -1
        )
        patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
        per_row = np.full(
            (image_token_length_w,),
            image_patch_token_id,
        )
        per_row = np.concatenate([per_row, [image_col_token_id]], 0)
        extra_tokens = np.tile(per_row, [image_token_length_h])
        joint = [
                    [image_start_token_id],
                    extra_tokens,
                    [image_end_token_id],
                ] + joint

        joint = np.concatenate(joint, 0)
        img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
        return patches, joint, patch_ordering, img_mask

    def build_image_input_idx(
        self,
        image_tokens: np.ndarray,
        patch_order: np.ndarray,
        image_patch_token_id: int,
        image_token_length_w: int,
        image_token_length_h: int,
    ):
        """Converts `patch_order` into a mapping of token_id -> patch_id"""

        tokens_per_image = image_token_length_w * image_token_length_h

        # Indices to insert the patches
        image_input_idx = image_tokens == image_patch_token_id
        image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)

        if patch_order is not None:
            n_tokens = image_input_idx.shape[0]
            patch_order = np.reshape(patch_order, [-1])
            n_patches = patch_order.shape[0]

            valid = patch_order >= 0
            n_valid_patches = valid.sum()
            assert len(image_input_idx) == n_valid_patches

            sorted_patch_ixs = np.zeros([n_tokens], np.int32)
            sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)

            # Project the inverted mapping into same sparse structure
            sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
            sorted_patch_ixs_ex[valid] = sorted_patch_ixs

            # Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
            valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
            image_input_idx = image_input_idx[sorted_patch_ixs_ex*valid]
            image_input_idx = image_input_idx*valid - 100*(1 - valid)
            image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
        return image_input_idx

    def preprocess(
        self,
        image: np.ndarray,
        image_patch_token_id: int,
        image_col_token_id: int,
        image_start_token_id: int,
        image_end_token_id: int,
        max_crops: Optional[int] = None,
        overlap_margins: Optional[List[int]] = None,
        base_image_input_size: Optional[Union[int, List[int]]] = None,
        image_token_length_w: Optional[int] = None,
        image_token_length_h: Optional[int] = None,
        image_patch_size: Optional[int] = None,
        **kwargs,
    ):
        """Preprocesses a single image

        Returns:
            crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
                   change between images but the other dimension are fixed
            tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the
                                patch features, might include other special tokens as well
            image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the
                       crops after pooling, negative values indicates patches features to exclude
            padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None
                          if the image mask is not being used.
        """

        max_crops = max_crops or self.max_crops
        overlap_margins = overlap_margins or self.overlap_margins
        base_image_input_size = base_image_input_size or self.base_image_input_size
        image_token_length_w = image_token_length_w or self.image_token_length_w
        image_token_length_h = image_token_length_h or self.image_token_length_h
        image_patch_size = image_patch_size or self.image_patch_size

        crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
            image,
            image_patch_token_id,
            image_col_token_id,
            image_start_token_id,
            image_end_token_id,
            max_crops,
            overlap_margins,
            base_image_input_size,
            image_token_length_w,
            image_token_length_h,
            image_patch_size,
        )
        patch_idx = self.build_image_input_idx(
            image_tokens,
            patch_ordering,
            image_patch_token_id,
            image_token_length_w=image_token_length_w,
            image_token_length_h=image_token_length_h,
        )
        return crops, image_tokens, patch_idx, img_mask

    def multimodal_preprocess(
        self,
        images: np.ndarray,
        tokens: List[int],
        image_idx: np.ndarray,
        sequence_length: int,
        image_patch_token_id: int,
        image_col_token_id: int,
        image_start_token_id: int,
        image_end_token_id: int,
        **kwargs,
    ):
        """Merge images and text tokens into multi-modal features for the model

        :param images: images to use as input
        :param tokens: input text tokens
        :param image_idx: where to insert the images into `tokens`
        :params image_patch_token_id: id to use of tokens that will contain image features
        :params image_col_token_id: token id for image column special tokens
        :params image_start_token_id: token id for image start special tokens
        :params image_end_token_id: token id for image end special tokens
        :params kwargs: override preprocessor default args
        """
        if images is None:
            return {"input_ids": tokens}

        max_total_crops = kwargs.get("max_crops") or self.max_crops
        image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
        image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
        image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
        base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
        image_num_patch = (
            base_image_input_size[0] // image_patch_size,
            base_image_input_size[1] // image_patch_size,
        )
        image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask

        tokens_per_image = image_token_length_w * image_token_length_h
        n_pixels = image_patch_size * image_patch_size * 3
        n_patches = image_num_patch[0] * image_num_patch[1]

        n = len(images)
        all_crops = []
        all_image_idx = []
        out_tokens = []
        all_crop_masks = []

        for ix in range(n):
            token_ix = image_idx[ix]
            crops, image_tokens, patch_idx, img_mask = self.preprocess(
                images[ix],
                image_patch_token_id,
                image_col_token_id,
                image_start_token_id,
                image_end_token_id,
                **kwargs,
            )

            if token_ix == -1:  # -1 is an image inserted at the very start
                start = 0
                token_ix = 0
                end = 0
            else:
                start = 0 if ix == 0 else image_idx[ix-1] + 1
                end = token_ix + 1

            all_image_idx.append(patch_idx + token_ix)
            all_crops.append(crops)
            out_tokens.append(tokens[start:token_ix])
            out_tokens.append(image_tokens)
            if ix == (n - 1):
                out_tokens.append(tokens[end:])
            if image_padding_mask:
                all_crop_masks.append(img_mask)

        input_ids = np.concatenate(out_tokens, 0)
        images = np.concatenate(all_crops, 0)
        image_input_idx = np.concatenate(all_image_idx, 0)
        if image_padding_mask:
            image_masks = np.concatenate(all_crop_masks, 0)
        else:
            image_masks = None

        out = {
            "input_ids": input_ids,
            "images": images,
            "image_input_idx": image_input_idx
        }
        if image_masks is not None:
            out["image_masks"] = image_masks
        return out


MolmoImageProcessor.register_for_auto_class()