pranjalchitale commited on
Commit
5088840
1 Parent(s): 3958b6a

Update tokenization_indictrans.py

Browse files
Files changed (1) hide show
  1. tokenization_indictrans.py +42 -20
tokenization_indictrans.py CHANGED
@@ -11,7 +11,10 @@ from transformers.tokenization_utils import PreTrainedTokenizer
11
  logger = logging.get_logger(__name__)
12
 
13
  SPIECE_UNDERLINE = "▁"
14
- SUPPORTED_LANGUAGES = [
 
 
 
15
  "asm_Beng",
16
  "awa_Deva",
17
  "ben_Beng",
@@ -46,7 +49,7 @@ SUPPORTED_LANGUAGES = [
46
  "tel_Telu",
47
  "urd_Arab",
48
  "unr_Deva",
49
- ]
50
 
51
  VOCAB_FILES_NAMES = {
52
  "src_vocab_fp": "dict.SRC.json",
@@ -74,7 +77,7 @@ class IndicTransTokenizer(PreTrainedTokenizer):
74
  eos_token="</s>",
75
  pad_token="<pad>",
76
  do_lower_case=False,
77
- **kwargs
78
  ):
79
 
80
  self.src = True
@@ -124,7 +127,10 @@ class IndicTransTokenizer(PreTrainedTokenizer):
124
  pad_token=pad_token,
125
  **kwargs,
126
  )
127
-
 
 
 
128
  def _switch_to_input_mode(self):
129
  self.src = True
130
  self.padding_side = "left"
@@ -150,6 +156,16 @@ class IndicTransTokenizer(PreTrainedTokenizer):
150
  with open(path, "r", encoding="utf-8") as f:
151
  return json.load(f)
152
 
 
 
 
 
 
 
 
 
 
 
153
  @property
154
  def src_vocab_size(self) -> int:
155
  return len(self.encoder)
@@ -183,27 +199,31 @@ class IndicTransTokenizer(PreTrainedTokenizer):
183
 
184
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
185
  """Uses sentencepiece model for detokenization"""
186
- pad_tokens = [token for token in tokens if token == self.pad_token]
187
- tokens = [token for token in tokens if token != self.pad_token]
188
  if self.src:
 
 
 
189
  return (
190
- " ".join(pad_tokens)
191
  + " "
192
- + " ".join(tokens[:2])
193
  + " "
194
- + "".join(tokens[2:]).replace(SPIECE_UNDERLINE, " ").strip()
195
  )
 
196
  return (
197
  "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
198
  + " "
199
- + " ".join(pad_tokens)
200
  )
201
 
202
  def _tokenize(self, text) -> List[str]:
203
  if self.src:
204
  tokens = text.split(" ")
205
- tags = tokens[:2]
206
- text = " ".join(tokens[2:])
207
  tokens = self.current_spm.EncodeAsPieces(text)
208
  return tags + tokens
209
  else:
@@ -217,23 +237,25 @@ class IndicTransTokenizer(PreTrainedTokenizer):
217
  # We don't expect to process pairs, but leave the pair logic for API consistency
218
  return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
219
 
220
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
 
 
221
  if not os.path.isdir(save_directory):
222
  logger.error(f"Vocabulary path ({save_directory}) should be a directory")
223
  return
224
-
225
  src_spm_fp = os.path.join(save_directory, "model.SRC")
226
  tgt_spm_fp = os.path.join(save_directory, "model.TGT")
227
  src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
228
  tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
229
-
230
  self._save_json(self.encoder, src_vocab_fp)
231
  self._save_json(self.decoder, tgt_vocab_fp)
232
-
233
- with open(src_spm_fp, 'wb') as f:
234
  f.write(self.src_spm.serialized_model_proto())
235
-
236
- with open(tgt_spm_fp, 'wb') as f:
237
  f.write(self.tgt_spm.serialized_model_proto())
238
 
239
- return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp
 
11
  logger = logging.get_logger(__name__)
