BrightXiaoHan commited on
Commit
46904af
1 Parent(s): 49bdfe5

upload tokenizer

Browse files
Files changed (4) hide show
  1. data_utils.py +319 -0
  2. special_tokens_map.json +1 -11
  3. tokenizers_pegasus.py +598 -0
  4. vocab.txt +5 -5
data_utils.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import re
4
+ import six
5
+ import unicodedata
6
+ import torch
7
+ import rouge
8
+ import numpy as np
9
+ import random
10
+ # from fengshen.examples.pegasus.pegasus_utils import text_segmentate
11
+ import sys
12
+
13
+ sys.path.append('../../../')
14
+
15
+ rouge = rouge.Rouge()
16
+
17
+
18
+ is_py2 = six.PY2
19
+
20
+ if not is_py2:
21
+ basestring = str
22
+
23
+
24
+ def _is_chinese_char(cp):
25
+ """Checks whether CP is the codepoint of a CJK character."""
26
+ # This defines a "chinese character" as anything in the CJK Unicode block:
27
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
28
+ #
29
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
30
+ # despite its name. The modern Korean Hangul alphabet is a different block,
31
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
32
+ # space-separated words, so they are not treated specially and handled
33
+ # like the all of the other languages.
34
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
35
+ or (cp >= 0x20000 and cp <= 0x2A6DF)
36
+ or (cp >= 0x2A700 and cp <= 0x2B73F)
37
+ or (cp >= 0x2B740 and cp <= 0x2B81F)
38
+ or (cp >= 0x2B820 and cp <= 0x2CEAF)
39
+ or (cp >= 0xF900 and cp <= 0xFAFF)
40
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)):
41
+ return True
42
+
43
+ return False
44
+
45
+
46
+ def _is_whitespace(char):
47
+ """Checks whether `char` is a whitespace character."""
48
+ # \t, \n, and \r are technically control characters but we treat them
49
+ # as whitespace since they are generally considered as such.
50
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
51
+ return True
52
+ cat = unicodedata.category(char)
53
+ if cat == "Zs":
54
+ return True
55
+ return False
56
+
57
+
58
+ def _is_control(char):
59
+ """Checks whether `char` is a control character."""
60
+ # These are technically control characters but we count them as whitespace
61
+ # characters.
62
+ if char == "\t" or char == "\n" or char == "\r":
63
+ return False
64
+ cat = unicodedata.category(char)
65
+ if cat.startswith("C"):
66
+ return True
67
+ return False
68
+
69
+
70
+ def _is_punctuation(char):
71
+ """Checks whether `char` is a punctuation character."""
72
+ cp = ord(char)
73
+ # We treat all non-letter/number ASCII as punctuation.
74
+ # Characters such as "^", "$", and "`" are not in the Unicode
75
+ # Punctuation class but we treat them as punctuation anyways, for
76
+ # consistency.
77
+ if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (
78
+ cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126):
79
+ return True
80
+ cat = unicodedata.category(char)
81
+ if cat.startswith("P"):
82
+ return True
83
+ return False
84
+
85
+
86
+ def is_string(s):
87
+ """判断是否是字符串
88
+ """
89
+ return isinstance(s, basestring)
90
+
91
+
92
+ def is_stopwords(word, stopwords):
93
+ if word in stopwords:
94
+ return True
95
+ else:
96
+ return False
97
+
98
+
99
+ def text_segmentate(text):
100
+ en_seg_pattern = '((?:\\!|\\?|\\.|\\n)+(?:\\s)+)'
101
+ ch_seg_pattern = '((?:?|!|。|\\n)+)'
102
+ try:
103
+ text = re.sub(en_seg_pattern, r'\1[SEP]', text)
104
+ # print("sub text: ", text)
105
+ except Exception as e:
106
+ print("input: ", text)
107
+ raise e
108
+ text = re.sub(ch_seg_pattern, r'\1[SEP]', text)
109
+ # print("sub ch text: ", text)
110
+ text_list = text.split("[SEP]")
111
+ text_list = list(filter(lambda x: len(x) != 0, text_list))
112
+ return text_list
113
+
114
+
115
+ def load_stopwords(stopwords_path):
116
+ stopwords_dict = {}
117
+ with open(stopwords_path, "r") as rf:
118
+ for line in rf:
119
+ line = line.strip()
120
+ if line not in stopwords_dict:
121
+ stopwords_dict[line] = 0
122
+ else:
123
+ pass
124
+ return stopwords_dict
125
+
126
+
127
+ def text_process(text, max_length):
128
+ """分割文本
129
+ """
130
+ texts = text_segmentate(text)
131
+
132
+ result, length = [], 0
133
+ for text in texts:
134
+ if length + len(text) > max_length * 1.3 and len(result) >= 3:
135
+ yield result
136
+ result, length = [], 0
137
+ result.append(text)
138
+ length += len(text)
139
+ if result and len(result) >= 3:
140
+ yield result
141
+
142
+
143
+ def text_process_split_long_content(text, max_length):
144
+ """分割长文本
145
+ """
146
+ texts = text_segmentate(text)
147
+
148
+ result, sentence_num = "", 0
149
+ for text in texts:
150
+ if len(text) > 500:
151
+ if len(result) > 300 and sentence_num >= 3:
152
+ yield result
153
+ result, sentence_num = "", 0
154
+ else:
155
+ result, sentence_num = "", 0
156
+ continue
157
+ else:
158
+ if len(result) + len(text) > max_length * 1.1 and sentence_num >= 3:
159
+ yield result
160
+ result, sentence_num = "", 0
161
+ result += text
162
+ sentence_num += 1
163
+
164
+ if result and sentence_num >= 3:
165
+ yield result
166
+
167
+
168
+ def gather_join(texts, idxs):
169
+ """取出对应的text,然后拼接起来
170
+ """
171
+ return ''.join([texts[i] for i in idxs])
172
+
173
+
174
+ def gather_join_f1(texts_token, idsx):
175
+ join_texts = []
176
+ for id in idsx:
177
+ join_texts.extend(texts_token[id])
178
+ return join_texts
179
+
180
+
181
+ def compute_rouge(source, target):
182
+ """计算rouge-1、rouge-2、rouge-l
183
+ """
184
+ source, target = ' '.join(source), ' '.join(target)
185
+ try:
186
+ scores = rouge.get_scores(hyps=source, refs=target)
187
+ return {
188
+ 'rouge-1': scores[0]['rouge-1']['f'],
189
+ 'rouge-2': scores[0]['rouge-2']['f'],
190
+ 'rouge-l': scores[0]['rouge-l']['f'],
191
+ }
192
+ except ValueError:
193
+ return {
194
+ 'rouge-1': 0.0,
195
+ 'rouge-2': 0.0,
196
+ 'rouge-l': 0.0,
197
+ }
198
+
199
+
200
+ def remove_stopwords(texts, stopwords_dict):
201
+ for i, text in enumerate(texts):
202
+ texts[i] = list(filter(lambda x: x not in stopwords_dict, text))
203
+ return texts
204
+
205
+
206
+ def pseudo_summary_f1(texts,
207
+ stopwords,
208
+ tokenizer,
209
+ max_length,
210
+ rouge_strategy="rouge-l"):
211
+ """构建伪标签摘要数据集
212
+ """
213
+ summary_rate = 0.25
214
+ max_length = max_length - 1
215
+ texts_tokens = []
216
+ sentece_idxs_vec = []
217
+ for text in texts:
218
+ if len(texts) == 0:
219
+ continue
220
+ try:
221
+ ids = tokenizer.encode(text.strip())[:-1]
222
+ except ValueError:
223
+ print("error, input : ", text)
224
+ raise ValueError
225
+ sentece_idxs_vec.append(ids)
226
+ tokens = [tokenizer._convert_id_to_token(token) for token in ids]
227
+ texts_tokens.append(tokens)
228
+
229
+ texts_tokens_rm = remove_stopwords(texts_tokens, stopwords)
230
+ source_idxs, target_idxs = list(range(len(texts))), []
231
+
232
+ assert len(texts_tokens) == len(texts)
233
+ # truncate_index = 0
234
+ while True:
235
+ sims = []
236
+ for i in source_idxs:
237
+ new_source_idxs = [j for j in source_idxs if j != i]
238
+ new_target_idxs = sorted(target_idxs + [i])
239
+ new_source = gather_join_f1(texts_tokens_rm, new_source_idxs)
240
+ new_target = gather_join_f1(texts_tokens_rm, new_target_idxs)
241
+ sim = compute_rouge(new_source, new_target)[rouge_strategy]
242
+ sims.append(sim)
243
+ new_idx = source_idxs[np.argmax(sims)]
244
+ del sims
245
+ source_idxs.remove(new_idx)
246
+ target_idxs = sorted(target_idxs + [new_idx])
247
+ source = gather_join(texts, source_idxs)
248
+ target = gather_join(texts, target_idxs)
249
+ try:
250
+ if (len(source_idxs) == 1
251
+ or 1.0 * len(target) / len(source) > summary_rate):
252
+ break
253
+ except ZeroDivisionError as e:
254
+ print(e.meesage)
255
+ print(texts)
256
+ print("source: ", source)
257
+ print("target: ", target)
258
+
259
+ if len(source) < len(target):
260
+ source, target = target, source
261
+ source_idxs, target_idxs = target_idxs, source_idxs
262
+
263
+ return sentece_idxs_vec, source, target, source_idxs, target_idxs
264
+
265
+
266
+ def get_input_mask(sentence_id_vec, indexs):
267
+ target_idxs = []
268
+ input_idxs = []
269
+ kMaskSentenceTokenId = 2
270
+ kEosTokenId = 1
271
+ mask_sentence_options_cumulative_prob = [0.9, 0.9, 1, 1]
272
+ for index in indexs:
273
+ target_idxs.extend(sentence_id_vec[index])
274
+ choice = random.uniform(0, 1)
275
+ if choice < mask_sentence_options_cumulative_prob[0]:
276
+ # print("mask index: ", index)
277
+ sentence_id_vec[index] = [kMaskSentenceTokenId]
278
+ elif choice < mask_sentence_options_cumulative_prob[1]:
279
+ # print("replace index: ", index)
280
+ replace_id = random.randint(0, len(sentence_id_vec))
281
+ sentence_id_vec[index] = sentence_id_vec[replace_id]
282
+ elif choice < mask_sentence_options_cumulative_prob[2]:
283
+ pass
284
+ else:
285
+ sentence_id_vec[index] = []
286
+
287
+ target_idxs.append(kEosTokenId)
288
+ # print(sentence_id_vec)
289
+ for index, sentence_id in enumerate(sentence_id_vec):
290
+ # print(index, sentence_id)
291
+ if len(sentence_id) == 0:
292
+ continue
293
+ input_idxs.extend(sentence_id_vec[index])
294
+
295
+ input_idxs.append(kEosTokenId)
296
+ return input_idxs, target_idxs
297
+
298
+
299
+ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int,
300
+ decoder_start_token_id: int):
301
+ """
302
+ Shift input ids one token to the right.
303
+ """
304
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
305
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
306
+ shifted_input_ids[:, 0] = decoder_start_token_id
307
+
308
+ if pad_token_id is None:
309
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
310
+ # replace possible -100 values in labels by `pad_token_id`
311
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
312
+
313
+ return shifted_input_ids
314
+
315
+
316
+ def padding_to_maxlength(ids, max_length, pad_id):
317
+ cur_len = len(ids)
318
+ len_diff = max_length - cur_len
319
+ return ids + [pad_id] * len_diff, [1] * cur_len + [0] * len_diff
special_tokens_map.json CHANGED
@@ -1,11 +1 @@
1
- {
2
- "additional_special_tokens": [
3
- "<mask_1>"
4
- ],
5
- "cls_token": "[CLS]",
6
- "eos_token": "</s>",
7
- "mask_token": "<mask_2>",
8
- "pad_token": "<pad>",
9
- "sep_token": "[SEP]",
10
- "unk_token": "<unk>"
11
- }
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
 
 
 
 
 
 
 
 
 
 
tokenizers_pegasus.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from data_utils import (
3
+ _is_control,
4
+ _is_punctuation,
5
+ _is_whitespace,
6
+ _is_chinese_char)
7
+ from transformers import PreTrainedTokenizer
8
+ from transformers import logging
9
+ from typing import List, Optional, Tuple, Union
10
+ import collections
11
+ import os
12
+ import unicodedata
13
+ import re
14
+ import jieba
15
+ import sys
16
+
17
+ sys.path.append("../../../../")
18
+
19
+ jieba.dt.tmp_dir = os.path.expanduser("~/.cache/")
20
+ # jieba.enable_parallel(8)
21
+ jieba.initialize()
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
26
+
27
+
28
+ def load_vocab(vocab_file):
29
+ """Loads a vocabulary file into a dictionary."""
30
+ vocab = collections.OrderedDict()
31
+ with open(vocab_file, "r", encoding="utf-8") as reader:
32
+ tokens = reader.readlines()
33
+ for index, token in enumerate(tokens):
34
+ token = token.rstrip("\n")
35
+ vocab[token] = index
36
+ return vocab
37
+
38
+
39
+ def whitespace_tokenize(text):
40
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
41
+ text = text.strip()
42
+ if not text:
43
+ return []
44
+ tokens = text.split()
45
+ return tokens
46
+
47
+
48
+ class PegasusTokenizer(PreTrainedTokenizer):
49
+ # copy from BertTokenizer
50
+ r"""
51
+ Construct a Pegasus tokenizer. Based on WordPiece.
52
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
53
+ this superclass for more information regarding those methods.
54
+ Args:
55
+ vocab_file (`str`):
56
+ File containing the vocabulary.
57
+ do_lower_case (`bool`, *optional*, defaults to `True`):
58
+ Whether or not to lowercase the input when tokenizing.
59
+ do_basic_tokenize (`bool`, *optional*, defaults to `True`):
60
+ Whether or not to do basic tokenization before WordPiece.
61
+ never_split (`Iterable`, *optional*):
62
+ Collection of tokens which will never be split during tokenization. Only has an effect when
63
+ `do_basic_tokenize=True`
64
+ unk_token (`str`, *optional*, defaults to `"[UNK]"`):
65
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
66
+ token instead.
67
+ sep_token (`str`, *optional*, defaults to `"[SEP]"`):
68
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
69
+ sequence classification or for a text and a question for question answering. It is also used as the last
70
+ token of a sequence built with special tokens.
71
+ pad_token (`str`, *optional*, defaults to `"[PAD]"`):
72
+ The token used for padding, for example when batching sequences of different lengths.
73
+ cls_token (`str`, *optional*, defaults to `"[CLS]"`):
74
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
75
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
76
+ mask_token (`str`, *optional*, defaults to `"[MASK]"`):
77
+ The token used for masking values. This is the token used when training this model with masked language
78
+ modeling. This is the token which the model will try to predict.
79
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
80
+ Whether or not to tokenize Chinese characters.
81
+ This should likely be deactivated for Japanese (see this
82
+ [issue](https://github.com/huggingface/transformers/issues/328)).
83
+ strip_accents (`bool`, *optional*):
84
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
85
+ value for `lowercase` (as in the original BERT).
86
+ """
87
+
88
+ vocab_files_names = VOCAB_FILES_NAMES
89
+ model_input_names = ["input_ids", "attention_mask"]
90
+
91
+ # pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
92
+ # pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
93
+ # max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
94
+
95
+ def __init__(self,
96
+ vocab_file,
97
+ do_lower_case=True,
98
+ do_basic_tokenize=True,
99
+ never_split=None,
100
+ pad_token="<pad>",
101
+ eos_token="</s>",
102
+ unk_token="<unk>",
103
+ mask_token="<mask_2>",
104
+ mask_token_sent="<mask_1>",
105
+ additional_special_tokens=None,
106
+ sep_token="[SEP]",
107
+ cls_token="[CLS]",
108
+ tokenize_chinese_chars=True,
109
+ strip_accents=None,
110
+ offset=100,
111
+ pre_tokenizer=lambda x: jieba.cut(x, HMM=False),
112
+ **kwargs):
113
+ self.offset = offset
114
+
115
+ if additional_special_tokens is not None:
116
+ if not isinstance(additional_special_tokens, list):
117
+ raise TypeError(
118
+ f"additional_special_tokens should be of type {type(list)}, \
119
+ but is {type(additional_special_tokens)}"
120
+ )
121
+
122
+ additional_special_tokens_extended = (
123
+ ([mask_token_sent] + additional_special_tokens)
124
+ if mask_token_sent not in additional_special_tokens
125
+ and mask_token_sent is not None else additional_special_tokens)
126
+
127
+ # fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
128
+ additional_special_tokens_extended += [
129
+ f"<unk_{i}>" for i in range(
130
+ len(additional_special_tokens_extended), self.offset - 1)
131
+ ]
132
+
133
+ if len(set(additional_special_tokens_extended)) != len(
134
+ additional_special_tokens_extended):
135
+ raise ValueError(
136
+ f"Please make sure that the provided additional_special_tokens \
137
+ do not contain an incorrectly shifted list of <unk_x> tokens. \
138
+ Found {additional_special_tokens_extended}."
139
+ )
140
+ additional_special_tokens = additional_special_tokens_extended
141
+ else:
142
+ additional_special_tokens = [
143
+ mask_token_sent
144
+ ] if mask_token_sent is not None else []
145
+ # additional_special_tokens += [f"<unk_{i}>" for i in range(3, self.offset)]
146
+
147
+ # print("additional_special_tokens: ", additional_special_tokens)
148
+
149
+ if not os.path.isfile(vocab_file):
150
+ raise ValueError(
151
+ f"Can't find a vocabulary file at path '{vocab_file}'. \
152
+ To load the vocabulary from a Google pretrained "
153
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
154
+ )
155
+
156
+ super().__init__(
157
+ do_lower_case=do_lower_case,
158
+ do_basic_tokenize=do_basic_tokenize,
159
+ never_split=never_split,
160
+ unk_token=unk_token,
161
+ sep_token=sep_token,
162
+ pad_token=pad_token,
163
+ cls_token=cls_token,
164
+ mask_token=mask_token,
165
+ eos_token=eos_token,
166
+ tokenize_chinese_chars=tokenize_chinese_chars,
167
+ additional_special_tokens=additional_special_tokens,
168
+ strip_accents=strip_accents,
169
+ **kwargs,
170
+ )
171
+
172
+ self.pre_tokenizer = pre_tokenizer
173
+ self.mask_token_sent = mask_token_sent
174
+ self.vocab = load_vocab(vocab_file)
175
+
176
+ self.vocab[self.eos_token] = self.vocab.pop("[unused1]")
177
+ # self.vocab[self.eos_token] = self.vocab.pop("[unused2]")
178
+ self.vocab[self.pad_token] = self.vocab.pop("[PAD]")
179
+ self.vocab[self.unk_token] = self.vocab.pop("[UNK]")
180
+
181
+ if self.mask_token_sent is not None:
182
+ self.vocab[self.mask_token] = self.vocab.pop("[unused3]")
183
+ self.vocab[self.mask_token_sent] = self.vocab.pop("[unused2]")
184
+
185
+ self.ids_to_tokens = collections.OrderedDict([
186
+ (ids, tok) for tok, ids in self.vocab.items()
187
+ ])
188
+ self.do_basic_tokenize = do_basic_tokenize
189
+ if do_basic_tokenize:
190
+ self.basic_tokenizer = BasicTokenizer(
191
+ do_lower_case=do_lower_case,
192
+ never_split=never_split,
193
+ tokenize_chinese_chars=tokenize_chinese_chars,
194
+ strip_accents=strip_accents,
195
+ )
196
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab,
197
+ unk_token=self.unk_token)
198
+
199
+ @property
200
+ def do_lower_case(self):
201
+ return self.basic_tokenizer.do_lower_case
202
+
203
+ @property
204
+ def vocab_size(self):
205
+ return len(self.vocab)
206
+
207
+ def get_vocab(self):
208
+ return dict(self.vocab, **self.added_tokens_encoder)
209
+
210
+ def _tokenize(self, text):
211
+ split_tokens = []
212
+ # print("pegasus_tokenizer: ", text)
213
+ for text in self.pre_tokenizer(text):
214
+ if text in self.vocab:
215
+ split_tokens.append(text)
216
+ else:
217
+ if self.do_basic_tokenize:
218
+ for token in self.basic_tokenizer.tokenize(
219
+ text, never_split=self.all_special_tokens):
220
+
221
+ # If the token is part of the never_split set
222
+ if token in self.basic_tokenizer.never_split:
223
+ split_tokens.append(token)
224
+ else:
225
+ split_tokens += self.wordpiece_tokenizer.tokenize(
226
+ token)
227
+ else:
228
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
229
+ return split_tokens
230
+
231
+ def _convert_token_to_id(self, token):
232
+ """Converts a token (str) in an id using the vocab."""
233
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
234
+
235
+ def _convert_id_to_token(self, index):
236
+ """Converts an index (integer) in a token (str) using the vocab."""
237
+ return self.ids_to_tokens.get(index, self.unk_token)
238
+
239
+ @staticmethod
240
+ def _cjk_punctuation():
241
+ return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\
242
+ \uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\
243
+ \uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\
244
+ \u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\
245
+ \u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'
246
+
247
+ def convert_ids_to_tokens(
248
+ self,
249
+ ids: Union[int, List[int]],
250
+ skip_special_tokens: bool = False) -> Union[str, List[str]]:
251
+ """
252
+ Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
253
+ added tokens.
254
+ Args:
255
+ ids (`int` or `List[int]`):
256
+ The token id (or token ids) to convert to tokens.
257
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
258
+ Whether or not to remove special tokens in the decoding.
259
+ Returns:
260
+ `str` or `List[str]`: The decoded token(s).
261
+ """
262
+ if isinstance(ids, int):
263
+ if ids in self.added_tokens_decoder:
264
+ return self.added_tokens_decoder[ids]
265
+ else:
266
+ return self._convert_id_to_token(ids)
267
+ tokens = []
268
+ for index in ids:
269
+ index = int(index)
270
+ if skip_special_tokens and index in self.all_special_ids and index != 2:
271
+ continue
272
+ if index in self.added_tokens_decoder:
273
+ tokens.append(self.added_tokens_decoder[index])
274
+ else:
275
+ tokens.append(self._convert_id_to_token(index))
276
+ return tokens
277
+
278
+ def convert_tokens_to_string(self, tokens):
279
+ """Converts a sequence of tokens (string) in a single string."""
280
+ # for token in
281
+ # tokens = tokens or self.ids_to_tokens(ids)
282
+ # tokens = [token for token in tokens if not self._is_special(token)]
283
+
284
+ text = ''
285
+ for i, token in enumerate(tokens):
286
+ if token[:2] == '##':
287
+ text += token[2:]
288
+ elif len(token) == 1 and _is_chinese_char(ord(token)):
289
+ text += token
290
+ elif len(token) == 1 and _is_punctuation(token):
291
+ text += token
292
+ text += ' '
293
+ elif i > 0 and _is_chinese_char(ord(text[-1])):
294
+ text += token
295
+ elif tokens == "</s>":
296
+ continue
297
+ else:
298
+ text += ' '
299
+ text += token
300
+
301
+ text = re.sub(' +', ' ', text)
302
+ text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
303
+ punctuation = re.sub(' +', '', self._cjk_punctuation()).strip() + '+-/={(<['
304
+ punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
305
+ punctuation_regex = '(%s) ' % punctuation_regex
306
+ text = re.sub(punctuation_regex, '\\1', text)
307
+ text = re.sub(r'(\d\.) (\d)', '\\1\\2', text)
308
+
309
+ return text.strip()
310
+ # out_string = " ".join(tokens).replace(" ##", "").strip()
311
+
312
+ def build_inputs_with_special_tokens(
313
+ self,
314
+ token_ids_0: List[int],
315
+ token_ids_1: Optional[List[int]] = None) -> List[int]:
316
+ """
317
+ Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
318
+ and adding special tokens. A PEGASUS sequence has the following format, where `X` represents the sequence:
319
+ - single sequence: `X </s>`
320
+ - pair of sequences: `A B </s>` (not intended use)
321
+ BOS is never used. Pairs of sequences are not the expected use case, but they will be handled without a
322
+ separator.
323
+ Args:
324
+ token_ids_0 (`List[int]`):
325
+ List of IDs to which the special tokens will be added.
326
+ token_ids_1 (`List[int]`, *optional*):
327
+ Optional second list of IDs for sequence pairs.
328
+ Returns:
329
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
330
+ """
331
+ if token_ids_1 is None:
332
+ return token_ids_0 + [self.eos_token_id]
333
+ return token_ids_0 + token_ids_1 + [self.eos_token_id]
334
+
335
+ def _special_token_mask(self, seq):
336
+ all_special_ids = set(
337
+ self.all_special_ids) # call it once instead of inside list comp
338
+ # all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
339
+
340
+ return [1 if x in all_special_ids else 0 for x in seq]
341
+
342
+ def get_special_tokens_mask(
343
+ self,
344
+ token_ids_0: List[int],
345
+ token_ids_1: Optional[List[int]] = None,
346
+ already_has_special_tokens: bool = False) -> List[int]:
347
+ """
348
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
349
+ special tokens using the tokenizer `prepare_for_model` method.
350
+ Args:
351
+ token_ids_0 (`List[int]`):
352
+ List of IDs.
353
+ token_ids_1 (`List[int]`, *optional*):
354
+ Optional second list of IDs for sequence pairs.
355
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
356
+ Whether or not the token list is already formatted with special tokens for the model.
357
+ Returns:
358
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
359
+ """
360
+
361
+ if already_has_special_tokens:
362
+ return self._special_token_mask(token_ids_0)
363
+ elif token_ids_1 is None:
364
+ return self._special_token_mask(token_ids_0) + [self.eos_token_id]
365
+ else:
366
+ return self._special_token_mask(token_ids_0 +
367
+ token_ids_1) + [self.eos_token_id]
368
+
369
+ def num_special_tokens_to_add(self, pair=False):
370
+ """Just EOS"""
371
+ return 1
372
+
373
+ def save_vocabulary(self,
374
+ save_directory: str,
375
+ filename_prefix: Optional[str] = None) -> Tuple[str]:
376
+ index = 0
377
+ if os.path.isdir(save_directory):
378
+ vocab_file = os.path.join(
379
+ save_directory,
380
+ (filename_prefix + "-" if filename_prefix else "") +
381
+ VOCAB_FILES_NAMES["vocab_file"])
382
+ else:
383
+ vocab_file = (filename_prefix +
384
+ "-" if filename_prefix else "") + save_directory
385
+ with open(vocab_file, "w", encoding="utf-8") as writer:
386
+ for token, token_index in sorted(self.vocab.items(),
387
+ key=lambda kv: kv[1]):
388
+ if index != token_index:
389
+ logger.warning(
390
+ f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
391
+ " Please check that the vocabulary is not corrupted!")
392
+ index = token_index
393
+ writer.write(token + "\n")
394
+ index += 1
395
+ return (vocab_file, )
396
+
397
+
398
+ class BasicTokenizer(object):
399
+ """
400
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
401
+ Args:
402
+ do_lower_case (`bool`, *optional*, defaults to `True`):
403
+ Whether or not to lowercase the input when tokenizing.
404
+ never_split (`Iterable`, *optional*):
405
+ Collection of tokens which will never be split during tokenization. Only has an effect when
406
+ `do_basic_tokenize=True`
407
+ tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
408
+ Whether or not to tokenize Chinese characters.
409
+ This should likely be deactivated for Japanese (see this
410
+ [issue](https://github.com/huggingface/transformers/issues/328)).
411
+ strip_accents: (`bool`, *optional*):
412
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
413
+ value for `lowercase` (as in the original BERT).
414
+ """
415
+
416
+ def __init__(self,
417
+ do_lower_case=True,
418
+ never_split=None,
419
+ tokenize_chinese_chars=True,
420
+ strip_accents=None):
421
+ if never_split is None:
422
+ never_split = []
423
+ self.do_lower_case = do_lower_case
424
+ self.never_split = set(never_split)
425
+ self.tokenize_chinese_chars = tokenize_chinese_chars
426
+ self.strip_accents = strip_accents
427
+
428
+ def tokenize(self, text, never_split=None):
429
+ """
430
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
431
+ WordPieceTokenizer.
432
+ Args:
433
+ never_split (`List[str]`, *optional*)
434
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
435
+ [`PreTrainedTokenizer.tokenize`]) List of token not to split.
436
+ """
437
+ # union() returns a new set by concatenating the two sets.
438
+ never_split = self.never_split.union(
439
+ set(never_split)) if never_split else self.never_split
440
+ text = self._clean_text(text)
441
+
442
+ # This was added on November 1st, 2018 for the multilingual and Chinese
443
+ # models. This is also applied to the English models now, but it doesn't
444
+ # matter since the English models were not trained on any Chinese data
445
+ # and generally don't have any Chinese data in them (there are Chinese
446
+ # characters in the vocabulary because Wikipedia does have some Chinese
447
+ # words in the English Wikipedia.).
448
+ if self.tokenize_chinese_chars:
449
+ text = self._tokenize_chinese_chars(text)
450
+ orig_tokens = whitespace_tokenize(text)
451
+ split_tokens = []
452
+ for token in orig_tokens:
453
+ if token not in never_split:
454
+ if self.do_lower_case:
455
+ token = token.lower()
456
+ if self.strip_accents is not False:
457
+ token = self._run_strip_accents(token)
458
+ elif self.strip_accents:
459
+ token = self._run_strip_accents(token)
460
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
461
+
462
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
463
+ return output_tokens
464
+
465
+ def _run_strip_accents(self, text):
466
+ """Strips accents from a piece of text."""
467
+ text = unicodedata.normalize("NFD", text)
468
+ output = []
469
+ for char in text:
470
+ cat = unicodedata.category(char)
471
+ if cat == "Mn":
472
+ continue
473
+ output.append(char)
474
+ return "".join(output)
475
+
476
+ def _run_split_on_punc(self, text, never_split=None):
477
+ """Splits punctuation on a piece of text."""
478
+ if never_split is not None and text in never_split:
479
+ return [text]
480
+ chars = list(text)
481
+ i = 0
482
+ start_new_word = True
483
+ output = []
484
+ while i < len(chars):
485
+ char = chars[i]
486
+ if _is_punctuation(char):
487
+ output.append([char])
488
+ start_new_word = True
489
+ else:
490
+ if start_new_word:
491
+ output.append([])
492
+ start_new_word = False
493
+ output[-1].append(char)
494
+ i += 1
495
+
496
+ return ["".join(x) for x in output]
497
+
498
+ def _tokenize_chinese_chars(self, text):
499
+ """Adds whitespace around any CJK character."""
500
+ output = []
501
+ for char in text:
502
+ cp = ord(char)
503
+ if self._is_chinese_char(cp):
504
+ output.append(" ")
505
+ output.append(char)
506
+ output.append(" ")
507
+ else:
508
+ output.append(char)
509
+ return "".join(output)
510
+
511
+ def _is_chinese_char(self, cp):
512
+ """Checks whether CP is the codepoint of a CJK character."""
513
+ # This defines a "chinese character" as anything in the CJK Unicode block:
514
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
515
+ #
516
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
517
+ # despite its name. The modern Korean Hangul alphabet is a different block,
518
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
519
+ # space-separated words, so they are not treated specially and handled
520
+ # like the all of the other languages.
521
+ if ((cp >= 0x4E00 and cp <= 0x9FFF)
522
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
523
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
524
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
525
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
526
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
527
+ or (cp >= 0xF900 and cp <= 0xFAFF)
528
+ or (cp >= 0x2F800 and cp <= 0x2FA1F)): #
529
+ return True
530
+
531
+ return False
532
+
533
+ def _clean_text(self, text):
534
+ """Performs invalid character removal and whitespace cleanup on text."""
535
+ output = []
536
+ for char in text:
537
+ cp = ord(char)
538
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
539
+ continue
540
+ if _is_whitespace(char):
541
+ output.append(" ")
542
+ else:
543
+ output.append(char)
544
+ return "".join(output)
545
+
546
+
547
+ class WordpieceTokenizer(object):
548
+ """Runs WordPiece tokenization."""
549
+
550
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
551
+ self.vocab = vocab
552
+ self.unk_token = unk_token
553
+ self.max_input_chars_per_word = max_input_chars_per_word
554
+
555
+ def tokenize(self, text):
556
+ """
557
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
558
+ tokenization using the given vocabulary.
559
+ For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
560
+ Args:
561
+ text: A single token or whitespace separated tokens. This should have
562
+ already been passed through *BasicTokenizer*.
563
+ Returns:
564
+ A list of wordpiece tokens.
565
+ """
566
+
567
+ output_tokens = []
568
+ for token in whitespace_tokenize(text):
569
+ chars = list(token)
570
+ if len(chars) > self.max_input_chars_per_word:
571
+ output_tokens.append(self.unk_token)
572
+ continue
573
+
574
+ is_bad = False
575
+ start = 0
576
+ sub_tokens = []
577
+ while start < len(chars):
578
+ end = len(chars)
579
+ cur_substr = None
580
+ while start < end:
581
+ substr = "".join(chars[start:end])
582
+ if start > 0:
583
+ substr = "##" + substr
584
+ if substr in self.vocab:
585
+ cur_substr = substr
586
+ break
587
+ end -= 1
588
+ if cur_substr is None:
589
+ is_bad = True
590
+ break
591
+ sub_tokens.append(cur_substr)
592
+ start = end
593
+
594
+ if is_bad:
595
+ output_tokens.append(self.unk_token)
596
+ else:
597
+ output_tokens.extend(sub_tokens)
598
+ return output_tokens
vocab.txt CHANGED
@@ -1,7 +1,7 @@
1
- <pad>
2
- </s>
3
- <mask_1>
4
- <mask_2>
5
  [unused4]
6
  [unused5]
7
  [unused6]
@@ -98,7 +98,7 @@
98
  [unused97]
99
  [unused98]
100
  [unused99]
101
- <unk>
102
  [CLS]
103
  [SEP]
104
  [MASK]
 
1
+ [PAD]
2
+ [unused1]
3
+ [unused2]
4
+ [unused3]
5
  [unused4]
6
  [unused5]
7
  [unused6]
 
98
  [unused97]
99
  [unused98]
100
  [unused99]
101
+ [UNK]
102
  [CLS]
103
  [SEP]
104
  [MASK]