File size: 1,822 Bytes
2e8ba46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from itertools import chain
from transformers import GitProcessor

class GIAProcessor(GitProcessor):
    def __init__(self, image_processor, tokenizer):
        super().__init__(image_processor, tokenizer)
        self._block_size = 1024

    def _group_texts(self, examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // self._block_size) * self._block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i: i + self._block_size] for i in range(0, total_length, self._block_size)]
            for k, t in concatenated_examples.items()
        }
        return result

    def __call__(self, text=None, images=None, return_tensors=None, **kwargs):
        if text is not None and images is None:
            encoded_text = self.tokenizer(text, return_tensors=return_tensors)
            encoding = self._group_texts(encoded_text)
        elif text is not None and images is not None:
            encoding = super().__call__(text, images, return_tensors, **kwargs)

        return encoding

    def batch_decode(self, *args, **kwargs):
        return self.tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.tokenizer.decode(*args, **kwargs)

    @property
    def model_input_names(self):
        return ["input_ids", "attention_mask", "pixel_values"]


GIAProcessor.register_for_auto_class("AutoProcessor")