oweller2 commited on
Commit
6e82f17
·
1 Parent(s): 7b38d2c
Files changed (1) hide show
  1. tokenizer.py +13 -5
tokenizer.py CHANGED
@@ -1,13 +1,21 @@
1
  from transformers import PreTrainedTokenizerFast
 
 
2
 
3
  class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
4
 
5
  def _batch_encode_plus(self, *args, **kwargs):
6
- outputs = super()._batch_encode_plus(*args, **kwargs)
7
- del outputs["token_type_ids"]
8
- for key in ['input_ids', 'attention_mask']:
9
- outputs[key] = [sequence[:-1] for sequence in outputs[key]]
10
- return outputs
 
 
 
 
 
 
11
 
12
  # Register the class
13
  from transformers import AutoTokenizer
 
1
  from transformers import PreTrainedTokenizerFast
2
+ import numpy
3
+ import torch
4
 
5
  class ModernDecoderBERTTokenizer(PreTrainedTokenizerFast):
6
 
7
  def _batch_encode_plus(self, *args, **kwargs):
8
+ outputs = super()._batch_encode_plus(*args, **kwargs)
9
+ del outputs["token_type_ids"]
10
+ for key in ['input_ids', 'attention_mask']:
11
+ if isinstance(outputs[key], (list, numpy.ndarray, torch.Tensor)):
12
+ if isinstance(outputs[key], list):
13
+ outputs[key] = [sequence[:-1] for sequence in outputs[key]]
14
+ elif isinstance(outputs[key], numpy.ndarray):
15
+ outputs[key] = numpy.array([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype)
16
+ elif isinstance(outputs[key], torch.Tensor):
17
+ outputs[key] = torch.tensor([sequence[:-1] for sequence in outputs[key]], dtype=outputs[key].dtype, device=outputs[key].device)
18
+ return outputs
19
 
20
  # Register the class
21
  from transformers import AutoTokenizer