Upload csc_tokenizer.py
Browse files- csc_tokenizer.py +19 -46
csc_tokenizer.py
CHANGED
@@ -108,7 +108,7 @@ class ChineseBertTokenizer(BertTokenizerFast):
|
|
108 |
return_token_type_ids=return_token_type_ids,
|
109 |
return_attention_mask=return_attention_mask,
|
110 |
return_overflowing_tokens=return_overflowing_tokens,
|
111 |
-
return_offsets_mapping=
|
112 |
return_length=return_length,
|
113 |
verbose=verbose,
|
114 |
)
|
@@ -117,61 +117,34 @@ class ChineseBertTokenizer(BertTokenizerFast):
|
|
117 |
|
118 |
pinyin_ids = None
|
119 |
if type(text) == str:
|
120 |
-
|
|
|
|
|
121 |
|
122 |
-
if type(text) == list:
|
123 |
pinyin_ids = []
|
124 |
-
for
|
125 |
-
|
|
|
|
|
126 |
|
127 |
if torch.is_tensor(encoding.input_ids):
|
128 |
pinyin_ids = torch.LongTensor(pinyin_ids)
|
129 |
|
130 |
encoding['pinyin_ids'] = pinyin_ids
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
def tokenize_sentence(self, sentence):
|
135 |
-
# convert sentence to ids
|
136 |
-
tokenizer_output = self.tokenizer.encode(sentence)
|
137 |
-
bert_tokens = tokenizer_output.ids
|
138 |
-
pinyin_tokens = self.convert_sentence_to_pinyin_ids(sentence, tokenizer_output)
|
139 |
-
# assert,token nums should be same as pinyin token nums
|
140 |
-
assert len(bert_tokens) <= self.max_length
|
141 |
-
assert len(bert_tokens) == len(pinyin_tokens)
|
142 |
-
# convert list to tensor
|
143 |
-
input_ids = torch.LongTensor(bert_tokens)
|
144 |
-
pinyin_ids = torch.LongTensor(pinyin_tokens).view(-1)
|
145 |
-
return input_ids, pinyin_ids
|
146 |
-
|
147 |
-
def convert_ids_to_pinyin_ids(self, ids: List[int]):
|
148 |
-
pinyin_ids = []
|
149 |
-
tokens = self.convert_ids_to_tokens(ids)
|
150 |
-
for token in tokens:
|
151 |
-
if len(token) > 1:
|
152 |
-
pinyin_ids.append([0] * 8)
|
153 |
-
continue
|
154 |
-
|
155 |
-
pinyin_string = pinyin(token, style=Style.TONE3, errors=lambda x: [['not chinese'] for _ in x])[0][0]
|
156 |
-
|
157 |
-
if pinyin_string == "not chinese":
|
158 |
-
pinyin_ids.append([0] * 8)
|
159 |
-
continue
|
160 |
|
161 |
-
|
162 |
-
pinyin_ids.append(self.pinyin2tensor[pinyin_string])
|
163 |
-
else:
|
164 |
-
ids = [0] * 8
|
165 |
-
for i, p in enumerate(pinyin_string):
|
166 |
-
if p not in self.pinyin_dict["char2idx"]:
|
167 |
-
ids = [0] * 8
|
168 |
-
break
|
169 |
-
ids[i] = self.pinyin_dict["char2idx"][p]
|
170 |
-
pinyin_ids.append(pinyin_ids)
|
171 |
|
172 |
-
|
|
|
|
|
|
|
|
|
173 |
|
174 |
-
def convert_sentence_to_pinyin_ids(self, sentence: str,
|
175 |
# get pinyin of a sentence
|
176 |
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
|
177 |
pinyin_locs = {}
|
@@ -194,7 +167,7 @@ class ChineseBertTokenizer(BertTokenizerFast):
|
|
194 |
|
195 |
# find chinese character location, and generate pinyin ids
|
196 |
pinyin_ids = []
|
197 |
-
for idx, (token, offset) in enumerate(zip(
|
198 |
if offset[1] - offset[0] != 1:
|
199 |
pinyin_ids.append([0] * 8)
|
200 |
continue
|
|
|
108 |
return_token_type_ids=return_token_type_ids,
|
109 |
return_attention_mask=return_attention_mask,
|
110 |
return_overflowing_tokens=return_overflowing_tokens,
|
111 |
+
return_offsets_mapping=True,
|
112 |
return_length=return_length,
|
113 |
verbose=verbose,
|
114 |
)
|
|
|
117 |
|
118 |
pinyin_ids = None
|
119 |
if type(text) == str:
|
120 |
+
offsets = encoding.offset_mapping[0].tolist()
|
121 |
+
tokens = self.sentence_to_tokens(text, offsets)
|
122 |
+
pinyin_ids = [self.convert_sentence_to_pinyin_ids(text, tokens, offsets)]
|
123 |
|
124 |
+
if type(text) == list or type(text) == tuple:
|
125 |
pinyin_ids = []
|
126 |
+
for i, sentence in enumerate(text):
|
127 |
+
offsets = encoding.offset_mapping[i].tolist()
|
128 |
+
tokens = self.sentence_to_tokens(sentence, offsets)
|
129 |
+
pinyin_ids.append(self.convert_sentence_to_pinyin_ids(sentence, tokens, offsets))
|
130 |
|
131 |
if torch.is_tensor(encoding.input_ids):
|
132 |
pinyin_ids = torch.LongTensor(pinyin_ids)
|
133 |
|
134 |
encoding['pinyin_ids'] = pinyin_ids
|
135 |
|
136 |
+
if not return_offsets_mapping:
|
137 |
+
del encoding['offset_mapping']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
+
return encoding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
def sentence_to_tokens(self, sentence, offsets):
|
142 |
+
tokens = []
|
143 |
+
for start, end in offsets:
|
144 |
+
tokens.append(sentence[start:end])
|
145 |
+
return tokens
|
146 |
|
147 |
+
def convert_sentence_to_pinyin_ids(self, sentence: str, tokens, offsets):
|
148 |
# get pinyin of a sentence
|
149 |
pinyin_list = pinyin(sentence, style=Style.TONE3, heteronym=True, errors=lambda x: [['not chinese'] for _ in x])
|
150 |
pinyin_locs = {}
|
|
|
167 |
|
168 |
# find chinese character location, and generate pinyin ids
|
169 |
pinyin_ids = []
|
170 |
+
for idx, (token, offset) in enumerate(zip(tokens, offsets)):
|
171 |
if offset[1] - offset[0] != 1:
|
172 |
pinyin_ids.append([0] * 8)
|
173 |
continue
|