winglian commited on
Commit
0ce1a65
1 Parent(s): 043c386

update sharegpt conversations when chatml chat template is set (#1075) [skip ci]

Browse files

* update sharegpt conversations when chatml chat template is set

* add info log when updating sharegpt/chatml conversation

src/axolotl/cli/__init__.py CHANGED
@@ -25,7 +25,11 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
25
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
26
  from axolotl.logging_config import configure_logging
27
  from axolotl.train import TrainDatasetMeta
28
- from axolotl.utils.config import normalize_config, validate_config
 
 
 
 
29
  from axolotl.utils.data import prepare_dataset
30
  from axolotl.utils.dict import DictDefault
31
  from axolotl.utils.distributed import is_main_process
@@ -289,6 +293,8 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
289
 
290
  normalize_config(cfg)
291
 
 
 
292
  setup_wandb_env_vars(cfg)
293
 
294
  setup_mlflow_env_vars(cfg)
 
25
  from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
26
  from axolotl.logging_config import configure_logging
27
  from axolotl.train import TrainDatasetMeta
28
+ from axolotl.utils.config import (
29
+ normalize_cfg_datasets,
30
+ normalize_config,
31
+ validate_config,
32
+ )
33
  from axolotl.utils.data import prepare_dataset
34
  from axolotl.utils.dict import DictDefault
35
  from axolotl.utils.distributed import is_main_process
 
293
 
294
  normalize_config(cfg)
295
 
296
+ normalize_cfg_datasets(cfg)
297
+
298
  setup_wandb_env_vars(cfg)
299
 
300
  setup_mlflow_env_vars(cfg)
src/axolotl/utils/config.py CHANGED
@@ -150,6 +150,21 @@ def normalize_config(cfg):
150
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
151
 
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  def validate_config(cfg):
154
  """
155
  This is a "pre-validation" step that handles the yaml configuration before we have any
 
150
  log_gpu_memory_usage(LOG, "baseline", cfg.device)
151
 
152
 
153
+ def normalize_cfg_datasets(cfg):
154
+ """
155
+ helpers for mapping chat_template to various dataset configurations as necessary
156
+ """
157
+
158
+ if cfg.chat_template and cfg.chat_template == "chatml":
159
+ if cfg.datasets:
160
+ for idx, ds_cfg in enumerate(cfg.datasets):
161
+ if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
162
+ LOG.info(
163
+ f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
164
+ )
165
+ cfg.datasets[idx].conversation = "chatml"
166
+
167
+
168
  def validate_config(cfg):
169
  """
170
  This is a "pre-validation" step that handles the yaml configuration before we have any
tests/test_normalize_config.py CHANGED
@@ -3,7 +3,7 @@ Test classes for checking functionality of the cfg normalization
3
  """
4
  import unittest
5
 
6
- from axolotl.utils.config import normalize_config
7
  from axolotl.utils.dict import DictDefault
8
 
9
 
@@ -44,3 +44,26 @@ class NormalizeConfigTestCase(unittest.TestCase):
44
  normalize_config(cfg)
45
 
46
  assert cfg.base_model_config == cfg.base_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  """
4
  import unittest
5
 
6
+ from axolotl.utils.config import normalize_cfg_datasets, normalize_config
7
  from axolotl.utils.dict import DictDefault
8
 
9
 
 
44
  normalize_config(cfg)
45
 
46
  assert cfg.base_model_config == cfg.base_model
47
+
48
+ def test_chat_template_chatml(self):
49
+ cfg = DictDefault(
50
+ {
51
+ "chat_template": "chatml",
52
+ "datasets": [
53
+ {
54
+ "path": "lorem/ipsum",
55
+ "type": "sharegpt",
56
+ "conversation": "vicuna_v1.1",
57
+ },
58
+ {
59
+ "path": "sit/amet",
60
+ "type": "sharegpt",
61
+ },
62
+ ],
63
+ }
64
+ )
65
+
66
+ normalize_cfg_datasets(cfg)
67
+
68
+ assert cfg.datasets[0].conversation == "vicuna_v1.1"
69
+ assert cfg.datasets[1].conversation == "chatml"