from transformers import PreTrainedTokenizerFast import numpy import torch class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast): def _batch_encode_plus(self, *args, **kwargs): outputs = super()._batch_encode_plus(*args, **kwargs) del outputs["token_type_ids"] # Get the input_ids to check for EOS tokens input_ids = outputs['input_ids'] # Function to check if sequence ends with EOS token def ends_with_eos(sequence): if len(sequence) == 0: return False return sequence[-1] == self.eos_token_id # Check for EOS tokens using input_ids only if isinstance(input_ids, torch.Tensor): last_token_is_eos = torch.tensor([ ends_with_eos(seq) for seq in input_ids ], dtype=torch.bool) elif isinstance(input_ids, numpy.ndarray): last_token_is_eos = numpy.array([ ends_with_eos(seq) for seq in input_ids ], dtype=bool) elif isinstance(input_ids, list): last_token_is_eos = [ends_with_eos(seq) for seq in input_ids] # Use the same last_token_is_eos check for both input_ids and attention_mask for key in ['input_ids', 'attention_mask']: if isinstance(outputs[key], torch.Tensor): # Only remove last token where last_token_is_eos is True mask = last_token_is_eos.unsqueeze(-1) outputs[key] = torch.where( mask, outputs[key][..., :-1], outputs[key] ) elif isinstance(outputs[key], numpy.ndarray): # Expand dimensions for broadcasting mask = numpy.expand_dims(last_token_is_eos, -1) outputs[key] = numpy.where( mask, outputs[key][..., :-1], outputs[key] ) elif isinstance(outputs[key], list): # For lists, use the same last_token_is_eos list for both keys outputs[key] = [ sequence[:-1] if is_eos else sequence for sequence, is_eos in zip(outputs[key], last_token_is_eos) ] return outputs # Register the class from transformers import AutoTokenizer AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)