File size: 488 Bytes
6d20d8a 3608e05 6d20d8a f64965c d77b85b f64965c 970954b 6a605a0 |
1 2 3 4 5 6 7 8 9 10 11 12 |
from transformers import PreTrainedTokenizerFast
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def prepare_for_model(self, ids, *args, **kwargs):
breakpoint()
filtered_ids = [id for id in ids if id != self.eos_token_id]
return super().prepare_for_model(filtered_ids, *args, **kwargs)
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |