|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" |
|
|
|
import logging |
|
from typing import Any, Dict, Optional |
|
|
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template |
|
|
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy |
|
from axolotl.prompters import ShareGPTPrompterV2 |
|
from axolotl.utils.tokenization import ( |
|
chatml_to_conversation, |
|
merge_consecutive_messages, |
|
) |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
|
|
def register_chatml_template(system_message=None): |
|
system_message = system_message or "You are a helpful assistant." |
|
register_conv_template( |
|
Conversation( |
|
name="chatml", |
|
system_template="<|im_start|>system\n{system_message}", |
|
system_message=system_message, |
|
roles=["<|im_start|>user", "<|im_start|>assistant"], |
|
sep_style=SeparatorStyle.CHATML, |
|
sep="<|im_end|>", |
|
) |
|
) |
|
register_conv_template( |
|
Conversation( |
|
name="chatml_glaive", |
|
system_template="<|im_start|>system\n{system_message}", |
|
system_message=system_message, |
|
roles=["<|im_start|>user", "<|im_start|>assistant", "<|im_start|>tool"], |
|
sep_style=SeparatorStyle.CHATML, |
|
sep="<|im_end|>", |
|
) |
|
) |
|
|
|
|
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): |
|
conversation = ( |
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None |
|
) |
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None |
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None |
|
roles = ds_cfg["roles"].to_dict() if ds_cfg and "roles" in ds_cfg else None |
|
strategy = SimpleShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation=conversation, |
|
role_key_model=field_model, |
|
role_key_human=field_human, |
|
roles=roles, |
|
), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
if ds_cfg and "strict" in ds_cfg: |
|
strategy.strict = ds_cfg["strict"] |
|
return strategy |
|
|
|
|
|
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): |
|
conversation = ( |
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None |
|
) |
|
strategy = UltrachatShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation=conversation, |
|
), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
if ds_cfg and "strict" in ds_cfg: |
|
strategy.strict = ds_cfg["strict"] |
|
return strategy |
|
|
|
|
|
def load_role(tokenizer, cfg): |
|
return SimpleRoleShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2(), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|
|
|
|
def load_guanaco(tokenizer, cfg): |
|
return GuanacoShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2(), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|
|
|
|
def load_glaive(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): |
|
conversation = ( |
|
ds_cfg["conversation"] |
|
if ds_cfg and "conversation" in ds_cfg |
|
else "chatml_glaive" |
|
) |
|
return GlaiveShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2(conversation=conversation), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|
|
|
|
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): |
|
""" |
|
basic sharegpt strategy to grab conversations from the sample row |
|
""" |
|
|
|
_strict = False |
|
|
|
@property |
|
def strict(self): |
|
return self._strict |
|
|
|
@strict.setter |
|
def strict(self, strict): |
|
self._strict = strict |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["conversations"] |
|
if self.strict: |
|
return conversations |
|
role_key = "from" |
|
if "role" in conversations[0].keys(): |
|
role_key = "role" |
|
value_key = "value" |
|
if "text" in conversations[0].keys(): |
|
value_key = "text" |
|
elif "content" in conversations[0].keys(): |
|
value_key = "content" |
|
|
|
role_map = { |
|
"user": "human", |
|
"human": "human", |
|
"assistant": "gpt", |
|
"gpt": "gpt", |
|
"system": "system", |
|
} |
|
turns = [ |
|
{ |
|
"from": ( |
|
role_map[t[role_key]] if t[role_key] in role_map else t[role_key] |
|
), |
|
"value": t[value_key], |
|
} |
|
for t in conversations |
|
] |
|
return turns |
|
|
|
|
|
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): |
|
""" |
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["conversations"] |
|
|
|
turns = [{"from": t["role"], "value": t["value"]} for t in conversations] |
|
return turns |
|
|
|
|
|
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): |
|
""" |
|
sharegpt strategy that remaps oasst data to sharegpt format |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["conversations"] |
|
|
|
role_map = {"prompter": "human", "assistant": "gpt"} |
|
turns = [ |
|
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations |
|
] |
|
return turns |
|
|
|
|
|
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): |
|
""" |
|
sharegpt strategy that remaps ultrachat data to sharegpt format |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["messages"] |
|
role_map = {"user": "human", "assistant": "gpt"} |
|
turns = [ |
|
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations |
|
] |
|
return turns |
|
|
|
|
|
class GlaiveShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): |
|
""" |
|
sharegpt strategy that remaps glaive data to sharegpt format |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversation = chatml_to_conversation(prompt) |
|
conversation = merge_consecutive_messages(conversation) |
|
|
|
return conversation |
|
|