mhenrichsen
commited on
Commit
•
f8ae59b
1
Parent(s):
4f4d638
Adds chat templates (#1022)
Browse files- README.md +3 -0
- src/axolotl/utils/chat_templates.py +29 -0
- src/axolotl/utils/models.py +7 -0
README.md
CHANGED
@@ -589,6 +589,9 @@ datasets:
|
|
589 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
590 |
field:
|
591 |
|
|
|
|
|
|
|
592 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
593 |
# subsequent training attempts load faster, relative path
|
594 |
dataset_prepared_path: data/last_run_prepared
|
|
|
589 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
590 |
field:
|
591 |
|
592 |
+
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
593 |
+
# Currently supports chatml and inst (mistral/mixtral)
|
594 |
+
chat_template: chatml
|
595 |
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
596 |
# subsequent training attempts load faster, relative path
|
597 |
dataset_prepared_path: data/last_run_prepared
|
src/axolotl/utils/chat_templates.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This module provides functionality for selecting chat templates based on user choices.
|
3 |
+
These templates are used for formatting messages in a conversation.
|
4 |
+
"""
|
5 |
+
|
6 |
+
|
7 |
+
def chat_templates(user_choice: str):
|
8 |
+
"""
|
9 |
+
Finds the correct chat_template for the tokenizer_config.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
user_choice (str): The user's choice of template.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
str: The chosen template string.
|
16 |
+
|
17 |
+
Raises:
|
18 |
+
ValueError: If the user_choice is not found in the templates.
|
19 |
+
"""
|
20 |
+
|
21 |
+
templates = {
|
22 |
+
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
23 |
+
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
24 |
+
}
|
25 |
+
|
26 |
+
if user_choice in templates:
|
27 |
+
return templates[user_choice]
|
28 |
+
|
29 |
+
raise ValueError(f"Template '{user_choice}' not found.")
|
src/axolotl/utils/models.py
CHANGED
@@ -26,6 +26,7 @@ from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|
26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
28 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
29 |
from axolotl.utils.dict import DictDefault
|
30 |
|
31 |
LOG = logging.getLogger("axolotl")
|
@@ -186,6 +187,12 @@ def load_tokenizer(cfg):
|
|
186 |
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
187 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
return tokenizer
|
190 |
|
191 |
|
|
|
26 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
27 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
28 |
from axolotl.utils.bench import log_gpu_memory_usage
|
29 |
+
from axolotl.utils.chat_templates import chat_templates
|
30 |
from axolotl.utils.dict import DictDefault
|
31 |
|
32 |
LOG = logging.getLogger("axolotl")
|
|
|
187 |
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
188 |
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
189 |
|
190 |
+
if cfg.chat_template:
|
191 |
+
tokenizer.chat_template = chat_templates(cfg.chat_template)
|
192 |
+
else:
|
193 |
+
LOG.info(
|
194 |
+
"No Chat template selected. Consider adding a chat template for easier inference."
|
195 |
+
)
|
196 |
return tokenizer
|
197 |
|
198 |
|