|
import re |
|
from functools import cache |
|
|
|
import torch |
|
from PIL.Image import Image |
|
from transformers import PaliGemmaConfig as HFPaliGemmaConfig |
|
from transformers import PaliGemmaForConditionalGeneration as HFPaliGemmaForConditionalGeneration |
|
from transformers import PaliGemmaProcessor as HFPaliGemmaProcessor |
|
from transformers.utils import TensorType |
|
|
|
from pretrain_mm.processor.processor import ProcessorMixin, TextProcessorMixin |
|
from pretrain_mm.processor.tokenizer_constants import SetConstants, TokenizerConstants |
|
from pretrain_mm.utils.token_tag_utils import TagType, box_pattern, point_pattern, segment_str |
|
|
|
|
|
""" |
|
note, for patching this model atm need to fix casting in PaliGemmaProcessor._merge_input_ids_with_image_features |
|
|
|
I am not sure exactly which devices/where the issue arises from so ended up just casting multiple things as otherwise |
|
debugging/fixing on borah is a pain |
|
|
|
if edited in place on borah, location @ |
|
`/bsuhome/gannett/mambaforge/envs/pt/lib/python3.11/site-packages/transformers/models/paligemma/modeling_paligemma.py` |
|
then remove the cast |
|
|
|
spaces with actual working implementation of seg/detect/etc |
|
|
|
|
|
- https://huggingface.co/spaces/big-vision/paligemma-hf |
|
- this one has the VAE for decoding to mask |
|
- https://huggingface.co/spaces/big-vision/paligemma |
|
- this one uses the big-vision stuff which is not ideal |
|
|
|
|
|
models: |
|
https://huggingface.co/google/paligemma-3b-ft-docvqa-896 |
|
https://huggingface.co/google/paligemma-3b-ft-ocrvqa-896 |
|
""" |
|
|
|
|
|
MODEL_ID: str = "google/paligemma-3b-ft-docvqa-896" |
|
PROCESSOR_IMAGE_MAX_SIZE: int = 1024 |
|
|
|
_r_loc = r"<loc(\d{4})>" |
|
re_loc = re.compile(_r_loc) |
|
re_loc_point = re.compile(_r_loc * 2) |
|
re_loc_box = re.compile(_r_loc * 4) |
|
|
|
re_seg = re.compile(r"<seg(\d{3})>") |
|
|
|
|
|
def _scale_val(val: int, dim_scale_factor: int, max_size: int = PROCESSOR_IMAGE_MAX_SIZE): |
|
return min(int(val * dim_scale_factor), max_size) |
|
|
|
|
|
def _make_seg_text(val: int, tag: str = "loc", digits: int = 4): |
|
return f"<{tag}{val:0>{digits}}>" |
|
|
|
|
|
@cache |
|
def _make_scale_dim_func(image_dim: int, max_size: int = PROCESSOR_IMAGE_MAX_SIZE): |
|
|
|
def func(*vals: int): |
|
return [round((int(val) / max_size) * image_dim) for val in vals] |
|
|
|
return func |
|
|
|
|
|
class PaliGemmaConfig(HFPaliGemmaConfig): |
|
pass |
|
|
|
|
|
class PaliGemmaConstantsClass(TokenizerConstants): |
|
|
|
bos_token: str = "<bos>" |
|
eos_token: str = "<eos>" |
|
image_placeholder_token: str = "<image>" |
|
|
|
repr_bbox_open_text: str = "<box>" |
|
repr_bbox_close_text: str = "</box>" |
|
repr_point_open_text: str = "<point>" |
|
repr_point_close_text: str = "</point>" |
|
|
|
|
|
PaliGemmaConstants = PaliGemmaConstantsClass() |
|
|
|
|
|
class PaliGemmaForConditionalGeneration(HFPaliGemmaForConditionalGeneration): |
|
pass |
|
|
|
|
|
@SetConstants(PaliGemmaConstants) |
|
class PaliGemmaProcessor(HFPaliGemmaProcessor, ProcessorMixin, TextProcessorMixin): |
|
constants: PaliGemmaConstantsClass |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self._call = super().__call__ |
|
|
|
def __call__( |
|
self, |
|
text=None, |
|
images=None, |
|
tokenize_newline_separately=True, |
|
padding=False, |
|
truncation=None, |
|
max_length=None, |
|
return_tensors=TensorType.PYTORCH, |
|
do_resize=None, |
|
do_normalize=None, |
|
image_mean=None, |
|
image_std=None, |
|
data_format="channels_first", |
|
input_data_format=None, |
|
resample: "PILImageResampling" = None, |
|
do_convert_rgb: bool = None, |
|
do_thumbnail: bool = None, |
|
do_align_long_axis: bool = None, |
|
do_rescale: bool = None, |
|
suffix=None, |
|
extra: dict | bool = False, |
|
**kwargs, |
|
): |
|
suffix = suffix or kwargs.get("label", None) |
|
if text: |
|
text = self.preprocess_text(text, images) |
|
|
|
if suffix: |
|
suffix = self.preprocess_text(suffix, images) |
|
|
|
batch = super().__call__( |
|
text=text, |
|
images=images, |
|
tokenize_newline_separately=tokenize_newline_separately, |
|
padding=padding, |
|
truncation=truncation, |
|
max_length=max_length, |
|
return_tensors=return_tensors, |
|
do_resize=do_resize, |
|
do_normalize=do_normalize, |
|
image_mean=image_mean, |
|
image_std=image_std, |
|
data_format=data_format, |
|
input_data_format=input_data_format, |
|
resample=resample, |
|
do_convert_rgb=do_convert_rgb, |
|
do_thumbnail=do_thumbnail, |
|
do_align_long_axis=do_align_long_axis, |
|
do_rescale=do_rescale, |
|
suffix=suffix, |
|
) |
|
|
|
batch = self.create_attachable(batch, extra)(text=text, images=images, label=suffix) |
|
|
|
return batch |
|
|
|
def decode(self, outputs: torch.Tensor, do_post: bool = True, **kwargs) -> str: |
|
"""this is specific to PaliGemma""" |
|
|
|
outputs = self.tokenizer.decode(outputs, **kwargs) |
|
return outputs |
|
|
|
def preprocess_text( |
|
self, |
|
text: str, |
|
images: list[torch.Tensor | Image] | Image = None, |
|
max_size: int = PROCESSOR_IMAGE_MAX_SIZE, |
|
) -> str: |
|
|
|
|
|
if isinstance(images, list): |
|
images = images[0] |
|
|
|
if images is not None: |
|
image_width, image_height = images.size |
|
height_scale = max_size / image_height |
|
width_scale = max_size / image_width |
|
|
|
segments = segment_str(text, box_pattern=box_pattern, point_pattern=point_pattern) |
|
|
|
out_text = "" |
|
for seg, seg_type in segments: |
|
if seg_type: |
|
if seg_type == TagType.POINT: |
|
x, y = map(int, seg) |
|
|
|
scaled_x = _make_seg_text(_scale_val(x, width_scale, max_size)) |
|
scaled_y = _make_seg_text(_scale_val(y, height_scale, max_size)) |
|
|
|
scaled_toks = f"{scaled_y}{scaled_x} point" |
|
elif seg_type == TagType.BOX: |
|
x1, y1, x2, y2 = map(int, seg) |
|
|
|
scaled_x1 = _make_seg_text(_scale_val(x1, width_scale, max_size)) |
|
scaled_y1 = _make_seg_text(_scale_val(y1, height_scale, max_size)) |
|
scaled_x2 = _make_seg_text(_scale_val(x2, width_scale, max_size)) |
|
scaled_y2 = _make_seg_text(_scale_val(y2, height_scale, max_size)) |
|
|
|
scaled_toks = f"{scaled_y1}{scaled_x1}{scaled_y2}{scaled_x2} box" |
|
out_text += scaled_toks |
|
else: |
|
out_text += seg |
|
return out_text |
|
|
|
def handle_token_loc_seg(self, text: str, image_height: int, image_width: int): |
|
_scale_height = _make_scale_dim_func(image_height) |
|
_scale_width = _make_scale_dim_func(image_width) |
|
box_tags = ("<box>", "</box>") |
|
point_tags = ("<point>", "</point>") |
|
|
|
def _make_text(tag_open: str, tag_close: str, *vals): |
|
return ( |
|
text[: tag_open[1]] + f"{tag_open[0]}{', '.join(map(str, vals))}{tag_close[0]}" + text[tag_close[1] :] |
|
) |
|
|
|
def _make_yx(points: list[int]): |
|
return _scale_height(*points[0::2]), _scale_width(*points[1::2]) |
|
|
|
while loc_match := re_loc_box.match(text): |
|
start_idx, end_idx = zip(box_tags, loc_match.span()) |
|
(y1, y2), (x1, x2) = _make_yx(list(loc_match.groups())) |
|
text = _make_text(start_idx, end_idx, y1, x1, y2, x2) |
|
|
|
while loc_match := re_loc_point.match(text): |
|
tag_open, tag_close = zip(point_tags, loc_match.span()) |
|
(y1,), (x1,) = _make_yx(list(loc_match.groups())) |
|
text = _make_text(tag_open, tag_close, y1, x1) |
|
|
|
return text |
|
|