from transformers import PreTrainedTokenizerFast | |
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): | |
def prepare_for_model(self, ids, *args, **kwargs): | |
breakpoint() | |
filtered_ids = [id for id in ids if id != self.eos_token_id] | |
return super().prepare_for_model(filtered_ids, *args, **kwargs) | |
# Register the class | |
from transformers import AutoTokenizer | |
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |