support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)
Browse files* support user defined prompters, pretokenized datasets in config, local parquet, local arrow files
* fix user defined dataset types
* fix for system prompts
* fix tests
* fix checks for parquet and arrow
* aha moment that d.data_files isn't used
* add documentation for ds_type to add support for parquet and arrow
README.md
CHANGED
@@ -392,6 +392,7 @@ datasets:
|
|
392 |
- path: vicgalle/alpaca-gpt4
|
393 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
394 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
|
|
395 |
data_files: # path to source data files
|
396 |
shards: # number of shards to split data into
|
397 |
name: # name of dataset configuration to load
|
|
|
392 |
- path: vicgalle/alpaca-gpt4
|
393 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
394 |
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
395 |
+
ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
|
396 |
data_files: # path to source data files
|
397 |
shards: # number of shards to split data into
|
398 |
name: # name of dataset configuration to load
|
src/axolotl/prompt_strategies/__init__.py
CHANGED
@@ -2,8 +2,10 @@
|
|
2 |
|
3 |
import importlib
|
4 |
|
|
|
5 |
|
6 |
-
|
|
|
7 |
try:
|
8 |
load_fn = "load"
|
9 |
if strategy.split(".")[-1].startswith("load_"):
|
@@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg):
|
|
11 |
strategy = ".".join(strategy.split(".")[:-1])
|
12 |
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
13 |
func = getattr(mod, load_fn)
|
14 |
-
|
|
|
|
|
|
|
15 |
except Exception: # pylint: disable=broad-exception-caught
|
16 |
return None
|
|
|
2 |
|
3 |
import importlib
|
4 |
|
5 |
+
from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
|
6 |
|
7 |
+
|
8 |
+
def load(strategy, tokenizer, cfg, ds_cfg):
|
9 |
try:
|
10 |
load_fn = "load"
|
11 |
if strategy.split(".")[-1].startswith("load_"):
|
|
|
13 |
strategy = ".".join(strategy.split(".")[:-1])
|
14 |
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
|
15 |
func = getattr(mod, load_fn)
|
16 |
+
load_kwargs = {}
|
17 |
+
if strategy == "user_defined":
|
18 |
+
load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
|
19 |
+
return func(tokenizer, cfg, **load_kwargs)
|
20 |
except Exception: # pylint: disable=broad-exception-caught
|
21 |
return None
|
src/axolotl/prompt_strategies/alpaca_w_system.py
CHANGED
@@ -57,6 +57,8 @@ class SystemDataPrompter(AlpacaPrompter):
|
|
57 |
Alpaca Style Prompter that uses system prompts from the dataset
|
58 |
"""
|
59 |
|
|
|
|
|
60 |
def build_prompt_w_system(
|
61 |
self,
|
62 |
system: str,
|
|
|
57 |
Alpaca Style Prompter that uses system prompts from the dataset
|
58 |
"""
|
59 |
|
60 |
+
system_format: str = "### System:\n{system}\n\n"
|
61 |
+
|
62 |
def build_prompt_w_system(
|
63 |
self,
|
64 |
system: str,
|
src/axolotl/prompt_strategies/user_defined.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
User Defined prompts with configuration from the YML config
|
3 |
+
"""
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from functools import partial
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
|
9 |
+
from axolotl.prompt_strategies.alpaca_w_system import (
|
10 |
+
InstructionWSystemPromptTokenizingStrategy,
|
11 |
+
SystemDataPrompter,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class UserDefinedDatasetConfig:
|
17 |
+
"""
|
18 |
+
dataclass configuration representing a userdefined dataset type
|
19 |
+
"""
|
20 |
+
|
21 |
+
system_prompt: str = ""
|
22 |
+
field_system: str = "system"
|
23 |
+
field_instruction: str = "instruction"
|
24 |
+
field_input: str = "input"
|
25 |
+
field_output: str = "output"
|
26 |
+
format: str = "{instruction} {input} "
|
27 |
+
no_input_format: str = "{instruction} "
|
28 |
+
system_format: str = "{system}"
|
29 |
+
|
30 |
+
def __getitem__(self, item):
|
31 |
+
return getattr(self, item)
|
32 |
+
|
33 |
+
|
34 |
+
class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
|
35 |
+
"""
|
36 |
+
Prompt Tokenization Strategy for user defined prompts
|
37 |
+
"""
|
38 |
+
|
39 |
+
|
40 |
+
def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
|
41 |
+
if not ds_cfg:
|
42 |
+
raise ValueError("Missing dataset prompt configuration")
|
43 |
+
|
44 |
+
system_prompt = ""
|
45 |
+
if ds_cfg.system_prompt:
|
46 |
+
system_prompt = ds_cfg.system_prompt
|
47 |
+
|
48 |
+
def parse_instruction_fields(
|
49 |
+
field_instruction,
|
50 |
+
field_input,
|
51 |
+
field_output,
|
52 |
+
field_system,
|
53 |
+
system_prompt,
|
54 |
+
prompt,
|
55 |
+
) -> Tuple[str, str, str, str]:
|
56 |
+
return (
|
57 |
+
prompt[field_instruction],
|
58 |
+
prompt[field_input] if field_input in prompt else "",
|
59 |
+
prompt[field_output] if field_output in prompt else "",
|
60 |
+
prompt[field_system] if field_system in prompt else system_prompt,
|
61 |
+
)
|
62 |
+
|
63 |
+
turn_format = ds_cfg.format
|
64 |
+
turn_no_input_format = ds_cfg.no_input_format
|
65 |
+
system_format = ds_cfg.system_format
|
66 |
+
|
67 |
+
class UserDefinedPrompter(SystemDataPrompter):
|
68 |
+
"""
|
69 |
+
Prompter for user defined prompts
|
70 |
+
"""
|
71 |
+
|
72 |
+
def match_prompt_style(self):
|
73 |
+
self.turn_format = turn_format
|
74 |
+
self.turn_no_input_format = turn_no_input_format
|
75 |
+
self.system_format = system_format
|
76 |
+
|
77 |
+
prompter = UserDefinedPrompter()
|
78 |
+
|
79 |
+
strat = UserDefinedPromptTokenizationStrategy(
|
80 |
+
prompter,
|
81 |
+
tokenizer,
|
82 |
+
cfg.train_on_inputs,
|
83 |
+
cfg.sequence_len,
|
84 |
+
)
|
85 |
+
|
86 |
+
setattr(
|
87 |
+
strat,
|
88 |
+
"parse_instruction_fields",
|
89 |
+
partial(
|
90 |
+
parse_instruction_fields,
|
91 |
+
ds_cfg.field_instruction,
|
92 |
+
ds_cfg.field_input,
|
93 |
+
ds_cfg.field_output,
|
94 |
+
ds_cfg.field_system,
|
95 |
+
system_prompt,
|
96 |
+
),
|
97 |
+
)
|
98 |
+
return strat
|
src/axolotl/prompters.py
CHANGED
@@ -26,7 +26,7 @@ class AlpacaPrompter:
|
|
26 |
|
27 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
28 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
29 |
-
system_format: str
|
30 |
turn_format: str
|
31 |
turn_no_input_format: str
|
32 |
prompt_style: Optional[PromptStyle] = None
|
@@ -63,13 +63,17 @@ class AlpacaPrompter:
|
|
63 |
# returns the full prompt from instruction and optional input
|
64 |
# if a label (=response, =output) is provided, it's also appended.
|
65 |
if input:
|
66 |
-
res =
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
else:
|
70 |
-
res =
|
71 |
-
|
72 |
-
|
|
|
|
|
73 |
if output:
|
74 |
res = f"{res}{output}"
|
75 |
yield res
|
|
|
26 |
|
27 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
28 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
29 |
+
system_format: str = "{system}"
|
30 |
turn_format: str
|
31 |
turn_no_input_format: str
|
32 |
prompt_style: Optional[PromptStyle] = None
|
|
|
63 |
# returns the full prompt from instruction and optional input
|
64 |
# if a label (=response, =output) is provided, it's also appended.
|
65 |
if input:
|
66 |
+
res = (
|
67 |
+
self.system_format.format(system=self.system_prompt)
|
68 |
+
if self.system_prompt
|
69 |
+
else ""
|
70 |
+
) + self.turn_format.format(instruction=instruction, input=input)
|
71 |
else:
|
72 |
+
res = (
|
73 |
+
self.system_format.format(system=self.system_no_input_prompt)
|
74 |
+
if self.system_prompt
|
75 |
+
else ""
|
76 |
+
) + self.turn_no_input_format.format(instruction=instruction)
|
77 |
if output:
|
78 |
res = f"{res}{output}"
|
79 |
yield res
|
src/axolotl/utils/data.py
CHANGED
@@ -41,6 +41,7 @@ from axolotl.prompters import (
|
|
41 |
ShareGPTPrompter,
|
42 |
SummarizeTLDRPrompter,
|
43 |
)
|
|
|
44 |
from axolotl.utils.distributed import is_main_process, zero_first
|
45 |
from axolotl.utils.trainer import (
|
46 |
calculate_total_num_steps,
|
@@ -160,8 +161,15 @@ def load_tokenized_prepared_datasets(
|
|
160 |
split=None,
|
161 |
)
|
162 |
elif local_path.is_file():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
ds = load_dataset(
|
164 |
-
|
165 |
name=d.name,
|
166 |
data_files=d.path,
|
167 |
streaming=False,
|
@@ -198,13 +206,27 @@ def load_tokenized_prepared_datasets(
|
|
198 |
)
|
199 |
else:
|
200 |
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
|
|
|
|
201 |
d_type = d.type
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
205 |
if "train" in ds:
|
206 |
ds = ds["train"]
|
207 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
209 |
datasets.append(ds_wrapper)
|
210 |
elif d_base_type == "alpaca":
|
|
|
41 |
ShareGPTPrompter,
|
42 |
SummarizeTLDRPrompter,
|
43 |
)
|
44 |
+
from axolotl.utils.dict import DictDefault
|
45 |
from axolotl.utils.distributed import is_main_process, zero_first
|
46 |
from axolotl.utils.trainer import (
|
47 |
calculate_total_num_steps,
|
|
|
161 |
split=None,
|
162 |
)
|
163 |
elif local_path.is_file():
|
164 |
+
ds_type = "json"
|
165 |
+
if d.ds_type:
|
166 |
+
ds_type = d.ds_type
|
167 |
+
elif ".parquet" in d.path:
|
168 |
+
ds_type = "parquet"
|
169 |
+
elif ".arrow" in d.path:
|
170 |
+
ds_type = "arrow"
|
171 |
ds = load_dataset(
|
172 |
+
ds_type,
|
173 |
name=d.name,
|
174 |
data_files=d.path,
|
175 |
streaming=False,
|
|
|
206 |
)
|
207 |
else:
|
208 |
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
209 |
+
|
210 |
+
d_base_type = d_prompt_style = None
|
211 |
d_type = d.type
|
212 |
+
if isinstance(d_type, str):
|
213 |
+
d_type_split = d_type.split(":")
|
214 |
+
d_base_type = d_type_split[0]
|
215 |
+
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
216 |
if "train" in ds:
|
217 |
ds = ds["train"]
|
218 |
+
if (
|
219 |
+
"input_ids" in ds.features
|
220 |
+
and "attention_mask" in ds.features
|
221 |
+
and "labels" in ds.features
|
222 |
+
):
|
223 |
+
# dataset is already tokenized, just drop it straight in
|
224 |
+
datasets.append(ds)
|
225 |
+
elif isinstance(d.type, DictDefault):
|
226 |
+
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
227 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
228 |
+
datasets.append(ds_wrapper)
|
229 |
+
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
230 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
231 |
datasets.append(ds_wrapper)
|
232 |
elif d_base_type == "alpaca":
|