File size: 3,679 Bytes
5448b02 8e5d4e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from transformers import PreTrainedTokenizerFast
import numpy
import torch
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
def _batch_encode_plus(self, *args, **kwargs):
outputs = super()._batch_encode_plus(*args, **kwargs)
del outputs["token_type_ids"]
# Get the input_ids to check for EOS tokens
input_ids = outputs['input_ids']
# Function to check if sequence ends with EOS token
def ends_with_eos(sequence):
if len(sequence) == 0:
return False
return sequence[-1] == self.eos_token_id
# Check for EOS tokens using input_ids only
if isinstance(input_ids, torch.Tensor):
last_token_is_eos = torch.tensor([
ends_with_eos(seq) for seq in input_ids
], dtype=torch.bool)
if last_token_is_eos.all():
# If all sequences have EOS, just truncate all
for key in ['input_ids', 'attention_mask']:
outputs[key] = outputs[key][..., :-1]
elif last_token_is_eos.any():
# Process each sequence individually
batch_size = input_ids.shape[0]
for i in range(batch_size):
if last_token_is_eos[i]:
for key in ['input_ids', 'attention_mask']:
# Remove last token and add padding at start for this sequence
truncated = outputs[key][i, :-1]
outputs[key][i] = torch.cat([
torch.zeros_like(truncated[:1]),
truncated
])
elif isinstance(input_ids, numpy.ndarray):
last_token_is_eos = numpy.array([
ends_with_eos(seq) for seq in input_ids
], dtype=bool)
if last_token_is_eos.all():
# If all sequences have EOS, just truncate all
for key in ['input_ids', 'attention_mask']:
outputs[key] = outputs[key][..., :-1]
elif last_token_is_eos.any():
batch_size = input_ids.shape[0]
for i in range(batch_size):
if last_token_is_eos[i]:
for key in ['input_ids', 'attention_mask']:
# Remove last token and add padding at start for this sequence
truncated = outputs[key][i, :-1]
outputs[key][i] = numpy.concatenate([
numpy.zeros_like(truncated[:1]),
truncated
])
elif isinstance(input_ids, list):
last_token_is_eos = [ends_with_eos(seq) for seq in input_ids]
if all(last_token_is_eos):
# If all sequences have EOS, just truncate all
for key in ['input_ids', 'attention_mask']:
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
elif any(last_token_is_eos):
for key in ['input_ids', 'attention_mask']:
outputs[key] = [
[0] + sequence[:-1] if is_eos else sequence
for sequence, is_eos in zip(outputs[key], last_token_is_eos)
]
return outputs
# Register the class
from transformers import AutoTokenizer
AutoTokenizer.register(ModernDecoderBERTTokenizer, fast_tokenizer_class=ModernDecoderBERTTokenizer) |