feat: Add LLaMA-3 instruct prompt strategies for fine-tuning (#1553)
Browse files* Add prompt strategies
* Update modified URL
* Update modified URL
* Update fastchat_conversation_turns.py
* Update register function
* Remove extra /n for system prompt
* Fix return
* Fix BOS
* Update requirements, pylint
* Linting
* Linting
* fix tuples, make sure to set system message in template
* tests for llama3 tokenization
* fix conditionals for loading chat template
---------
Co-authored-by: Ram <ram@Rams-MacBook-Pro.local>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
- src/axolotl/cli/preprocess.py +20 -8
- src/axolotl/cli/train.py +12 -1
- src/axolotl/prompt_strategies/sharegpt.py +18 -2
- src/axolotl/prompters.py +1 -0
- src/axolotl/utils/chat_templates.py +1 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +1 -0
- tests/prompt_strategies/test_sharegpt.py +49 -1
src/axolotl/cli/preprocess.py
CHANGED
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|
19 |
)
|
20 |
from axolotl.common.cli import PreprocessCliArgs
|
21 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
22 |
-
from axolotl.prompt_strategies.sharegpt import
|
|
|
|
|
|
|
23 |
|
24 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
25 |
|
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
36 |
return_remaining_strings=True
|
37 |
)
|
38 |
|
39 |
-
if parsed_cfg.chat_template == "chatml"
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
if not parsed_cfg.dataset_prepared_path:
|
48 |
msg = (
|
|
|
19 |
)
|
20 |
from axolotl.common.cli import PreprocessCliArgs
|
21 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
22 |
+
from axolotl.prompt_strategies.sharegpt import (
|
23 |
+
register_chatml_template,
|
24 |
+
register_llama3_template,
|
25 |
+
)
|
26 |
|
27 |
LOG = logging.getLogger("axolotl.cli.preprocess")
|
28 |
|
|
|
39 |
return_remaining_strings=True
|
40 |
)
|
41 |
|
42 |
+
if parsed_cfg.chat_template == "chatml":
|
43 |
+
if parsed_cfg.default_system_message:
|
44 |
+
LOG.info(
|
45 |
+
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
|
46 |
+
)
|
47 |
+
register_chatml_template(parsed_cfg.default_system_message)
|
48 |
+
else:
|
49 |
+
register_chatml_template()
|
50 |
+
elif parsed_cfg.chat_template == "llama3":
|
51 |
+
if parsed_cfg.default_system_message:
|
52 |
+
LOG.info(
|
53 |
+
f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
|
54 |
+
)
|
55 |
+
register_llama3_template(parsed_cfg.default_system_message)
|
56 |
+
else:
|
57 |
+
register_llama3_template()
|
58 |
|
59 |
if not parsed_cfg.dataset_prepared_path:
|
60 |
msg = (
|
src/axolotl/cli/train.py
CHANGED
@@ -19,7 +19,10 @@ from axolotl.cli import (
|
|
19 |
print_axolotl_text_art,
|
20 |
)
|
21 |
from axolotl.common.cli import TrainerCliArgs
|
22 |
-
from axolotl.prompt_strategies.sharegpt import
|
|
|
|
|
|
|
23 |
from axolotl.train import train
|
24 |
|
25 |
LOG = logging.getLogger("axolotl.cli.train")
|
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
47 |
else:
|
48 |
register_chatml_template()
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
if cfg.rl: # and cfg.rl != "orpo":
|
51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
52 |
else:
|
|
|
19 |
print_axolotl_text_art,
|
20 |
)
|
21 |
from axolotl.common.cli import TrainerCliArgs
|
22 |
+
from axolotl.prompt_strategies.sharegpt import (
|
23 |
+
register_chatml_template,
|
24 |
+
register_llama3_template,
|
25 |
+
)
|
26 |
from axolotl.train import train
|
27 |
|
28 |
LOG = logging.getLogger("axolotl.cli.train")
|
|
|
50 |
else:
|
51 |
register_chatml_template()
|
52 |
|
53 |
+
if cfg.chat_template == "llama3" and cfg.default_system_message:
|
54 |
+
LOG.info(
|
55 |
+
f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
|
56 |
+
)
|
57 |
+
register_llama3_template(cfg.default_system_message)
|
58 |
+
else:
|
59 |
+
register_llama3_template()
|
60 |
+
|
61 |
if cfg.rl: # and cfg.rl != "orpo":
|
62 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
63 |
else:
|
src/axolotl/prompt_strategies/sharegpt.py
CHANGED
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
|
|
22 |
name="chatml",
|
23 |
system_template="<|im_start|>system\n{system_message}",
|
24 |
system_message=system_message,
|
25 |
-
roles=
|
26 |
sep_style=SeparatorStyle.CHATML,
|
27 |
sep="<|im_end|>",
|
28 |
)
|
@@ -32,13 +32,29 @@ def register_chatml_template(system_message=None):
|
|
32 |
name="chatml_glaive",
|
33 |
system_template="<|im_start|>system\n{system_message}",
|
34 |
system_message=system_message,
|
35 |
-
roles=
|
36 |
sep_style=SeparatorStyle.CHATML,
|
37 |
sep="<|im_end|>",
|
38 |
)
|
39 |
)
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
def build_loader(
|
43 |
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
44 |
prompter_cls: Type["ShareGPTPrompterV2"],
|
|
|
22 |
name="chatml",
|
23 |
system_template="<|im_start|>system\n{system_message}",
|
24 |
system_message=system_message,
|
25 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
26 |
sep_style=SeparatorStyle.CHATML,
|
27 |
sep="<|im_end|>",
|
28 |
)
|
|
|
32 |
name="chatml_glaive",
|
33 |
system_template="<|im_start|>system\n{system_message}",
|
34 |
system_message=system_message,
|
35 |
+
roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
|
36 |
sep_style=SeparatorStyle.CHATML,
|
37 |
sep="<|im_end|>",
|
38 |
)
|
39 |
)
|
40 |
|
41 |
|
42 |
+
def register_llama3_template(system_message=None):
|
43 |
+
system_message = system_message or "You are a helpful assistant."
|
44 |
+
register_conv_template(
|
45 |
+
Conversation(
|
46 |
+
name="llama3",
|
47 |
+
system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
|
48 |
+
system_message=system_message,
|
49 |
+
roles=("user", "assistant"),
|
50 |
+
sep_style=SeparatorStyle.LLAMA3,
|
51 |
+
sep="",
|
52 |
+
stop_str="<|eot_id|>",
|
53 |
+
stop_token_ids=[128001, 128009],
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
def build_loader(
|
59 |
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
|
60 |
prompter_cls: Type["ShareGPTPrompterV2"],
|
src/axolotl/prompters.py
CHANGED
@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
|
|
263 |
"chatml": "<|im_start|>{ROLE}",
|
264 |
"zephyr": "<|{ROLE}|>",
|
265 |
"vicuna_v1.1": "{ROLE}",
|
|
|
266 |
}
|
267 |
|
268 |
|
|
|
263 |
"chatml": "<|im_start|>{ROLE}",
|
264 |
"zephyr": "<|{ROLE}|>",
|
265 |
"vicuna_v1.1": "{ROLE}",
|
266 |
+
"llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
|
267 |
}
|
268 |
|
269 |
|
src/axolotl/utils/chat_templates.py
CHANGED
@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
|
|
24 |
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
26 |
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
|
|
27 |
}
|
28 |
|
29 |
if user_choice in templates:
|
|
|
24 |
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
26 |
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
27 |
+
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
|
28 |
}
|
29 |
|
30 |
if user_choice in templates:
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
|
|
143 |
inst = "inst" # pylint: disable=invalid-name
|
144 |
gemma = "gemma" # pylint: disable=invalid-name
|
145 |
cohere = "cohere" # pylint: disable=invalid-name
|
|
|
146 |
|
147 |
|
148 |
class LoftQConfig(BaseModel):
|
|
|
143 |
inst = "inst" # pylint: disable=invalid-name
|
144 |
gemma = "gemma" # pylint: disable=invalid-name
|
145 |
cohere = "cohere" # pylint: disable=invalid-name
|
146 |
+
llama3 = "llama3" # pylint: disable=invalid-name
|
147 |
|
148 |
|
149 |
class LoftQConfig(BaseModel):
|
tests/prompt_strategies/test_sharegpt.py
CHANGED
@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
|
|
12 |
GlaiveShareGPTPromptTokenizingStrategy,
|
13 |
SimpleShareGPTPromptTokenizingStrategy,
|
14 |
register_chatml_template,
|
|
|
15 |
)
|
16 |
from axolotl.prompters import ShareGPTPrompterV2
|
17 |
|
18 |
register_chatml_template()
|
|
|
19 |
|
20 |
|
21 |
@pytest.fixture(name="sharegpt_dataset")
|
@@ -115,7 +117,53 @@ def fixture_tokenizer():
|
|
115 |
return tokenizer
|
116 |
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
"""
|
120 |
Test class for sharegpt prompter
|
121 |
"""
|
|
|
12 |
GlaiveShareGPTPromptTokenizingStrategy,
|
13 |
SimpleShareGPTPromptTokenizingStrategy,
|
14 |
register_chatml_template,
|
15 |
+
register_llama3_template,
|
16 |
)
|
17 |
from axolotl.prompters import ShareGPTPrompterV2
|
18 |
|
19 |
register_chatml_template()
|
20 |
+
register_llama3_template()
|
21 |
|
22 |
|
23 |
@pytest.fixture(name="sharegpt_dataset")
|
|
|
117 |
return tokenizer
|
118 |
|
119 |
|
120 |
+
@pytest.fixture(name="llama3_tokenizer")
|
121 |
+
def fixture_llama3_tokenizer():
|
122 |
+
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
|
123 |
+
tokenizer.eos_token = "<|eot_id|>"
|
124 |
+
|
125 |
+
return tokenizer
|
126 |
+
|
127 |
+
|
128 |
+
class TestSharegptLlama3:
|
129 |
+
"""Test class for ShareGPT style datasets with llama-3 prompts"""
|
130 |
+
|
131 |
+
def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
|
132 |
+
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
133 |
+
ShareGPTPrompterV2(
|
134 |
+
conversation="llama3",
|
135 |
+
role_key_model=None,
|
136 |
+
role_key_human=None,
|
137 |
+
),
|
138 |
+
llama3_tokenizer,
|
139 |
+
False, # train_on_inputs
|
140 |
+
2048, # sequence_len
|
141 |
+
)
|
142 |
+
|
143 |
+
dataset_wrapper = TokenizedPromptDataset(
|
144 |
+
strategy, sharegpt_dataset, process_count=1
|
145 |
+
)
|
146 |
+
|
147 |
+
input_ids = dataset_wrapper[0]["input_ids"]
|
148 |
+
|
149 |
+
# fmt: off
|
150 |
+
assert input_ids == [
|
151 |
+
128000, # bos
|
152 |
+
128006, 9125, 128007, # system header
|
153 |
+
271, 31724, 128009, # sys prompt, eot
|
154 |
+
128006, 882, 128007, # user header
|
155 |
+
271, 15339, 128009, # user prompt eot
|
156 |
+
128006, 78191, 128007, # assistant header
|
157 |
+
271, 15339, 128009, # assistant response eot
|
158 |
+
128006, 882, 128007,
|
159 |
+
271, 19045, 29474, 128009,
|
160 |
+
128006, 78191, 128007,
|
161 |
+
271, 19045, 29474, 128009,
|
162 |
+
]
|
163 |
+
# fmt: on
|
164 |
+
|
165 |
+
|
166 |
+
class TestSharegptChatML:
|
167 |
"""
|
168 |
Test class for sharegpt prompter
|
169 |
"""
|