File size: 707 Bytes
d776f8b
 
 
8250eed
3608e05
 
 
bfe22ad
efef38a
bfe22ad
 
efef38a
 
5a1ce8b
 
 
 
efef38a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import PreTrainedTokenizer, AutoTokenizer

class ModernDecoderBERTTokenizer(PreTrainedTokenizer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
        breakpoint()
        if token_ids_1 is None:
            return [id for id in token_ids_0 if id != self.eos_token_id]
        return [id for id in token_ids_0 if id != self.eos_token_id] + [id for id in token_ids_1 if id != self.eos_token_id]

    def get_vocab(self):
        breakpoint()
        return dict(self.vocab.items())
        
AutoTokenizer.register("ModernDecoderBERTTokenizer", ModernDecoderBERTTokenizer)