DreamGenX winglian commited on
Commit
25e037f
·
unverified ·
1 Parent(s): 52c83d3

Support for additional_special_tokens (#1221) [skip ci]

Browse files

* Support for additional_special_tokens

* Support for additional_special_tokens. Adjust whitespace.

* Support for additional_special_tokens. Use correct quotes.

* Support for additional_special_tokens. Safe pop.

* Support for additional_special_tokens. nt.

* Support for additional_special_tokens. cfg.special_tokens may be None.

* add token if not in vocabulary when adding additional_special_tokens

* fix logic for copy/pasta

* bugfix for popping from config and tokenizer reload

* no need to add tokens manually now with previous bugfix

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/utils/models.py CHANGED
@@ -161,15 +161,20 @@ def load_tokenizer(cfg):
161
  if getattr(tokenizer, attr_name) is None:
162
  setattr(tokenizer, attr_name, "<|endoftext|>")
163
 
 
164
  if cfg.special_tokens:
 
 
 
 
165
  lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
166
- for k, val in cfg.special_tokens.items():
167
  # check if new special token is not already in tokenizer and
168
  # is adapter training to make sure lora_modules_to_save is set
169
  # pylint: disable=too-many-boolean-expressions
170
  if (
171
  (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
172
- and (len(tokenizer.encode(val)) > 1)
173
  and cfg.adapter
174
  and (
175
  not cfg.lora_modules_to_save
@@ -213,6 +218,21 @@ def load_tokenizer(cfg):
213
  ]
214
  )
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
217
  LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
218
  LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
 
161
  if getattr(tokenizer, attr_name) is None:
162
  setattr(tokenizer, attr_name, "<|endoftext|>")
163
 
164
+ additional_special_tokens = None
165
  if cfg.special_tokens:
166
+ special_tokens = cfg.special_tokens.to_dict()
167
+ additional_special_tokens = special_tokens.pop(
168
+ "additional_special_tokens", None
169
+ )
170
  lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
171
+ for k, val in special_tokens.items():
172
  # check if new special token is not already in tokenizer and
173
  # is adapter training to make sure lora_modules_to_save is set
174
  # pylint: disable=too-many-boolean-expressions
175
  if (
176
  (getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
177
+ and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
178
  and cfg.adapter
179
  and (
180
  not cfg.lora_modules_to_save
 
218
  ]
219
  )
220
 
221
+ # Additional special tokens are a List, and need to be treated differently than regular special
222
+ # tokens. We add them after we have called `add_tokens` in case these additional special tokens
223
+ # are new tokens.
224
+ #
225
+ # Usage:
226
+ #
227
+ # ```py
228
+ # special_tokens:
229
+ # additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
230
+ # ```
231
+ if additional_special_tokens is not None:
232
+ tokenizer.add_special_tokens(
233
+ {"additional_special_tokens": additional_special_tokens}
234
+ )
235
+
236
  LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
237
  LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
238
  LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
tests/test_tokenizers.py CHANGED
@@ -67,6 +67,21 @@ class TestTokenizers(unittest.TestCase):
67
  )
68
  load_tokenizer(cfg)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  if __name__ == "__main__":
72
  unittest.main()
 
67
  )
68
  load_tokenizer(cfg)
69
 
70
+ def test_add_additional_special_tokens(self):
71
+ cfg = DictDefault(
72
+ {
73
+ "tokenizer_config": "huggyllama/llama-7b",
74
+ "special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
75
+ }
76
+ )
77
+ tokenizer = load_tokenizer(cfg)
78
+ self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
79
+ self.assertEqual(len(tokenizer), 32001)
80
+
81
+ # ensure reloading the tokenizer again from cfg results in same vocab length
82
+ tokenizer = load_tokenizer(cfg)
83
+ self.assertEqual(len(tokenizer), 32001)
84
+
85
 
86
  if __name__ == "__main__":
87
  unittest.main()