File size: 3,679 Bytes
5448b02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e5d4e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from transformers import PreTrainedTokenizerFast
import numpy
import torch

class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):

    def _batch_encode_plus(self, *args, **kwargs):
        outputs = super()._batch_encode_plus(*args, **kwargs)
        del outputs["token_type_ids"]
        
        # Get the input_ids to check for EOS tokens
        input_ids = outputs['input_ids']
        
        # Function to check if sequence ends with EOS token
        def ends_with_eos(sequence):
            if len(sequence) == 0:
                return False
            return sequence[-1] == self.eos_token_id
        
        # Check for EOS tokens using input_ids only
        if isinstance(input_ids, torch.Tensor):
            last_token_is_eos = torch.tensor([
                ends_with_eos(seq) for seq in input_ids
            ], dtype=torch.bool)
            
            if last_token_is_eos.all():
                # If all sequences have EOS, just truncate all
                for key in ['input_ids', 'attention_mask']:
                    outputs[key] = outputs[key][..., :-1]
            elif last_token_is_eos.any():
                # Process each sequence individually
                batch_size = input_ids.shape[0]
                for i in range(batch_size):
                    if last_token_is_eos[i]:
                        for key in ['input_ids', 'attention_mask']:
                            # Remove last token and add padding at start for this sequence
                            truncated = outputs[key][i, :-1]
                            outputs[key][i] = torch.cat([
                                torch.zeros_like(truncated[:1]),
                                truncated
                            ])
                
        elif isinstance(input_ids, numpy.ndarray):
            last_token_is_eos = numpy.array([
                ends_with_eos(seq) for seq in input_ids
            ], dtype=bool)
            
            if last_token_is_eos.all():
                # If all sequences have EOS, just truncate all
                for key in ['input_ids', 'attention_mask']:
                    outputs[key] = outputs[key][..., :-1]
            elif last_token_is_eos.any():
                batch_size = input_ids.shape[0]
                for i in range(batch_size):
                    if last_token_is_eos[i]:
                        for key in ['input_ids', 'attention_mask']:
                            # Remove last token and add padding at start for this sequence
                            truncated = outputs[key][i, :-1]
                            outputs[key][i] = numpy.concatenate([
                                numpy.zeros_like(truncated[:1]),
                                truncated
                            ])
                    
        elif isinstance(input_ids, list):
            last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
            
            if all(last_token_is_eos):
                # If all sequences have EOS, just truncate all
                for key in ['input_ids', 'attention_mask']:
                    outputs[key] = [sequence[:-1] for sequence in outputs[key]]
            elif any(last_token_is_eos):
                for key in ['input_ids', 'attention_mask']:
                    outputs[key] = [
                        [0] + sequence[:-1] if is_eos else sequence
                        for sequence, is_eos in zip(outputs[key], last_token_is_eos)
                    ]
        
        return outputs

        
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)