from transformers import PreTrainedTokenizerFast | |
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): | |
def _batch_encode_plus(self, *args, **kwargs): | |
breakpoint() | |
outputs = super()._batch_encode_plus(*args, **kwargs) | |
outputs['input_ids'] = [[id for id in ids if id != self.eos_token_id] for ids in outputs['input_ids']] | |
return outputs | |
# Register the class | |
from transformers import AutoTokenizer | |
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |