winglian commited on
Commit
409ca0f
·
unverified ·
1 Parent(s): 8662e8f

add support for defined train split (#654)

Browse files
README.md CHANGED
@@ -250,6 +250,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
250
  ```json
251
  {"article": "...", "question": "...", "answer": "..."}
252
  ```
 
 
 
 
253
  - `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
254
  ```json
255
  {"article": "...", "unanswerable_question": "..."}
@@ -356,6 +360,12 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
356
  - path: data.jsonl # or json
357
  ds_type: json # see other options below
358
  type: alpaca
 
 
 
 
 
 
359
  ```
360
 
361
  - loading
 
250
  ```json
251
  {"article": "...", "question": "...", "answer": "..."}
252
  ```
253
+ - `context_qa.load_v2`: in context question answering (alternate)
254
+ ```json
255
+ {"context": "...", "question": "...", "answer": "..."}
256
+ ```
257
  - `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
258
  ```json
259
  {"article": "...", "unanswerable_question": "..."}
 
360
  - path: data.jsonl # or json
361
  ds_type: json # see other options below
362
  type: alpaca
363
+
364
+ # dataset with splits, but no train split
365
+ dataset:
366
+ - path: knowrohit07/know_sql
367
+ type: context_qa.load_v2
368
+ train_on_split: validation
369
  ```
370
 
371
  - loading
src/axolotl/prompt_strategies/context_qa.py CHANGED
@@ -24,6 +24,15 @@ def load(tokenizer, cfg):
24
  )
25
 
26
 
 
 
 
 
 
 
 
 
 
27
  class AlpacaContextPrompter(AlpacaPrompter):
28
  """
29
  Customized system prompted for concise QA
@@ -50,6 +59,38 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
50
  )
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  class AlpacaMissingInfoContextPromptTokenizingStrategy(
54
  InstructionPromptTokenizingStrategy
55
  ):
 
24
  )
25
 
26
 
27
+ def load_v2(tokenizer, cfg):
28
+ return ContextQaV2PromptTokenizingStrategy(
29
+ ContextV2Prompter(),
30
+ tokenizer,
31
+ cfg.train_on_inputs,
32
+ cfg.sequence_len,
33
+ )
34
+
35
+
36
  class AlpacaContextPrompter(AlpacaPrompter):
37
  """
38
  Customized system prompted for concise QA
 
59
  )
60
 
61
 
62
+ class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
63
+ """
64
+ Tokenization Strategy to combine in-context article with a question and answer
65
+ """
66
+
67
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
68
+ return (
69
+ "Context: "
70
+ + prompt["context"]
71
+ + "\nQuestion: "
72
+ + prompt["question"]
73
+ + "\n",
74
+ "",
75
+ "Answer: " + prompt["answer"],
76
+ )
77
+
78
+
79
+ class ContextV2Prompter(AlpacaPrompter):
80
+ """
81
+ Customized system prompted for concise QA
82
+ """
83
+
84
+ system_prompt = ""
85
+ system_no_input_prompt = ""
86
+
87
+ def match_prompt_style(self):
88
+ # pylint: disable=duplicate-code
89
+ self.turn_format = "{instruction}\n{input}"
90
+ self.turn_no_input_format = "{instruction}"
91
+ self.system_format = "{system}"
92
+
93
+
94
  class AlpacaMissingInfoContextPromptTokenizingStrategy(
95
  InstructionPromptTokenizingStrategy
96
  ):
src/axolotl/utils/data.py CHANGED
@@ -247,6 +247,16 @@ def load_tokenized_prepared_datasets(
247
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
248
  if "train" in ds:
249
  ds = ds["train"]
 
 
 
 
 
 
 
 
 
 
250
  if (
251
  "input_ids" in ds.features
252
  and "attention_mask" in ds.features
 
247
  d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
248
  if "train" in ds:
249
  ds = ds["train"]
250
+ elif (
251
+ isinstance(ds, DatasetDict)
252
+ and d.train_on_split
253
+ and d.train_on_split in ds
254
+ ):
255
+ ds = ds[d.train_on_split]
256
+ elif isinstance(ds, DatasetDict):
257
+ raise ValueError(
258
+ f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
259
+ )
260
  if (
261
  "input_ids" in ds.features
262
  and "attention_mask" in ds.features