Ram Ram winglian commited on
Commit
50421c8
1 Parent(s): b32c08f

feat: Add LLaMA-3 instruct prompt strategies for fine-tuning (#1553)

Browse files

* Add prompt strategies

* Update modified URL

* Update modified URL

* Update fastchat_conversation_turns.py

* Update register function

* Remove extra /n for system prompt

* Fix return

* Fix BOS

* Update requirements, pylint

* Linting

* Linting

* fix tuples, make sure to set system message in template

* tests for llama3 tokenization

* fix conditionals for loading chat template

---------

Co-authored-by: Ram <ram@Rams-MacBook-Pro.local>
Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/cli/preprocess.py CHANGED
@@ -19,7 +19,10 @@ from axolotl.cli import (
19
  )
20
  from axolotl.common.cli import PreprocessCliArgs
21
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
22
- from axolotl.prompt_strategies.sharegpt import register_chatml_template
 
 
 
23
 
24
  LOG = logging.getLogger("axolotl.cli.preprocess")
25
 
@@ -36,13 +39,22 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
36
  return_remaining_strings=True
37
  )
38
 
39
- if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
40
- LOG.info(
41
- f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
42
- )
43
- register_chatml_template(parsed_cfg.default_system_message)
44
- else:
45
- register_chatml_template()
 
 
 
 
 
 
 
 
 
46
 
47
  if not parsed_cfg.dataset_prepared_path:
48
  msg = (
 
19
  )
20
  from axolotl.common.cli import PreprocessCliArgs
21
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
22
+ from axolotl.prompt_strategies.sharegpt import (
23
+ register_chatml_template,
24
+ register_llama3_template,
25
+ )
26
 
27
  LOG = logging.getLogger("axolotl.cli.preprocess")
28
 
 
39
  return_remaining_strings=True
40
  )
41
 
42
+ if parsed_cfg.chat_template == "chatml":
43
+ if parsed_cfg.default_system_message:
44
+ LOG.info(
45
+ f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
46
+ )
47
+ register_chatml_template(parsed_cfg.default_system_message)
48
+ else:
49
+ register_chatml_template()
50
+ elif parsed_cfg.chat_template == "llama3":
51
+ if parsed_cfg.default_system_message:
52
+ LOG.info(
53
+ f"LLaMA-3 set. Adding default system message: {parsed_cfg.default_system_message}"
54
+ )
55
+ register_llama3_template(parsed_cfg.default_system_message)
56
+ else:
57
+ register_llama3_template()
58
 
59
  if not parsed_cfg.dataset_prepared_path:
60
  msg = (
src/axolotl/cli/train.py CHANGED
@@ -19,7 +19,10 @@ from axolotl.cli import (
19
  print_axolotl_text_art,
20
  )
21
  from axolotl.common.cli import TrainerCliArgs
22
- from axolotl.prompt_strategies.sharegpt import register_chatml_template
 
 
 
23
  from axolotl.train import train
24
 
25
  LOG = logging.getLogger("axolotl.cli.train")
@@ -47,6 +50,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
47
  else:
48
  register_chatml_template()
49
 
 
 
 
 
 
 
 
 
50
  if cfg.rl: # and cfg.rl != "orpo":
51
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
52
  else:
 
19
  print_axolotl_text_art,
20
  )
21
  from axolotl.common.cli import TrainerCliArgs
22
+ from axolotl.prompt_strategies.sharegpt import (
23
+ register_chatml_template,
24
+ register_llama3_template,
25
+ )
26
  from axolotl.train import train
27
 
28
  LOG = logging.getLogger("axolotl.cli.train")
 
50
  else:
51
  register_chatml_template()
52
 
53
+ if cfg.chat_template == "llama3" and cfg.default_system_message:
54
+ LOG.info(
55
+ f"LLaMA-3 set. Adding default system message: {cfg.default_system_message}"
56
+ )
57
+ register_llama3_template(cfg.default_system_message)
58
+ else:
59
+ register_llama3_template()
60
+
61
  if cfg.rl: # and cfg.rl != "orpo":
62
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
63
  else:
src/axolotl/prompt_strategies/sharegpt.py CHANGED
@@ -22,7 +22,7 @@ def register_chatml_template(system_message=None):
22
  name="chatml",
23
  system_template="<|im_start|>system\n{system_message}",
24
  system_message=system_message,
25
- roles=["<|im_start|>user", "<|im_start|>assistant"],
26
  sep_style=SeparatorStyle.CHATML,
27
  sep="<|im_end|>",
28
  )
@@ -32,13 +32,29 @@ def register_chatml_template(system_message=None):
32
  name="chatml_glaive",
33
  system_template="<|im_start|>system\n{system_message}",
34
  system_message=system_message,
35
- roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"],
36
  sep_style=SeparatorStyle.CHATML,
37
  sep="<|im_end|>",
38
  )
39
  )
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def build_loader(
43
  tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
44
  prompter_cls: Type["ShareGPTPrompterV2"],
 
22
  name="chatml",
23
  system_template="<|im_start|>system\n{system_message}",
24
  system_message=system_message,
25
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
26
  sep_style=SeparatorStyle.CHATML,
27
  sep="<|im_end|>",
28
  )
 
32
  name="chatml_glaive",
33
  system_template="<|im_start|>system\n{system_message}",
34
  system_message=system_message,
35
+ roles=("<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"),
36
  sep_style=SeparatorStyle.CHATML,
37
  sep="<|im_end|>",
38
  )
39
  )
