|
""" |
|
User Defined prompts with configuration from the YML config |
|
""" |
|
|
|
from dataclasses import dataclass |
|
from functools import partial |
|
from typing import Optional, Tuple |
|
|
|
from axolotl.prompt_strategies.alpaca_w_system import ( |
|
InstructionWSystemPromptTokenizingStrategy, |
|
SystemDataPrompter, |
|
) |
|
|
|
|
|
@dataclass |
|
class UserDefinedDatasetConfig: |
|
""" |
|
dataclass configuration representing a userdefined dataset type |
|
""" |
|
|
|
system_prompt: str = "" |
|
field_system: str = "system" |
|
field_instruction: str = "instruction" |
|
field_input: str = "input" |
|
field_output: str = "output" |
|
format: str = "{instruction} {input} " |
|
no_input_format: str = "{instruction} " |
|
system_format: str = "{system}" |
|
|
|
def __getitem__(self, item): |
|
return getattr(self, item) |
|
|
|
|
|
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy): |
|
""" |
|
Prompt Tokenization Strategy for user defined prompts |
|
""" |
|
|
|
|
|
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None): |
|
if not ds_cfg: |
|
raise ValueError("Missing dataset prompt configuration") |
|
|
|
system_prompt = "" |
|
if ds_cfg.system_prompt: |
|
system_prompt = ds_cfg.system_prompt |
|
|
|
def parse_instruction_fields( |
|
field_instruction, |
|
field_input, |
|
field_output, |
|
field_system, |
|
system_prompt, |
|
prompt, |
|
) -> Tuple[str, str, str, str]: |
|
return ( |
|
prompt[field_instruction], |
|
prompt[field_input] if field_input in prompt else "", |
|
prompt[field_output] if field_output in prompt else "", |
|
prompt[field_system] if field_system in prompt else system_prompt, |
|
) |
|
|
|
turn_format = ds_cfg.format |
|
turn_no_input_format = ds_cfg.no_input_format |
|
system_format = ds_cfg.system_format |
|
|
|
class UserDefinedPrompter(SystemDataPrompter): |
|
""" |
|
Prompter for user defined prompts |
|
""" |
|
|
|
def match_prompt_style(self): |
|
self.turn_format = turn_format |
|
self.turn_no_input_format = turn_no_input_format |
|
self.system_format = system_format |
|
|
|
prompter = UserDefinedPrompter() |
|
|
|
strat = UserDefinedPromptTokenizationStrategy( |
|
prompter, |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|
|
setattr( |
|
strat, |
|
"parse_instruction_fields", |
|
partial( |
|
parse_instruction_fields, |
|
ds_cfg.field_instruction, |
|
ds_cfg.field_input, |
|
ds_cfg.field_output, |
|
ds_cfg.field_system, |
|
system_prompt, |
|
), |
|
) |
|
return strat |
|
|