from transformers import PreTrainedTokenizer, AutoTokenizer class ModernDecoderBERTTokenizer(PreTrainedTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): breakpoint() if token_ids_1 is None: return [id for id in token_ids_0 if id != self.eos_token_id] return [id for id in token_ids_0 if id != self.eos_token_id] + [id for id in token_ids_1 if id != self.eos_token_id] def get_vocab(self): breakpoint() return dict(self.vocab.items()) AutoTokenizer.register("ModernDecoderBERTTokenizer", ModernDecoderBERTTokenizer)