File size: 2,497 Bytes
6d20d8a
6e82f17
 
3608e05
6d20d8a
f64965c
8a083e2
b54c050
 
3cd88d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38e83eb
 
3cd88d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38e83eb
3cd88d6
 
 
 
 
 
b54c050
970954b
 
 
6a605a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import PreTrainedTokenizerFast
import numpy
import torch

class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):

    def _batch_encode_plus(self, *args, **kwargs):
        outputs = super()._batch_encode_plus(*args, **kwargs)
        del outputs["token_type_ids"]
        
        # Get the input_ids to check for EOS tokens
        input_ids = outputs['input_ids']
        
        # Function to check if sequence ends with EOS token
        def ends_with_eos(sequence):
            if len(sequence) == 0:
                return False
            return sequence[-1] == self.eos_token_id
        
        # Check for EOS tokens using input_ids only
        if isinstance(input_ids, torch.Tensor):
            last_token_is_eos = torch.tensor([
                ends_with_eos(seq) for seq in input_ids
            ], dtype=torch.bool)
        elif isinstance(input_ids, numpy.ndarray):
            last_token_is_eos = numpy.array([
                ends_with_eos(seq) for seq in input_ids
            ], dtype=bool)
        elif isinstance(input_ids, list):
            last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
        
        # Use the same last_token_is_eos check for both input_ids and attention_mask
        for key in ['input_ids', 'attention_mask']:
            if isinstance(outputs[key], torch.Tensor):
                # Only remove last token where last_token_is_eos is True
                mask = last_token_is_eos.unsqueeze(-1)
                outputs[key] = torch.where(
                    mask,
                    outputs[key][..., :-1],
                    outputs[key]
                )
            elif isinstance(outputs[key], numpy.ndarray):
                # Expand dimensions for broadcasting
                mask = numpy.expand_dims(last_token_is_eos, -1)
                outputs[key] = numpy.where(
                    mask,
                    outputs[key][..., :-1],
                    outputs[key]
                )
            elif isinstance(outputs[key], list):
                # For lists, use the same last_token_is_eos list for both keys
                outputs[key] = [
                    sequence[:-1] if is_eos else sequence
                    for sequence, is_eos in zip(outputs[key], last_token_is_eos)
                ]
        
        return outputs

# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)