from transformers import PreTrainedTokenizerFast class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): def _batch_encode_plus(self, *args, **kwargs): breakpoint() outputs = super()._batch_encode_plus(*args, **kwargs) outputs['input_ids'] = [[id for id in ids if id != self.eos_token_id] for ids in outputs['input_ids']] return outputs # Register the class from transformers import AutoTokenizer AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)