test-flex-gpt / tokenizer.py
oweller2
unpad
3cd88d6
raw
history blame
2.5 kB
from transformers import PreTrainedTokenizerFast
import numpy
import torch
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
outputs = super()._batch_encode_plus(*args, **kwargs)
del outputs["token_type_ids"]
# Get the input_ids to check for EOS tokens
input_ids = outputs['input_ids']
# Function to check if sequence ends with EOS token
def ends_with_eos(sequence):
if len(sequence) == 0:
return False
return sequence[-1] == self.eos_token_id
# Check for EOS tokens using input_ids only
if isinstance(input_ids, torch.Tensor):
last_token_is_eos = torch.tensor([
ends_with_eos(seq) for seq in input_ids
], dtype=torch.bool)
elif isinstance(input_ids, numpy.ndarray):
last_token_is_eos = numpy.array([
ends_with_eos(seq) for seq in input_ids
], dtype=bool)
elif isinstance(input_ids, list):
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
# Use the same last_token_is_eos check for both input_ids and attention_mask
for key in ['input_ids', 'attention_mask']:
if isinstance(outputs[key], torch.Tensor):
# Only remove last token where last_token_is_eos is True
mask = last_token_is_eos.unsqueeze(-1)
outputs[key] = torch.where(
mask,
outputs[key][..., :-1],
outputs[key]
)
elif isinstance(outputs[key], numpy.ndarray):
# Expand dimensions for broadcasting
mask = numpy.expand_dims(last_token_is_eos, -1)
outputs[key] = numpy.where(
mask,
outputs[key][..., :-1],
outputs[key]
)
elif isinstance(outputs[key], list):
# For lists, use the same last_token_is_eos list for both keys
outputs[key] = [
sequence[:-1] if is_eos else sequence
for sequence, is_eos in zip(outputs[key], last_token_is_eos)
]
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)