oweller2
commited on
Commit
·
6e82f17
1
Parent(s):
7b38d2c
dpone
Browse files- tokenizer.py +13 -5
tokenizer.py
CHANGED
@@ -1,13 +1,21 @@
|
|
1 |
from transformers import PreTrainedTokenizerFast
|
|
|
|
|
2 |
|
3 |
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
4 |
|
5 |
def _batch_encode_plus(self, *args, **kwargs):
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
# Register the class
|
13 |
from transformers import AutoTokenizer
|
|
|
1 |
from transformers import PreTrainedTokenizerFast
|
2 |
+
import numpy
|
3 |
+
import torch
|
4 |
|
5 |
class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
|
6 |
|
7 |
def _batch_encode_plus(self, *args, **kwargs):
|
8 |
+
outputs = super()._batch_encode_plus(*args, **kwargs)
|
9 |
+
del outputs["token_type_ids"]
|
10 |
+
for key in ['input_ids', 'attention_mask']:
|
11 |
+
if isinstance(outputs[key], (list, numpy.ndarray, torch.Tensor)):
|
12 |
+
if isinstance(outputs[key], list):
|
13 |
+
outputs[key] = [sequence[:-1] for sequence in outputs[key]]
|
14 |
+
elif isinstance(outputs[key], numpy.ndarray):
|
15 |
+
outputs[key] = numpy.array([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype)
|
16 |
+
elif isinstance(outputs[key], torch.Tensor):
|
17 |
+
outputs[key] = torch.tensor([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype, device=outputs[key].device)
|
18 |
+
return outputs
|
19 |
|
20 |
# Register the class
|
21 |
from transformers import AutoTokenizer
|