winglian commited on
Commit
d2e7f27
1 Parent(s): d21318d

support user defined prompters, pretokenized datasets in config, local parquet, local arrow files (#348)

Browse files

* support user defined prompters, pretokenized datasets in config, local parquet, local arrow files

* fix user defined dataset types

* fix for system prompts

* fix tests

* fix checks for parquet and arrow

* aha moment that d.data_files isn't used

* add documentation for ds_type to add support for parquet and arrow

README.md CHANGED
@@ -392,6 +392,7 @@ datasets:
392
  - path: vicgalle/alpaca-gpt4
393
  # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
394
  type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
 
395
  data_files: # path to source data files
396
  shards: # number of shards to split data into
397
  name: # name of dataset configuration to load
 
392
  - path: vicgalle/alpaca-gpt4
393
  # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
394
  type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
395
+ ds_type: # Optional[str] (json|arrow|parquet) defines the datatype when path is a file
396
  data_files: # path to source data files
397
  shards: # number of shards to split data into
398
  name: # name of dataset configuration to load
src/axolotl/prompt_strategies/__init__.py CHANGED
@@ -2,8 +2,10 @@
2
 
3
  import importlib
4
 
 
5
 
6
- def load(strategy, tokenizer, cfg):
 
7
  try:
8
  load_fn = "load"
9
  if strategy.split(".")[-1].startswith("load_"):
@@ -11,6 +13,9 @@ def load(strategy, tokenizer, cfg):
11
  strategy = ".".join(strategy.split(".")[:-1])
12
  mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
13
  func = getattr(mod, load_fn)
14
- return func(tokenizer, cfg)
 
 
 
15
  except Exception: # pylint: disable=broad-exception-caught
16
  return None
 
2
 
3
  import importlib
4
 
5
+ from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig
6
 
7
+
8
+ def load(strategy, tokenizer, cfg, ds_cfg):
9
  try:
10
  load_fn = "load"
11
  if strategy.split(".")[-1].startswith("load_"):
 
13
  strategy = ".".join(strategy.split(".")[:-1])
14
  mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
15
  func = getattr(mod, load_fn)
16
+ load_kwargs = {}
17
+ if strategy == "user_defined":
18
+ load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
19
+ return func(tokenizer, cfg, **load_kwargs)
20
  except Exception: # pylint: disable=broad-exception-caught
21
  return None
src/axolotl/prompt_strategies/alpaca_w_system.py CHANGED
@@ -57,6 +57,8 @@ class SystemDataPrompter(AlpacaPrompter):
57
  Alpaca Style Prompter that uses system prompts from the dataset
58
  """
59
 
 
 
60
  def build_prompt_w_system(
61
  self,
62
  system: str,
 
57
  Alpaca Style Prompter that uses system prompts from the dataset
58
  """
59
 
60
+ system_format: str = "### System:\n{system}\n\n"
61
+
62
  def build_prompt_w_system(
63
  self,
64
  system: str,
src/axolotl/prompt_strategies/user_defined.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ User Defined prompts with configuration from the YML config
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from functools import partial
7
+ from typing import Optional, Tuple
8
+
9
+ from axolotl.prompt_strategies.alpaca_w_system import (
10
+ InstructionWSystemPromptTokenizingStrategy,
11
+ SystemDataPrompter,
12
+ )
13
+
14
+
15
+ @dataclass
16
+ class UserDefinedDatasetConfig:
17
+ """
18
+ dataclass configuration representing a userdefined dataset type
19
+ """
20
+
21
+ system_prompt: str = ""
22
+ field_system: str = "system"
23
+ field_instruction: str = "instruction"
24
+ field_input: str = "input"
25
+ field_output: str = "output"
26
+ format: str = "{instruction} {input} "
27
+ no_input_format: str = "{instruction} "
28
+ system_format: str = "{system}"
29
+
30
+ def __getitem__(self, item):
31
+ return getattr(self, item)
32
+
33
+
34
+ class UserDefinedPromptTokenizationStrategy(InstructionWSystemPromptTokenizingStrategy):
35
+ """
36
+ Prompt Tokenization Strategy for user defined prompts
37
+ """
38
+
39
+
40
+ def load(tokenizer, cfg, ds_cfg: Optional[UserDefinedDatasetConfig] = None):
41
+ if not ds_cfg:
42
+ raise ValueError("Missing dataset prompt configuration")
43
+
44
+ system_prompt = ""
45
+ if ds_cfg.system_prompt:
46
+ system_prompt = ds_cfg.system_prompt
47
+
48
+ def parse_instruction_fields(
49
+ field_instruction,
50
+ field_input,
51
+ field_output,
52
+ field_system,
53
+ system_prompt,
54
+ prompt,
55
+ ) -> Tuple[str, str, str, str]:
56
+ return (
57
+ prompt[field_instruction],
58
+ prompt[field_input] if field_input in prompt else "",
59
+ prompt[field_output] if field_output in prompt else "",
60
+ prompt[field_system] if field_system in prompt else system_prompt,
61
+ )
62
+
63
+ turn_format = ds_cfg.format
64
+ turn_no_input_format = ds_cfg.no_input_format
65
+ system_format = ds_cfg.system_format
66
+
67
+ class UserDefinedPrompter(SystemDataPrompter):
68
+ """
69
+ Prompter for user defined prompts
70
+ """
71
+
72
+ def match_prompt_style(self):
73
+ self.turn_format = turn_format
74
+ self.turn_no_input_format = turn_no_input_format
75
+ self.system_format = system_format
76
+
77
+ prompter = UserDefinedPrompter()
78
+
79
+ strat = UserDefinedPromptTokenizationStrategy(
80
+ prompter,
81
+ tokenizer,
82
+ cfg.train_on_inputs,
83
+ cfg.sequence_len,
84
+ )
85
+
86
+ setattr(
87
+ strat,
88
+ "parse_instruction_fields",
89
+ partial(
90
+ parse_instruction_fields,
91
+ ds_cfg.field_instruction,
92
+ ds_cfg.field_input,
93
+ ds_cfg.field_output,
94
+ ds_cfg.field_system,
95
+ system_prompt,
96
+ ),
97
+ )
98
+ return strat
src/axolotl/prompters.py CHANGED
@@ -26,7 +26,7 @@ class AlpacaPrompter:
26
 
27
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
28
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
29
- system_format: str
30
  turn_format: str
31
  turn_no_input_format: str
32
  prompt_style: Optional[PromptStyle] = None
@@ -63,13 +63,17 @@ class AlpacaPrompter:
63
  # returns the full prompt from instruction and optional input
64
  # if a label (=response, =output) is provided, it's also appended.
65
  if input:
66
- res = self.system_prompt + self.turn_format.format(
67
- instruction=instruction, input=input
68
- )
 
 
69
  else:
70
- res = self.system_no_input_prompt + self.turn_no_input_format.format(
71
- instruction=instruction
72
- )
 
 
73
  if output:
74
  res = f"{res}{output}"
75
  yield res
 
26
 
27
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
28
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
29
+ system_format: str = "{system}"
30
  turn_format: str
31
  turn_no_input_format: str
32
  prompt_style: Optional[PromptStyle] = None
 
63
  # returns the full prompt from instruction and optional input
64
  # if a label (=response, =output) is provided, it's also appended.
65
  if input:
66
+ res = (
67
+ self.system_format.format(system=self.system_prompt)
68
+ if self.system_prompt
69
+ else ""
70
+ ) + self.turn_format.format(instruction=instruction, input=input)
71
  else:
72
+ res = (
73
+ self.system_format.format(system=self.system_no_input_prompt)
74
+ if self.system_prompt
75
+ else ""
76
+ ) + self.turn_no_input_format.format(instruction=instruction)
77
  if output:
78
  res = f"{res}{output}"
79
  yield res
src/axolotl/utils/data.py CHANGED
@@ -41,6 +41,7 @@ from axolotl.prompters import (
41
  ShareGPTPrompter,
42
  SummarizeTLDRPrompter,
43
  )
 
44
  from axolotl.utils.distributed import is_main_process, zero_first
45
  from axolotl.utils.trainer import (
46
  calculate_total_num_steps,
@@ -160,8 +161,15 @@ def load_tokenized_prepared_datasets(
160
  split=None,
161
  )
162
  elif local_path.is_file():
 
 
 
 
 
 
 
163
  ds = load_dataset(
164
- "json",
165
  name=d.name,
166
  data_files=d.path,
167
  streaming=False,
@@ -198,13 +206,27 @@ def load_tokenized_prepared_datasets(
198
  )
199
  else:
200
  ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
 
 
201
  d_type = d.type
202
- d_type_split = d_type.split(":")
203
- d_base_type = d_type_split[0]
204
- d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
 
205
  if "train" in ds:
206
  ds = ds["train"]
207
- if ds_strategy := load(d.type, tokenizer, cfg):
 
 
 
 
 
 
 
 
 
 
 
208
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
209
  datasets.append(ds_wrapper)
210
  elif d_base_type == "alpaca":
 
41
  ShareGPTPrompter,
42
  SummarizeTLDRPrompter,
43
  )
44
+ from axolotl.utils.dict import DictDefault
45
  from axolotl.utils.distributed import is_main_process, zero_first
46
  from axolotl.utils.trainer import (
47
  calculate_total_num_steps,
 
161
  split=None,
162
  )
163
  elif local_path.is_file():
164
+ ds_type = "json"
165
+ if d.ds_type:
166
+ ds_type = d.ds_type
167
+ elif ".parquet" in d.path:
168
+ ds_type = "parquet"
169
+ elif ".arrow" in d.path:
170
+ ds_type = "arrow"
171
  ds = load_dataset(
172
+ ds_type,
173
  name=d.name,
174
  data_files=d.path,
175
  streaming=False,
 
206
  )
207
  else:
208
  ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
209
+
210
+ d_base_type = d_prompt_style = None
211
  d_type = d.type
212
+ if isinstance(d_type, str):
213
+ d_type_split = d_type.split(":")
214
+ d_base_type = d_type_split[0]
215
+ d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
216
  if "train" in ds:
217
  ds = ds["train"]
218
+ if (
219
+ "input_ids" in ds.features
220
+ and "attention_mask" in ds.features
221
+ and "labels" in ds.features
222
+ ):
223
+ # dataset is already tokenized, just drop it straight in
224
+ datasets.append(ds)
225
+ elif isinstance(d.type, DictDefault):
226
+ ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
227
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
228
+ datasets.append(ds_wrapper)
229
+ elif ds_strategy := load(d.type, tokenizer, cfg, d):
230
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
231
  datasets.append(ds_wrapper)
232
  elif d_base_type == "alpaca":