File size: 488 Bytes
6d20d8a
3608e05
6d20d8a
f64965c
 
d77b85b
f64965c
 
970954b
 
 
6a605a0
1
2
3
4
5
6
7
8
9
10
11
12
from transformers import PreTrainedTokenizerFast

class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):

    def prepare_for_model(self, ids, *args, **kwargs):
        breakpoint()
        filtered_ids = [id for id in ids if id != self.eos_token_id] 
        return super().prepare_for_model(filtered_ids, *args, **kwargs)

# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)