|
from itertools import chain |
|
from transformers import GitProcessor |
|
|
|
class GIAProcessor(GitProcessor): |
|
def __init__(self, image_processor, tokenizer): |
|
super().__init__(image_processor, tokenizer) |
|
self._block_size = 1024 |
|
|
|
def _group_texts(self, examples): |
|
|
|
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
|
total_length = len(concatenated_examples[list(examples.keys())[0]]) |
|
|
|
|
|
total_length = (total_length // self._block_size) * self._block_size |
|
|
|
result = { |
|
k: [t[i: i + self._block_size] for i in range(0, total_length, self._block_size)] |
|
for k, t in concatenated_examples.items() |
|
} |
|
return result |
|
|
|
def __call__(self, text=None, images=None, return_tensors=None, **kwargs): |
|
if text is not None and images is None: |
|
encoded_text = self.tokenizer(text, return_tensors=return_tensors) |
|
encoding = self._group_texts(encoded_text) |
|
elif text is not None and images is not None: |
|
encoding = super().__call__(text, images, return_tensors, **kwargs) |
|
|
|
return encoding |
|
|
|
def batch_decode(self, *args, **kwargs): |
|
return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
|
def decode(self, *args, **kwargs): |
|
return self.tokenizer.decode(*args, **kwargs) |
|
|
|
@property |
|
def model_input_names(self): |
|
return ["input_ids", "attention_mask", "pixel_values"] |
|
|
|
|
|
GIAProcessor.register_for_auto_class("AutoProcessor") |