dqnguyen commited on
Commit
4c4cb27
1 Parent(s): b70c240

Upload tokenization_bertweet_fast.py

Browse files
Files changed (1) hide show
  1. tokenization_bertweet_fast.py +324 -0
tokenization_bertweet_fast.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (c) 2020, VinAI Research and the HuggingFace Inc. team.
3
+ # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ Tokenization classes for BERTweet"""
17
+
18
+ import os
19
+ from collections import defaultdict
20
+ from shutil import copyfile
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ from transformers.tokenization_utils_base import EncodingFast
24
+
25
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
26
+ from transformers.utils import logging
27
+ from transformers import BertweetTokenizer
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ VOCAB_FILES_NAMES = {
33
+ "vocab_file": "vocab.txt",
34
+ "merges_file": "bpe.codes",
35
+ "tokenizer_file": "tokenizer.json",
36
+ }
37
+
38
+ PRETRAINED_VOCAB_FILES_MAP = {
39
+ "vocab_file": {
40
+ "vinai/bertweet-base": "https://huggingface.co/vinai/bertweet-base/resolve/main/vocab.txt",
41
+ },
42
+ "merges_file": {
43
+ "vinai/bertweet-base": "https://huggingface.co/vinai/bertweet-base/resolve/main/bpe.codes",
44
+ },
45
+ "tokenizer_file": {
46
+ "vinai/bertweet-base": "https://huggingface.co/vinai/bertweet-base/resolve/main/tokenizer.json",
47
+ },
48
+ }
49
+
50
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
51
+ "vinai/bertweet-base": 128,
52
+ }
53
+
54
+
55
+ class BertweetTokenizerFast(PreTrainedTokenizerFast):
56
+ """
57
+ Construct a "Fast" BPE tokenizer for BERTweet (backed by HuggingFace's *tokenizers* library).
58
+
59
+ Peculiarities:
60
+
61
+ - uses BERT's pre-tokenizer: BertPreTokenizer splits tokens on spaces, and also on punctuation. Each occurrence of
62
+ a punctuation character will be treated separately.
63
+
64
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the methods. Users should refer to the
65
+ superclass for more information regarding methods.
66
+
67
+ Args:
68
+ vocab_file (`str`):
69
+ Path to the vocabulary file.
70
+ merges_file (`str`):
71
+ Path to the merges file.
72
+ """
73
+
74
+ vocab_files_names = VOCAB_FILES_NAMES
75
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
76
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
77
+ model_input_names = ["input_ids", "attention_mask"]
78
+ slow_tokenizer_class = BertweetTokenizer
79
+
80
+ def __init__(
81
+ self,
82
+ vocab_file=None,
83
+ merges_file=None,
84
+ tokenizer_file=None,
85
+ bos_token="<s>",
86
+ eos_token="</s>",
87
+ sep_token="</s>",
88
+ cls_token="<s>",
89
+ unk_token="<unk>",
90
+ pad_token="<pad>",
91
+ mask_token="<mask>",
92
+ **kwargs
93
+ ):
94
+ super().__init__(
95
+ vocab_file,
96
+ merges_file,
97
+ tokenizer_file=tokenizer_file,
98
+ bos_token=bos_token,
99
+ eos_token=eos_token,
100
+ sep_token=sep_token,
101
+ cls_token=cls_token,
102
+ unk_token=unk_token,
103
+ pad_token=pad_token,
104
+ mask_token=mask_token,
105
+ **kwargs,
106
+ )
107
+
108
+ self.vocab_file = vocab_file
109
+ self.merges_file = merges_file
110
+ self.can_save_slow_tokenizer = False if not self.vocab_file else True
111
+
112
+ def get_added_vocab_hacking(self):
113
+ """
114
+ Returns the added tokens in the vocabulary as a dictionary of token to index.
115
+
116
+ Returns:
117
+ `Dict[str, int], Dict[int, int]`: The added tokens, and their original and new ids
118
+ """
119
+ base_vocab_size = self._tokenizer.get_vocab_size(with_added_tokens=False)
120
+ full_vocab_size = self._tokenizer.get_vocab_size(with_added_tokens=True)
121
+ if full_vocab_size == base_vocab_size:
122
+ return {}, {}
123
+
124
+ # Tokens in added_vocab should have ids that are equal to or larger than the size of base_vocab
125
+ added_vocab = dict(
126
+ (self._tokenizer.id_to_token(index), index + 1 - base_vocab_size + self.mask_token_id)
127
+ for index in range(base_vocab_size, full_vocab_size)
128
+ )
129
+
130
+ id_mapping = dict((index, self._tokenizer.token_to_id(tok)) for tok, index in added_vocab.items())
131
+
132
+ return added_vocab, id_mapping
133
+
134
+ def _decode(
135
+ self,
136
+ token_ids: Union[int, List[int]],
137
+ skip_special_tokens: bool = False,
138
+ clean_up_tokenization_spaces: bool = True,
139
+ **kwargs
140
+ ) -> str:
141
+ self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
142
+
143
+ if isinstance(token_ids, int):
144
+ token_ids = [token_ids]
145
+
146
+ # Mapping ids into their original values
147
+ _, id_mapping = self.get_added_vocab_hacking()
148
+ if len(id_mapping) > 0:
149
+ token_ids = [id_mapping[id] if id in id_mapping else id for id in token_ids]
150
+
151
+ text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
152
+
153
+ if clean_up_tokenization_spaces:
154
+ clean_text = self.clean_up_tokenization(text)
155
+ return clean_text
156
+ else:
157
+ return text
158
+
159
+ def _convert_encoding(
160
+ self,
161
+ encoding: EncodingFast,
162
+ return_token_type_ids: Optional[bool] = None,
163
+ return_attention_mask: Optional[bool] = None,
164
+ return_overflowing_tokens: bool = False,
165
+ return_special_tokens_mask: bool = False,
166
+ return_offsets_mapping: bool = False,
167
+ return_length: bool = False,
168
+ verbose: bool = True,
169
+ ) -> Tuple[Dict[str, Any], List[EncodingFast]]:
170
+ """
171
+ Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict and a list
172
+ of encodings, take care of building a batch from overflowing tokens.
173
+
174
+ Overflowing tokens are converted to additional examples (like batches) so the output values of the dict are
175
+ lists (overflows) of lists (tokens).
176
+
177
+ Output shape: (overflows, sequence length)
178
+ """
179
+ if return_token_type_ids is None:
180
+ return_token_type_ids = "token_type_ids" in self.model_input_names
181
+ if return_attention_mask is None:
182
+ return_attention_mask = "attention_mask" in self.model_input_names
183
+
184
+ if return_overflowing_tokens and encoding.overflowing is not None:
185
+ encodings = [encoding] + encoding.overflowing
186
+ else:
187
+ encodings = [encoding]
188
+
189
+ encoding_dict = defaultdict(list)
190
+ added_vocab, _ = self.get_added_vocab_hacking()
191
+ for e in encodings:
192
+ # encoding_dict["input_ids"].append(e.ids)
193
+ # Reassign ids of tokens due to the hacking strategy
194
+ ids = []
195
+ for id, token in zip(e.ids, e.tokens):
196
+ if id <= self.mask_token_id:
197
+ ids.append(id)
198
+ else:
199
+ if token.strip() in added_vocab:
200
+ ids.append(added_vocab[token.strip()])
201
+ else:
202
+ ids.append(self.unk_token_id)
203
+
204
+ encoding_dict["input_ids"].append(ids)
205
+
206
+ if return_token_type_ids:
207
+ encoding_dict["token_type_ids"].append(e.type_ids)
208
+ if return_attention_mask:
209
+ encoding_dict["attention_mask"].append(e.attention_mask)
210
+ if return_special_tokens_mask:
211
+ encoding_dict["special_tokens_mask"].append(e.special_tokens_mask)
212
+ if return_offsets_mapping:
213
+ encoding_dict["offset_mapping"].append(e.offsets)
214
+ if return_length:
215
+ # encoding_dict["length"].append(len(e.ids))
216
+ encoding_dict["length"].append(len(ids))
217
+
218
+ return encoding_dict, encodings
219
+
220
+ def build_inputs_with_special_tokens(
221
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
222
+ ) -> List[int]:
223
+ """
224
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
225
+ adding special tokens. A BERTweet sequence has the following format:
226
+
227
+ - single sequence: `<s> X </s>`
228
+ - pair of sequences: `<s> A </s></s> B </s>`
229
+
230
+ Args:
231
+ token_ids_0 (`List[int]`):
232
+ List of IDs to which the special tokens will be added.
233
+ token_ids_1 (`List[int]`, *optional*):
234
+ Optional second list of IDs for sequence pairs.
235
+
236
+ Returns:
237
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
238
+ """
239
+
240
+ if token_ids_1 is None:
241
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
242
+ cls = [self.cls_token_id]
243
+ sep = [self.sep_token_id]
244
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
245
+
246
+ def get_special_tokens_mask(
247
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
248
+ ) -> List[int]:
249
+ """
250
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
251
+ special tokens using the tokenizer `prepare_for_model` method.
252
+
253
+ Args:
254
+ token_ids_0 (`List[int]`):
255
+ List of IDs.
256
+ token_ids_1 (`List[int]`, *optional*):
257
+ Optional second list of IDs for sequence pairs.
258
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
259
+ Whether or not the token list is already formatted with special tokens for the model.
260
+
261
+ Returns:
262
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
263
+ """
264
+
265
+ if already_has_special_tokens:
266
+ return super().get_special_tokens_mask(
267
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
268
+ )
269
+
270
+ if token_ids_1 is None:
271
+ return [1] + ([0] * len(token_ids_0)) + [1]
272
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
273
+
274
+ def create_token_type_ids_from_sequences(
275
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
276
+ ) -> List[int]:
277
+ """
278
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. BERTweet does
279
+ not make use of token type ids, therefore a list of zeros is returned.
280
+
281
+ Args:
282
+ token_ids_0 (`List[int]`):
283
+ List of IDs.
284
+ token_ids_1 (`List[int]`, *optional*):
285
+ Optional second list of IDs for sequence pairs.
286
+
287
+ Returns:
288
+ `List[int]`: List of zeros.
289
+
290
+ """
291
+
292
+ sep = [self.sep_token_id]
293
+ cls = [self.cls_token_id]
294
+
295
+ if token_ids_1 is None:
296
+ return len(cls + token_ids_0 + sep) * [0]
297
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
298
+
299
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
300
+ if not self.can_save_slow_tokenizer:
301
+ raise ValueError(
302
+ "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
303
+ "tokenizer."
304
+ )
305
+
306
+ if not os.path.isdir(save_directory):
307
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
308
+ return
309
+
310
+ out_vocab_file = os.path.join(
311
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
312
+ )
313
+
314
+ out_merges_file = os.path.join(
315
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
316
+ )
317
+
318
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
319
+ copyfile(self.vocab_file, out_vocab_file)
320
+
321
+ if os.path.abspath(self.merges_file) != os.path.abspath(out_merges_file):
322
+ copyfile(self.merges_file, out_merges_file)
323
+
324
+ return (out_vocab_file, out_merges_file)