|
from typing import Any, Dict, Optional |
|
import PIL |
|
import torch |
|
import PIL |
|
import torch |
|
from typing import Dict |
|
from io import BytesIO |
|
from transformers import SiglipImageProcessor |
|
from sentence_transformers.models import Transformer as BaseTransformer |
|
|
|
|
|
class MultiModalTransformer(BaseTransformer): |
|
|
|
def __init__( |
|
self, |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
tokenizer_args: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
): |
|
super().__init__(model_name_or_path, **kwargs) |
|
if tokenizer_args is None: |
|
tokenizer_args = {} |
|
self.processor = SiglipImageProcessor.from_pretrained( |
|
model_name_or_path, cache_dir=cache_dir, **tokenizer_args |
|
) |
|
|
|
def forward( |
|
self, features: dict[str, torch.Tensor], **kwargs |
|
) -> dict[str, torch.Tensor]: |
|
trans_features = { |
|
"input_ids": features["input_ids"], |
|
"attention_mask": features["attention_mask"], |
|
} |
|
if "pixel_values" in features: |
|
trans_features["pixel_values"] = features["pixel_values"].to( |
|
self.auto_model.dtype |
|
) |
|
|
|
sentence_embedding = self.auto_model(**trans_features, **kwargs)[ |
|
"sentence_embedding" |
|
] |
|
features.update({"sentence_embedding": sentence_embedding}) |
|
return features |
|
|
|
def tokenize(self, texts: list[Dict] | list[str]) -> dict[str, torch.Tensor]: |
|
img_start_token = "<|jasper_img_start|>" |
|
img_token = "<|jasper_img_token|>" |
|
img_end_token = "<|jasper_img_end|>" |
|
num_img_tokens = 300 |
|
|
|
def process_text_item(item): |
|
if isinstance(item, str): |
|
return item, [] |
|
text, images = "", [] |
|
for sub_item in item: |
|
if sub_item["type"] == "text": |
|
text += sub_item["content"] |
|
elif sub_item["type"] == "image_bytes": |
|
text += img_start_token + img_token * num_img_tokens + img_end_token |
|
images.append( |
|
PIL.Image.open(BytesIO(sub_item["content"])).convert("RGB") |
|
) |
|
elif sub_item["type"] == "image_path": |
|
text += img_start_token + img_token * num_img_tokens + img_end_token |
|
images.append(PIL.Image.open(sub_item["content"]).convert("RGB")) |
|
else: |
|
raise ValueError(f"unknown data type {sub_item['type']}") |
|
return text, images |
|
|
|
all_texts, all_images = [], [] |
|
for item in texts: |
|
text, images = process_text_item(item) |
|
all_texts.append(text) |
|
all_images.extend(images) |
|
ipt = self.tokenizer( |
|
all_texts, |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_seq_length, |
|
return_tensors="pt", |
|
) |
|
if all_images: |
|
ipt["pixel_values"] = self.processor( |
|
images=all_images, return_tensors="pt" |
|
)["pixel_values"] |
|
return ipt |
|
|