Feat/chatml add system message (#1117)
Browse files* add system message to template
* readme update
* added code to register new system message
* register chatml template for test
---------
Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
README.md
CHANGED
@@ -613,6 +613,8 @@ rl:
|
|
613 |
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
614 |
# Currently supports chatml and inst (mistral/mixtral)
|
615 |
chat_template: chatml
|
|
|
|
|
616 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
617 |
# subsequent training attempts load faster, relative path
|
618 |
dataset_prepared_path: data/last_run_prepared
|
|
|
613 |
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
614 |
# Currently supports chatml and inst (mistral/mixtral)
|
615 |
chat_template: chatml
|
616 |
+
# Changes the default system message
|
617 |
+
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
|
618 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
619 |
# subsequent training attempts load faster, relative path
|
620 |
dataset_prepared_path: data/last_run_prepared
|
src/axolotl/cli/preprocess.py
CHANGED
@@ -18,6 +18,7 @@ from axolotl.cli import (
|
|
18 |
)
|
19 |
from axolotl.common.cli import PreprocessCliArgs
|
20 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
|
21 |
|
22 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
23 |
|
@@ -34,6 +35,12 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
34 |
return_remaining_strings=True
|
35 |
)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
if not parsed_cfg.dataset_prepared_path:
|
38 |
msg = (
|
39 |
Fore.RED
|
|
|
18 |
)
|
19 |
from axolotl.common.cli import PreprocessCliArgs
|
20 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
21 |
+
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
22 |
|
23 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
24 |
|
|
|
35 |
return_remaining_strings=True
|
36 |
)
|
37 |
|
38 |
+
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
|
39 |
+
LOG.info(
|
40 |
+
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
41 |
+
)
|
42 |
+
register_chatml_template(parsed_cfg.default_system_message)
|
43 |
+
|
44 |
if not parsed_cfg.dataset_prepared_path:
|
45 |
msg = (
|
46 |
Fore.RED
|
src/axolotl/cli/train.py
CHANGED
@@ -18,6 +18,7 @@ from axolotl.cli import (
|
|
18 |
print_axolotl_text_art,
|
19 |
)
|
20 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
21 |
from axolotl.train import train
|
22 |
|
23 |
LOG = logging.getLogger("axolotl.cli.train")
|
@@ -37,7 +38,12 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
37 |
print_axolotl_text_art()
|
38 |
check_accelerate_default_config()
|
39 |
check_user_token()
|
40 |
-
if cfg.
|
|
|
|
|
|
|
|
|
|
|
41 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
42 |
else:
|
43 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
18 |
print_axolotl_text_art,
|
19 |
)
|
20 |
from axolotl.common.cli import TrainerCliArgs
|
21 |
+
from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
22 |
from axolotl.train import train
|
23 |
|
24 |
LOG = logging.getLogger("axolotl.cli.train")
|
|
|
38 |
print_axolotl_text_art()
|
39 |
check_accelerate_default_config()
|
40 |
check_user_token()
|
41 |
+
if cfg.chat_template == "chatml" and cfg.default_system_message:
|
42 |
+
LOG.info(
|
43 |
+
f"ChatML set. Adding default system message: {cfg.default_system_message}"
|
44 |
+
)
|
45 |
+
register_chatml_template(cfg.default_system_message)
|
46 |
+
|
47 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
48 |
else:
|
49 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
@@ -6,16 +6,19 @@ from fastchat.conversation import Conversation, SeparatorStyle, register_conv_te
|
|
6 |
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
7 |
from axolotl.prompters import ShareGPTPrompterV2
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
)
|
18 |
-
)
|
19 |
|
20 |
|
21 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|
|
6 |
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
7 |
from axolotl.prompters import ShareGPTPrompterV2
|
8 |
|
9 |
+
|
10 |
+
def register_chatml_template(system_message=None):
|
11 |
+
system_message = system_message or "You are a helpful assistant."
|
12 |
+
register_conv_template(
|
13 |
+
Conversation(
|
14 |
+
name="chatml",
|
15 |
+
system_template="<|im_start|>system\n{system_message}",
|
16 |
+
system_message=system_message,
|
17 |
+
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
18 |
+
sep_style=SeparatorStyle.CHATML,
|
19 |
+
sep="<|im_end|>",
|
20 |
+
)
|
21 |
)
|
|
|
22 |
|
23 |
|
24 |
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
src/axolotl/utils/chat_templates.py
CHANGED
@@ -20,7 +20,7 @@ def chat_templates(user_choice: str):
|
|
20 |
|
21 |
templates = {
|
22 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
23 |
-
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in
|
24 |
}
|
25 |
|
26 |
if user_choice in templates:
|
|
|
20 |
|
21 |
templates = {
|
22 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
23 |
+
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
24 |
}
|
25 |
|
26 |
if user_choice in templates:
|
src/axolotl/utils/models.py
CHANGED
@@ -219,7 +219,13 @@ def load_tokenizer(cfg):
|
|
219 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
220 |
|
221 |
if cfg.chat_template:
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
else:
|
224 |
LOG.info(
|
225 |
"No Chat template selected. Consider adding a chat template for easier inference."
|
|
|
219 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
220 |
|
221 |
if cfg.chat_template:
|
222 |
+
chat_template_string = chat_templates(cfg.chat_template)
|
223 |
+
if cfg.default_system_message and cfg.chat_template == "chatml":
|
224 |
+
chat_template_string = chat_template_string.replace(
|
225 |
+
"You are a helpful assistant.", cfg.default_system_message
|
226 |
+
)
|
227 |
+
|
228 |
+
tokenizer.chat_template = chat_template_string
|
229 |
else:
|
230 |
LOG.info(
|
231 |
"No Chat template selected. Consider adding a chat template for easier inference."
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
@@ -7,9 +7,14 @@ from tokenizers import AddedToken
|
|
7 |
from transformers import AutoTokenizer
|
8 |
|
9 |
from axolotl.datasets import TokenizedPromptDataset
|
10 |
-
from axolotl.prompt_strategies.sharegpt import
|
|
|
|
|
|
|
11 |
from axolotl.prompters import ShareGPTPrompterV2
|
12 |
|
|
|
|
|
13 |
|
14 |
@pytest.fixture(name="sharegpt_dataset")
|
15 |
def fixture_sharegpt_dataset():
|
|
|
7 |
from transformers import AutoTokenizer
|
8 |
|
9 |
from axolotl.datasets import TokenizedPromptDataset
|
10 |
+
from axolotl.prompt_strategies.sharegpt import (
|
11 |
+
SimpleShareGPTPromptTokenizingStrategy,
|
12 |
+
register_chatml_template,
|
13 |
+
)
|
14 |
from axolotl.prompters import ShareGPTPrompterV2
|
15 |
|
16 |
+
register_chatml_template()
|
17 |
+
|
18 |
|
19 |
@pytest.fixture(name="sharegpt_dataset")
|
20 |
def fixture_sharegpt_dataset():
|