test-flex-gpt / tokenizer.py
oweller2
dpone
6e82f17
raw
history blame
1.08 kB
from transformers import PreTrainedTokenizerFast
import numpy
import torch
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
outputs = super()._batch_encode_plus(*args, **kwargs)
del outputs["token_type_ids"]
for key in ['input_ids', 'attention_mask']:
if isinstance(outputs[key], (list, numpy.ndarray, torch.Tensor)):
if isinstance(outputs[key], list):
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
elif isinstance(outputs[key], numpy.ndarray):
outputs[key] = numpy.array([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype)
elif isinstance(outputs[key], torch.Tensor):
outputs[key] = torch.tensor([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype, device=outputs[key].device)
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer)