File size: 538 Bytes
6d20d8a 3608e05 6d20d8a f64965c 8a083e2 d77b85b 8a083e2 970954b 6a605a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from transformers import PreTrainedTokenizerFast
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
breakpoint()
outputs = super()._batch_encode_plus(*args, **kwargs)
outputs['input_ids'] = [[id for id in ids if id != self.eos_token_id] for ids in outputs['input_ids']]
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |