mhenrichsen Mads Henrichsen winglian commited on
Commit
98b4762
·
unverified ·
1 Parent(s): ee0b5f6

Feat/chatml add system message (#1117)

Browse files

* add system message to template

* readme update

* added code to register new system message

* register chatml template for test

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

README.md CHANGED
@@ -613,6 +613,8 @@ rl:
613
  # Saves the desired chat template to the tokenizer_config.json for easier inferencing
614
  # Currently supports chatml and inst (mistral/mixtral)
615
  chat_template: chatml
 
 
616
  # Axolotl attempts to save the dataset as an arrow after packing the data together so
617
  # subsequent training attempts load faster, relative path
618
  dataset_prepared_path: data/last_run_prepared
 
613
  # Saves the desired chat template to the tokenizer_config.json for easier inferencing
614
  # Currently supports chatml and inst (mistral/mixtral)
615
  chat_template: chatml
616
+ # Changes the default system message
617
+ default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
618
  # Axolotl attempts to save the dataset as an arrow after packing the data together so
619
  # subsequent training attempts load faster, relative path
620
  dataset_prepared_path: data/last_run_prepared
src/axolotl/cli/preprocess.py CHANGED
@@ -18,6 +18,7 @@ from axolotl.cli import (
18
  )
19
  from axolotl.common.cli import PreprocessCliArgs
20
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
 
21
 
22
  LOG = logging.getLogger("axolotl.cli.preprocess")
23
 
@@ -34,6 +35,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
34
  return_remaining_strings=True
35
  )
36
 
 
 
 
 
 
 
37
  if not parsed_cfg.dataset_prepared_path:
38
  msg = (
39
  Fore.RED
 
18
  )
19
  from axolotl.common.cli import PreprocessCliArgs
20
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
21
+ from axolotl.prompt_strategies.sharegpt import register_chatml_template
22
 
23
  LOG = logging.getLogger("axolotl.cli.preprocess")
24
 
 
35
  return_remaining_strings=True
36
  )
37
 
38
+ if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
39
+ LOG.info(
40
+ f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
41
+ )
42
+ register_chatml_template(parsed_cfg.default_system_message)
43
+
44
  if not parsed_cfg.dataset_prepared_path:
45
  msg = (
46
  Fore.RED
src/axolotl/cli/train.py CHANGED
@@ -18,6 +18,7 @@ from axolotl.cli import (
18
  print_axolotl_text_art,
19
  )
20
  from axolotl.common.cli import TrainerCliArgs
 
21
  from axolotl.train import train
22
 
23
  LOG = logging.getLogger("axolotl.cli.train")
@@ -37,7 +38,12 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
37
  print_axolotl_text_art()
38
  check_accelerate_default_config()
39
  check_user_token()
40
- if cfg.rl:
 
 
 
 
 
41
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
42
  else:
43
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
18
  print_axolotl_text_art,
19
  )
20
  from axolotl.common.cli import TrainerCliArgs
21
+ from axolotl.prompt_strategies.sharegpt import register_chatml_template
22
  from axolotl.train import train
23
 
24
  LOG = logging.getLogger("axolotl.cli.train")
 
38
  print_axolotl_text_art()
39
  check_accelerate_default_config()
40
  check_user_token()
41
+ if cfg.chat_template == "chatml" and cfg.default_system_message:
42
+ LOG.info(
43
+ f"ChatML set. Adding default system message: {cfg.default_system_message}"
44
+ )
45
+ register_chatml_template(cfg.default_system_message)
46
+
47
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
48
  else:
49
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
src/axolotl/prompt_strategies/sharegpt.py CHANGED
@@ -6,16 +6,19 @@ from fastchat.conversation import Conversation, SeparatorStyle, register_conv_te
6
  from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
7
  from axolotl.prompters import ShareGPTPrompterV2
8
 
9
- register_conv_template(
10
- Conversation(
11
- name="chatml",
12
- system_template="<|im_start|>system\n{system_message}",
13
- system_message="You are a helpful assistant.",
14
- roles=["<|im_start|>user", "<|im_start|>assistant"],
15
- sep_style=SeparatorStyle.CHATML,
16
- sep="<|im_end|>",
 
 
 
 
17
  )
18
- )
19
 
20
 
21
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
 
6
  from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
7
  from axolotl.prompters import ShareGPTPrompterV2
8
 
9
+
10
+ def register_chatml_template(system_message=None):
11
+ system_message = system_message or "You are a helpful assistant."
12
+ register_conv_template(
13
+ Conversation(
14
+ name="chatml",
15
+ system_template="<|im_start|>system\n{system_message}",
16
+ system_message=system_message,
17
+ roles=["<|im_start|>user", "<|im_start|>assistant"],
18
+ sep_style=SeparatorStyle.CHATML,
19
+ sep="<|im_end|>",
20
+ )
21
  )
 
22
 
23
 
24
  def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
src/axolotl/utils/chat_templates.py CHANGED
@@ -20,7 +20,7 @@ def chat_templates(user_choice: str):
20
 
21
  templates = {
22
  "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
23
- "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
24
  }
25
 
26
  if user_choice in templates:
 
20
 
21
  templates = {
22
  "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
23
+ "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
24
  }
25
 
26
  if user_choice in templates:
src/axolotl/utils/models.py CHANGED
@@ -219,7 +219,13 @@ def load_tokenizer(cfg):
219
  LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
220
 
221
  if cfg.chat_template:
222
- tokenizer.chat_template = chat_templates(cfg.chat_template)
 
 
 
 
 
 
223
  else:
224
  LOG.info(
225
  "No Chat template selected. Consider adding a chat template for easier inference."
 
219
  LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
220
 
221
  if cfg.chat_template:
222
+ chat_template_string = chat_templates(cfg.chat_template)
223
+ if cfg.default_system_message and cfg.chat_template == "chatml":
224
+ chat_template_string = chat_template_string.replace(
225
+ "You are a helpful assistant.", cfg.default_system_message
226
+ )
227
+
228
+ tokenizer.chat_template = chat_template_string
229
  else:
230
  LOG.info(
231
  "No Chat template selected. Consider adding a chat template for easier inference."
tests/prompt_strategies/test_sharegpt.py CHANGED
@@ -7,9 +7,14 @@ from tokenizers import AddedToken
7
  from transformers import AutoTokenizer
8
 
9
  from axolotl.datasets import TokenizedPromptDataset
10
- from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy
 
 
 
11
  from axolotl.prompters import ShareGPTPrompterV2
12
 
 
 
13
 
14
  @pytest.fixture(name="sharegpt_dataset")
15
  def fixture_sharegpt_dataset():
 
7
  from transformers import AutoTokenizer
8
 
9
  from axolotl.datasets import TokenizedPromptDataset
10
+ from axolotl.prompt_strategies.sharegpt import (
11
+ SimpleShareGPTPromptTokenizingStrategy,
12
+ register_chatml_template,
13
+ )
14
  from axolotl.prompters import ShareGPTPrompterV2
15
 
16
+ register_chatml_template()
17
+
18
 
19
  @pytest.fixture(name="sharegpt_dataset")
20
  def fixture_sharegpt_dataset():