test-flex-gpt / tokenizer.py
oweller2
update
8a083e2
raw
history blame
538 Bytes
from transformers import PreTrainedTokenizerFast
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
breakpoint()
outputs = super()._batch_encode_plus(*args, **kwargs)
outputs['input_ids'] = [[id for id in ids if id != self.eos_token_id] for ids in outputs['input_ids']]
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)