40
 
41
 
42
+ def register_llama3_template(system_message=None):
43
+ system_message = system_message or "You are a helpful assistant."
44
+ register_conv_template(
45
+ Conversation(
46
+ name="llama3",
47
+ system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>",
48
+ system_message=system_message,
49
+ roles=("user", "assistant"),
50
+ sep_style=SeparatorStyle.LLAMA3,
51
+ sep="",
52
+ stop_str="<|eot_id|>",
53
+ stop_token_ids=[128001, 128009],
54
+ )
55
+ )
56
+
57
+
58
  def build_loader(
59
  tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
60
  prompter_cls: Type["ShareGPTPrompterV2"],
src/axolotl/prompters.py CHANGED
@@ -263,6 +263,7 @@ CONVERSATION_ROLE_FORMAT = {
263
  "chatml": "<|im_start|>{ROLE}",
264
  "zephyr": "<|{ROLE}|>",
265
  "vicuna_v1.1": "{ROLE}",
 
266
  }
267
 
268
 
 
263
  "chatml": "<|im_start|>{ROLE}",
264
  "zephyr": "<|{ROLE}|>",
265
  "vicuna_v1.1": "{ROLE}",
266
+ "llama3": "<|start_header_id|>{ROLE}<|end_header_id|>",
267
  }
268
 
269
 
src/axolotl/utils/chat_templates.py CHANGED
@@ -24,6 +24,7 @@ def chat_templates(user_choice: str):
24
  "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
25
  "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
26
  "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
 
27
  }
28
 
29
  if user_choice in templates:
 
24
  "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
25
  "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
26
  "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
27
+ "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% else %}{{ eos_token }}{% endif %}",
28
  }
29
 
30
  if user_choice in templates:
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -143,6 +143,7 @@ class ChatTemplate(str, Enum):
143
  inst = "inst" # pylint: disable=invalid-name
144
  gemma = "gemma" # pylint: disable=invalid-name
145
  cohere = "cohere" # pylint: disable=invalid-name
 
146
 
147
 
148
  class LoftQConfig(BaseModel):
 
143
  inst = "inst" # pylint: disable=invalid-name
144
  gemma = "gemma" # pylint: disable=invalid-name
145
  cohere = "cohere" # pylint: disable=invalid-name
146
+ llama3 = "llama3" # pylint: disable=invalid-name
147
 
148
 
149
  class LoftQConfig(BaseModel):
tests/prompt_strategies/test_sharegpt.py CHANGED
@@ -12,10 +12,12 @@ from axolotl.prompt_strategies.sharegpt import (
12
  GlaiveShareGPTPromptTokenizingStrategy,
13
  SimpleShareGPTPromptTokenizingStrategy,
14
  register_chatml_template,
 
15
  )
16
  from axolotl.prompters import ShareGPTPrompterV2
17
 
18
  register_chatml_template()
 
19
 
20
 
21
  @pytest.fixture(name="sharegpt_dataset")
@@ -115,7 +117,53 @@ def fixture_tokenizer():
115
  return tokenizer
116
 
117
 
118
- class TestSharegpt:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  """
120
  Test class for sharegpt prompter
121
  """
 
12
  GlaiveShareGPTPromptTokenizingStrategy,
13
  SimpleShareGPTPromptTokenizingStrategy,
14
  register_chatml_template,
15
+ register_llama3_template,
16
  )
17
  from axolotl.prompters import ShareGPTPrompterV2
18
 
19
  register_chatml_template()
20
+ register_llama3_template()
21
 
22
 
23
  @pytest.fixture(name="sharegpt_dataset")
 
117
  return tokenizer
118
 
119
 
120
+ @pytest.fixture(name="llama3_tokenizer")
121
+ def fixture_llama3_tokenizer():
122
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B")
123
+ tokenizer.eos_token = "<|eot_id|>"
124
+
125
+ return tokenizer
126
+
127
+
128
+ class TestSharegptLlama3:
129
+ """Test class for ShareGPT style datasets with llama-3 prompts"""
130
+
131
+ def test_tokenization(self, sharegpt_dataset, llama3_tokenizer):
132
+ strategy = SimpleShareGPTPromptTokenizingStrategy(
133
+ ShareGPTPrompterV2(
134
+ conversation="llama3",
135
+ role_key_model=None,
136
+ role_key_human=None,
137
+ ),
138
+ llama3_tokenizer,
139
+ False, # train_on_inputs
140
+ 2048, # sequence_len
141
+ )
142
+
143
+ dataset_wrapper = TokenizedPromptDataset(
144
+ strategy, sharegpt_dataset, process_count=1
145
+ )
146
+
147
+ input_ids = dataset_wrapper[0]["input_ids"]
148
+
149
+ # fmt: off
150
+ assert input_ids == [
151
+ 128000, # bos
152
+ 128006, 9125, 128007, # system header
153
+ 271, 31724, 128009, # sys prompt, eot
154
+ 128006, 882, 128007, # user header
155
+ 271, 15339, 128009, # user prompt eot
156
+ 128006, 78191, 128007, # assistant header
157
+ 271, 15339, 128009, # assistant response eot
158
+ 128006, 882, 128007,
159
+ 271, 19045, 29474, 128009,
160
+ 128006, 78191, 128007,
161
+ 271, 19045, 29474, 128009,
162
+ ]
163
+ # fmt: on
164
+
165
+
166
+ class TestSharegptChatML:
167
  """
168
  Test class for sharegpt prompter
169
  """