File size: 538 Bytes
6d20d8a
3608e05
6d20d8a
f64965c
8a083e2
d77b85b
8a083e2
 
 
970954b
 
 
6a605a0
1
2
3
4
5
6
7
8
9
10
11
12
13
from transformers import PreTrainedTokenizerFast

class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):

    def _batch_encode_plus(self, *args, **kwargs):
        breakpoint()
        outputs = super()._batch_encode_plus(*args, **kwargs)
        outputs['input_ids'] = [[id for id in ids if id != self.eos_token_id] for ids in outputs['input_ids']]
        return outputs

# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)