besiktas commited on
Commit
d11ed8e
·
verified ·
1 Parent(s): f2492ec

Upload processor

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<image>": 257152
3
+ }
paligemma_model.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from functools import cache
3
+
4
+ import torch
5
+ from PIL.Image import Image
6
+ from transformers import PaliGemmaConfig as HFPaliGemmaConfig
7
+ from transformers import PaliGemmaForConditionalGeneration as HFPaliGemmaForConditionalGeneration
8
+ from transformers import PaliGemmaProcessor as HFPaliGemmaProcessor
9
+ from transformers.utils import TensorType
10
+
11
+ from pretrain_mm.processor.processor import ProcessorMixin, TextProcessorMixin
12
+ from pretrain_mm.processor.tokenizer_constants import SetConstants, TokenizerConstants
13
+ from pretrain_mm.utils.token_tag_utils import TagType, box_pattern, point_pattern, segment_str
14
+
15
+
16
+ """
17
+ note, for patching this model atm need to fix casting in PaliGemmaProcessor._merge_input_ids_with_image_features
18
+
19
+ I am not sure exactly which devices/where the issue arises from so ended up just casting multiple things as otherwise
20
+ debugging/fixing on borah is a pain
21
+
22
+ if edited in place on borah, location @
23
+ `/bsuhome/gannett/mambaforge/envs/pt/lib/python3.11/site-packages/transformers/models/paligemma/modeling_paligemma.py`
24
+ then remove the cast
25
+
26
+ spaces with actual working implementation of seg/detect/etc
27
+
28
+
29
+ - https://huggingface.co/spaces/big-vision/paligemma-hf
30
+ - this one has the VAE for decoding to mask
31
+ - https://huggingface.co/spaces/big-vision/paligemma
32
+ - this one uses the big-vision stuff which is not ideal
33
+
34
+
35
+ models:
36
+ https://huggingface.co/google/paligemma-3b-ft-docvqa-896
37
+ https://huggingface.co/google/paligemma-3b-ft-ocrvqa-896
38
+ """
39
+
40
+
41
+ MODEL_ID: str = "google/paligemma-3b-ft-docvqa-896"
42
+ PROCESSOR_IMAGE_MAX_SIZE: int = 1024 # loc/seg must be scaled to this max
43
+
44
+ _r_loc = r"<loc(\d{4})>"
45
+ re_loc = re.compile(_r_loc)
46
+ re_loc_point = re.compile(_r_loc * 2)
47
+ re_loc_box = re.compile(_r_loc * 4)
48
+
49
+ re_seg = re.compile(r"<seg(\d{3})>")
50
+
51
+
52
+ def _scale_val(val: int, dim_scale_factor: int, max_size: int = PROCESSOR_IMAGE_MAX_SIZE):
53
+ return min(int(val * dim_scale_factor), max_size)
54
+
55
+
56
+ def _make_seg_text(val: int, tag: str = "loc", digits: int = 4):
57
+ return f"<{tag}{val:0>{digits}}>"
58
+
59
+
60
+ @cache
61
+ def _make_scale_dim_func(image_dim: int, max_size: int = PROCESSOR_IMAGE_MAX_SIZE):
62
+ # image_dim is either height or width
63
+ def func(*vals: int):
64
+ return [round((int(val) / max_size) * image_dim) for val in vals]
65
+
66
+ return func
67
+
68
+
69
+ class PaliGemmaConfig(HFPaliGemmaConfig):
70
+ pass
71
+
72
+
73
+ class PaliGemmaConstantsClass(TokenizerConstants):
74
+ # from the processor.tokenizer
75
+ bos_token: str = "<bos>"
76
+ eos_token: str = "<eos>"
77
+ image_placeholder_token: str = "<image>"
78
+
79
+ repr_bbox_open_text: str = "<box>"
80
+ repr_bbox_close_text: str = "</box>"
81
+ repr_point_open_text: str = "<point>"
82
+ repr_point_close_text: str = "</point>"
83
+
84
+
85
+ PaliGemmaConstants = PaliGemmaConstantsClass()
86
+
87
+
88
+ class PaliGemmaForConditionalGeneration(HFPaliGemmaForConditionalGeneration):
89
+ pass
90
+
91
+
92
+ @SetConstants(PaliGemmaConstants)
93
+ class PaliGemmaProcessor(HFPaliGemmaProcessor, ProcessorMixin, TextProcessorMixin):
94
+ constants: PaliGemmaConstantsClass
95
+
96
+ def __init__(self, *args, **kwargs):
97
+ super().__init__(*args, **kwargs)
98
+ self._call = super().__call__
99
+
100
+ def __call__(
101
+ self,
102
+ text=None,
103
+ images=None,
104
+ tokenize_newline_separately=True,
105
+ padding=False,
106
+ truncation=None,
107
+ max_length=None,
108
+ return_tensors=TensorType.PYTORCH,
109
+ do_resize=None,
110
+ do_normalize=None,
111
+ image_mean=None,
112
+ image_std=None,
113
+ data_format="channels_first",
114
+ input_data_format=None,
115
+ resample: "PILImageResampling" = None, # noqa: F821 # type: ignore
116
+ do_convert_rgb: bool = None,
117
+ do_thumbnail: bool = None,
118
+ do_align_long_axis: bool = None,
119
+ do_rescale: bool = None,
120
+ suffix=None,
121
+ extra: dict | bool = False,
122
+ **kwargs,
123
+ ):
124
+ suffix = suffix or kwargs.get("label", None)
125
+ if text:
126
+ text = self.preprocess_text(text, images)
127
+
128
+ if suffix:
129
+ suffix = self.preprocess_text(suffix, images)
130
+
131
+ batch = super().__call__(
132
+ text=text,
133
+ images=images,
134
+ tokenize_newline_separately=tokenize_newline_separately,
135
+ padding=padding,
136
+ truncation=truncation,
137
+ max_length=max_length,
138
+ return_tensors=return_tensors,
139
+ do_resize=do_resize,
140
+ do_normalize=do_normalize,
141
+ image_mean=image_mean,
142
+ image_std=image_std,
143
+ data_format=data_format,
144
+ input_data_format=input_data_format,
145
+ resample=resample,
146
+ do_convert_rgb=do_convert_rgb,
147
+ do_thumbnail=do_thumbnail,
148
+ do_align_long_axis=do_align_long_axis,
149
+ do_rescale=do_rescale,
150
+ suffix=suffix,
151
+ )
152
+
153
+ batch = self.create_attachable(batch, extra)(text=text, images=images, label=suffix)
154
+
155
+ return batch
156
+
157
+ def decode(self, outputs: torch.Tensor, do_post: bool = True, **kwargs) -> str:
158
+ """this is specific to PaliGemma"""
159
+ # converts the tokens to text
160
+ outputs = self.tokenizer.decode(outputs, **kwargs)
161
+ return outputs
162
+
163
+ def preprocess_text(
164
+ self,
165
+ text: str,
166
+ images: list[torch.Tensor | Image] | Image = None,
167
+ max_size: int = PROCESSOR_IMAGE_MAX_SIZE,
168
+ ) -> str:
169
+ # not sure what to do for multiple images if need to scale
170
+
171
+ if isinstance(images, list):
172
+ images = images[0]
173
+
174
+ if images is not None:
175
+ image_width, image_height = images.size
176
+ height_scale = max_size / image_height
177
+ width_scale = max_size / image_width
178
+
179
+ segments = segment_str(text, box_pattern=box_pattern, point_pattern=point_pattern)
180
+
181
+ out_text = ""
182
+ for seg, seg_type in segments:
183
+ if seg_type:
184
+ if seg_type == TagType.POINT:
185
+ x, y = map(int, seg)
186
+ # Scale the coordinates
187
+ scaled_x = _make_seg_text(_scale_val(x, width_scale, max_size))
188
+ scaled_y = _make_seg_text(_scale_val(y, height_scale, max_size))
189
+ # model uses y, x in examples
190
+ scaled_toks = f"{scaled_y}{scaled_x} point"
191
+ elif seg_type == TagType.BOX:
192
+ x1, y1, x2, y2 = map(int, seg)
193
+ # Scale the coordinates
194
+ scaled_x1 = _make_seg_text(_scale_val(x1, width_scale, max_size))
195
+ scaled_y1 = _make_seg_text(_scale_val(y1, height_scale, max_size))
196
+ scaled_x2 = _make_seg_text(_scale_val(x2, width_scale, max_size))
197
+ scaled_y2 = _make_seg_text(_scale_val(y2, height_scale, max_size))
198
+ # they do y1, x1, y2, x2 in examples
199
+ scaled_toks = f"{scaled_y1}{scaled_x1}{scaled_y2}{scaled_x2} box"
200
+ out_text += scaled_toks
201
+ else:
202
+ out_text += seg
203
+ return out_text
204
+
205
+ def handle_token_loc_seg(self, text: str, image_height: int, image_width: int):
206
+ _scale_height = _make_scale_dim_func(image_height)
207
+ _scale_width = _make_scale_dim_func(image_width)
208
+ box_tags = ("<box>", "</box>")
209
+ point_tags = ("<point>", "</point>")
210
+
211
+ def _make_text(tag_open: str, tag_close: str, *vals):
212
+ return (
213
+ text[: tag_open[1]] + f"{tag_open[0]}{', '.join(map(str, vals))}{tag_close[0]}" + text[tag_close[1] :]
214
+ )
215
+
216
+ def _make_yx(points: list[int]):
217
+ return _scale_height(*points[0::2]), _scale_width(*points[1::2])
218
+
219
+ while loc_match := re_loc_box.match(text):
220
+ start_idx, end_idx = zip(box_tags, loc_match.span())
221
+ (y1, y2), (x1, x2) = _make_yx(list(loc_match.groups()))
222
+ text = _make_text(start_idx, end_idx, y1, x1, y2, x2)
223
+
224
+ while loc_match := re_loc_point.match(text):
225
+ tag_open, tag_close = zip(point_tags, loc_match.span())
226
+ (y1,), (x1,) = _make_yx(list(loc_match.groups()))
227
+ text = _make_text(tag_open, tag_close, y1, x1)
228
+
229
+ return text
preprocessor_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_valid_processor_keys": [
3
+ "images",
4
+ "do_resize",
5
+ "size",
6
+ "resample",
7
+ "do_rescale",
8
+ "rescale_factor",
9
+ "do_normalize",
10
+ "image_mean",
11
+ "image_std",
12
+ "return_tensors",
13
+ "data_format",
14
+ "input_data_format",
15
+ "do_convert_rgb"
16
+ ],
17
+ "auto_map": {
18
+ "AutoProcessor": "paligemma_model.PaliGemmaProcessor"
19
+ },
20
+ "do_convert_rgb": null,
21
+ "do_normalize": true,
22
+ "do_rescale": true,
23
+ "do_resize": true,
24
+ "image_mean": [
25
+ 0.5,
26
+ 0.5,
27
+ 0.5
28
+ ],
29
+ "image_processor_type": "SiglipImageProcessor",
30
+ "image_seq_length": 4096,
31
+ "image_std": [
32
+ 0.5,
33
+ 0.5,
34
+ 0.5
35
+ ],
36
+ "processor_class": "PaliGemmaProcessor",
37
+ "resample": 3,
38
+ "rescale_factor": 0.00392156862745098,
39
+ "size": {
40
+ "height": 896,
41
+ "width": 896
42
+ }
43
+ }
processor_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "paligemma_model.PaliGemmaProcessor"
4
+ },
5
+ "processor_class": "PaliGemmaProcessor"
6
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ {
4
+ "content": "<image>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false
9
+ }
10
+ ],
11
+ "bos_token": {
12
+ "content": "<bos>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false
17
+ },
18
+ "eos_token": {
19
+ "content": "<eos>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "pad_token": {
26
+ "content": "<pad>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ },
32
+ "unk_token": {
33
+ "content": "<unk>",
34
+ "lstrip": false,
35
+ "normalized": false,
36
+ "rstrip": false,
37
+ "single_word": false
38
+ }
39
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8104de0e0f9b8ab923ac66b31bee4ae132edf05863545fa4a3b69b4774117ae
3
+ size 17763304
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8986bb4f423f07f8c7f70d0dbe3526fb2316056c17bae71b1ea975e77a168fc6
3
+ size 4264023
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff