Nanobit commited on
Commit
ff939d8
·
unverified ·
1 Parent(s): 324d59e

fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path (#1298)

Browse files

* fix(dataset): normalize tokenizer config and change hash from tokenizer class to tokenizer path

* fix: normalize config

src/axolotl/utils/config/__init__.py CHANGED
@@ -119,6 +119,10 @@ def normalize_config(cfg):
119
  model_config = load_model_config(cfg)
120
  cfg.model_config_type = model_config.model_type
121
 
 
 
 
 
122
  # figure out if the model is llama
123
  cfg.is_llama_derived_model = (
124
  (hasattr(model_config, "model_type") and model_config.model_type == "llama")
 
119
  model_config = load_model_config(cfg)
120
  cfg.model_config_type = model_config.model_type
121
 
122
+ cfg.tokenizer_config = (
123
+ cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
124
+ )
125
+
126
  # figure out if the model is llama
127
  cfg.is_llama_derived_model = (
128
  (hasattr(model_config, "model_type") and model_config.model_type == "llama")
src/axolotl/utils/data.py CHANGED
@@ -134,7 +134,7 @@ def load_tokenized_prepared_datasets(
134
  split="train",
135
  ) -> Tuple[DatasetDict, List[Prompter]]:
136
  cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
137
- tokenizer_name = tokenizer.__class__.__name__
138
  ds_hash = str(
139
  md5(
140
  (
 
134
  split="train",
135
  ) -> Tuple[DatasetDict, List[Prompter]]:
136
  cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
137
+ tokenizer_name = cfg.tokenizer_config
138
  ds_hash = str(
139
  md5(
140
  (
src/axolotl/utils/models.py CHANGED
@@ -134,9 +134,8 @@ def load_tokenizer(cfg):
134
  if cfg.tokenizer_type:
135
  tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
136
 
137
- tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
138
  tokenizer = tokenizer_cls.from_pretrained(
139
- tokenizer_config,
140
  trust_remote_code=cfg.trust_remote_code or False,
141
  use_fast=use_fast,
142
  **tokenizer_kwargs,
 
134
  if cfg.tokenizer_type:
135
  tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
136
 
 
137
  tokenizer = tokenizer_cls.from_pretrained(
138
+ cfg.tokenizer_config,
139
  trust_remote_code=cfg.trust_remote_code or False,
140
  use_fast=use_fast,
141
  **tokenizer_kwargs,
tests/core/test_trainer_builder.py CHANGED
@@ -1,16 +1,18 @@
1
  """
2
  unit tests for axolotl.core.trainer_builder
3
  """
 
4
  import pytest
5
 
6
  from axolotl.core.trainer_builder import HFDPOTrainerBuilder
 
7
  from axolotl.utils.dict import DictDefault
8
  from axolotl.utils.models import load_model, load_tokenizer
9
 
10
 
11
  @pytest.fixture(name="cfg")
12
  def fixture_cfg():
13
- return DictDefault(
14
  {
15
  "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
16
  "model_type": "AutoModelForCausalLM",
@@ -34,6 +36,10 @@ def fixture_cfg():
34
  }
35
  )
36
 
 
 
 
 
37
 
38
  @pytest.fixture(name="tokenizer")
39
  def fixture_tokenizer(cfg):
 
1
  """
2
  unit tests for axolotl.core.trainer_builder
3
  """
4
+
5
  import pytest
6
 
7
  from axolotl.core.trainer_builder import HFDPOTrainerBuilder
8
+ from axolotl.utils.config import normalize_config
9
  from axolotl.utils.dict import DictDefault
10
  from axolotl.utils.models import load_model, load_tokenizer
11
 
12
 
13
  @pytest.fixture(name="cfg")
14
  def fixture_cfg():
15
+ cfg = DictDefault(
16
  {
17
  "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
18
  "model_type": "AutoModelForCausalLM",
 
36
  }
37
  )
38
 
39
+ normalize_config(cfg)
40
+
41
+ return cfg
42
+
43
 
44
  @pytest.fixture(name="tokenizer")
45
  def fixture_tokenizer(cfg):