test-flex-gpt / tokenizer.py
oweller2
update
f64965c
raw
history blame
488 Bytes
from transformers import PreTrainedTokenizerFast
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def prepare_for_model(self, ids, *args, **kwargs):
breakpoint()
filtered_ids = [id for id in ids if id != self.eos_token_id]
return super().prepare_for_model(filtered_ids, *args, **kwargs)
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)