|
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class""" |
|
from typing import Any, Dict, Optional |
|
|
|
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template |
|
|
|
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy |
|
from axolotl.prompters import ShareGPTPrompterV2 |
|
|
|
|
|
def register_chatml_template(system_message=None): |
|
system_message = system_message or "You are a helpful assistant." |
|
register_conv_template( |
|
Conversation( |
|
name="chatml", |
|
system_template="<|im_start|>system\n{system_message}", |
|
system_message=system_message, |
|
roles=["<|im_start|>user", "<|im_start|>assistant"], |
|
sep_style=SeparatorStyle.CHATML, |
|
sep="<|im_end|>", |
|
) |
|
) |
|
|
|
|
|
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): |
|
conversation = ( |
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None |
|
) |
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None |
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None |
|
strategy = SimpleShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation=conversation, |
|
role_key_model=field_model, |
|
role_key_human=field_human, |
|
), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
if ds_cfg and "strict" in ds_cfg: |
|
strategy.strict = ds_cfg["strict"] |
|
return strategy |
|
|
|
|
|
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): |
|
conversation = ( |
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None |
|
) |
|
strategy = UltrachatShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2( |
|
conversation=conversation, |
|
), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
if ds_cfg and "strict" in ds_cfg: |
|
strategy.strict = ds_cfg["strict"] |
|
return strategy |
|
|
|
|
|
def load_role(tokenizer, cfg): |
|
return SimpleRoleShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2(), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|
|
|
|
def load_guanaco(tokenizer, cfg): |
|
return GuanacoShareGPTPromptTokenizingStrategy( |
|
ShareGPTPrompterV2(), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|
|
|
|
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): |
|
""" |
|
basic sharegpt strategy to grab conversations from the sample row |
|
""" |
|
|
|
_strict = False |
|
|
|
@property |
|
def strict(self): |
|
return self._strict |
|
|
|
@strict.setter |
|
def strict(self, strict): |
|
self._strict = strict |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["conversations"] |
|
if self.strict: |
|
return conversations |
|
role_key = "from" |
|
if "role" in conversations[0].keys(): |
|
role_key = "role" |
|
value_key = "value" |
|
if "text" in conversations[0].keys(): |
|
value_key = "text" |
|
elif "content" in conversations[0].keys(): |
|
value_key = "content" |
|
|
|
role_map = { |
|
"user": "human", |
|
"human": "human", |
|
"assistant": "gpt", |
|
"gpt": "gpt", |
|
"system": "system", |
|
} |
|
turns = [ |
|
{"from": role_map[t[role_key]], "value": t[value_key]} |
|
for t in conversations |
|
] |
|
return turns |
|
|
|
|
|
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): |
|
""" |
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["conversations"] |
|
|
|
turns = [{"from": t["role"], "value": t["value"]} for t in conversations] |
|
return turns |
|
|
|
|
|
class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): |
|
""" |
|
sharegpt strategy that remaps oasst data to sharegpt format |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["conversations"] |
|
|
|
role_map = {"prompter": "human", "assistant": "gpt"} |
|
turns = [ |
|
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations |
|
] |
|
return turns |
|
|
|
|
|
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy): |
|
""" |
|
sharegpt strategy that remaps ultrachat data to sharegpt format |
|
""" |
|
|
|
def get_conversation_thread(self, prompt): |
|
conversations = prompt["messages"] |
|
role_map = {"user": "human", "assistant": "gpt"} |
|
turns = [ |
|
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations |
|
] |
|
return turns |
|
|