winglian commited on
Commit
47d601f
1 Parent(s): 756dfba

optionally define whether to use_fast tokenizer

Browse files
README.md CHANGED
@@ -302,6 +302,8 @@ model_type: AutoModelForCausalLM
302
  tokenizer_type: AutoTokenizer
303
  # Trust remote code for untrusted source
304
  trust_remote_code:
 
 
305
 
306
  # whether you are training a 4-bit GPTQ quantized model
307
  gptq: true
 
302
  tokenizer_type: AutoTokenizer
303
  # Trust remote code for untrusted source
304
  trust_remote_code:
305
+ # use_fast option for tokenizer loading from_pretrained, default to True
306
+ tokenizer_use_fast:
307
 
308
  # whether you are training a 4-bit GPTQ quantized model
309
  gptq: true
src/axolotl/utils/models.py CHANGED
@@ -34,15 +34,20 @@ def load_tokenizer(
34
  tokenizer_type,
35
  cfg,
36
  ):
 
 
 
37
  if tokenizer_type:
38
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
39
  tokenizer_config,
40
  trust_remote_code=cfg.trust_remote_code or False,
 
41
  )
42
  else:
43
  tokenizer = AutoTokenizer.from_pretrained(
44
  tokenizer_config,
45
  trust_remote_code=cfg.trust_remote_code or False,
 
46
  )
47
 
48
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
 
34
  tokenizer_type,
35
  cfg,
36
  ):
37
+ use_fast = True # this is the default
38
+ if cfg.tokenizer_use_fast is not None:
39
+ use_fast = cfg.tokenizer_use_fast
40
  if tokenizer_type:
41
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
42
  tokenizer_config,
43
  trust_remote_code=cfg.trust_remote_code or False,
44
+ use_fast=use_fast,
45
  )
46
  else:
47
  tokenizer = AutoTokenizer.from_pretrained(
48
  tokenizer_config,
49
  trust_remote_code=cfg.trust_remote_code or False,
50
+ use_fast=use_fast,
51
  )
52
 
53
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
tests/test_tokenizers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test cases for the tokenizer loading
3
+ """
4
+ import unittest
5
+
6
+ from axolotl.utils.dict import DictDefault
7
+ from axolotl.utils.models import load_tokenizer
8
+
9
+
10
+ class TestTokenizers(unittest.TestCase):
11
+ """
12
+ test class for the load_tokenizer fn
13
+ """
14
+
15
+ def test_default_use_fast(self):
16
+ cfg = DictDefault({})
17
+ tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
18
+ assert "Fast" in tokenizer.__class__.__name__
19
+
20
+ def test_dont_use_fast(self):
21
+ cfg = DictDefault(
22
+ {
23
+ "tokenizer_use_fast": False,
24
+ }
25
+ )
26
+ tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
27
+ assert "Fast" not in tokenizer.__class__.__name__
28
+
29
+
30
+ if __name__ == "__main__":
31
+ unittest.main()