12
 
13
  SPIECE_UNDERLINE = "▁"
14
+
15
+ SPECIAL_TAGS = {
16
+ "_bt_",
17
+ "_ft_",
18
  "asm_Beng",
19
  "awa_Deva",
20
  "ben_Beng",
 
49
  "tel_Telu",
50
  "urd_Arab",
51
  "unr_Deva",
52
+ }
53
 
54
  VOCAB_FILES_NAMES = {
55
  "src_vocab_fp": "dict.SRC.json",
 
77
  eos_token="</s>",
78
  pad_token="<pad>",
79
  do_lower_case=False,
80
+ **kwargs,
81
  ):
82
 
83
  self.src = True
 
127
  pad_token=pad_token,
128
  **kwargs,
129
  )
130
+
131
+ def add_new_special_tags(self, new_tags: List[str]):
132
+ SPECIAL_TAGS.update(new_tags)
133
+
134
  def _switch_to_input_mode(self):
135
  self.src = True
136
  self.padding_side = "left"
 
156
  with open(path, "r", encoding="utf-8") as f:
157
  return json.load(f)
158
 
159
+ def _split_tags(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
160
+ tags = [token for token in tokens if token in SPECIAL_TAGS]
161
+ tokens = [token for token in tokens if token not in SPECIAL_TAGS]
162
+ return tags, tokens
163
+
164
+ def _split_pads(self, tokens: List[str]) -> Tuple[List[str], List[str]]:
165
+ pads = [token for token in tokens if token == self.pad_token]
166
+ tokens = [token for token in tokens if token != self.pad_token]
167
+ return pads, tokens
168
+
169
  @property
170
  def src_vocab_size(self) -> int:
171
  return len(self.encoder)
 
199
 
200
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
201
  """Uses sentencepiece model for detokenization"""
202
+ pads, tokens = self._split_pads(tokens)
203
+
204
  if self.src:
205
+
206
+ tags, non_tags = self._split_tags(tokens)
207
+
208
  return (
209
+ " ".join(pads)
210
  + " "
211
+ + " ".join(tags)
212
  + " "
213
+ + "".join(non_tags).replace(SPIECE_UNDERLINE, " ").strip()
214
  )
215
+
216
  return (
217
  "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
218
  + " "
219
+ + " ".join(pads)
220
  )
221
 
222
  def _tokenize(self, text) -> List[str]:
223
  if self.src:
224
  tokens = text.split(" ")
225
+ tags, non_tags = self._split_tags(tokens)
226
+ text = " ".join(non_tags)
227
  tokens = self.current_spm.EncodeAsPieces(text)
228
  return tags + tokens
229
  else:
 
237
  # We don't expect to process pairs, but leave the pair logic for API consistency
238
  return token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
239
 
240
+ def save_vocabulary(
241
+ self, save_directory: str, filename_prefix: Optional[str] = None
242
+ ) -> Tuple[str]:
243
  if not os.path.isdir(save_directory):
244
  logger.error(f"Vocabulary path ({save_directory}) should be a directory")
245
  return
246
+
247
  src_spm_fp = os.path.join(save_directory, "model.SRC")
248
  tgt_spm_fp = os.path.join(save_directory, "model.TGT")
249
  src_vocab_fp = os.path.join(save_directory, "dict.SRC.json")
250
  tgt_vocab_fp = os.path.join(save_directory, "dict.TGT.json")
251
+
252
  self._save_json(self.encoder, src_vocab_fp)
253
  self._save_json(self.decoder, tgt_vocab_fp)
254
+
255
+ with open(src_spm_fp, "wb") as f:
256
  f.write(self.src_spm.serialized_model_proto())
257
+
258
+ with open(tgt_spm_fp, "wb") as f:
259
  f.write(self.tgt_spm.serialized_model_proto())
260
 
261
+ return src_vocab_fp, tgt_vocab_fp, src_spm_fp, tgt_spm_fp