winglian commited on
Commit
1365073
1 Parent(s): 8c2f3cb

concise multiple choice and tldr summarize

Browse files
src/axolotl/prompt_tokenizers.py CHANGED
@@ -97,7 +97,7 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
97
  return (
98
  prompt["question"],
99
  "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
100
- prompt["explanation"],
101
  )
102
 
103
 
@@ -119,6 +119,15 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
119
  )
120
 
121
 
 
 
 
 
 
 
 
 
 
122
  class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
123
  def parse_instruction_fields(self, prompt) -> (str, str, str):
124
  return (
 
97
  return (
98
  prompt["question"],
99
  "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
100
+ prompt["solution"] if "solution" in prompt else prompt["explanation"],
101
  )
102
 
103
 
 
119
  )
120
 
121
 
122
+ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
123
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
124
+ return (
125
+ prompt["article"],
126
+ "",
127
+ prompt["summary"],
128
+ )
129
+
130
+
131
  class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
132
  def parse_instruction_fields(self, prompt) -> (str, str, str):
133
  return (
src/axolotl/prompters.py CHANGED
@@ -39,6 +39,14 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
39
  prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n"
40
 
41
 
 
 
 
 
 
 
 
 
42
  class CompletionPrompter(AlpacaPrompter):
43
  def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
44
  yield instruction
 
39
  prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n"
40
 
41
 
42
+ class MultipleChoiceConcisePrompter(AlpacaPrompter):
43
+ prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
44
+
45
+
46
+ class SummarizeTLDRPrompter(AlpacaPrompter):
47
+ prompt_no_input = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
48
+
49
+
50
  class CompletionPrompter(AlpacaPrompter):
51
  def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
52
  yield instruction
src/axolotl/utils/data.py CHANGED
@@ -19,7 +19,9 @@ from axolotl.prompt_tokenizers import (
19
  AlpacaReflectionPTStrategy,
20
  ShareGPTPromptTokenizingStrategy,
21
  JeopardyPromptTokenizingStrategy,
22
- CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
 
 
23
  )
24
  from axolotl.prompters import (
25
  AlpacaPrompter,
@@ -27,7 +29,9 @@ from axolotl.prompters import (
27
  ReflectAlpacaPrompter,
28
  ShareGPTPrompter,
29
  JeopardyPrompter,
30
- CompletionPrompter, MultipleChoiceExplainPrompter,
 
 
31
  )
32
 
33
 
@@ -94,6 +98,18 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
94
  )
95
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
96
  datasets.append(ds_wrapper)
 
 
 
 
 
 
 
 
 
 
 
 
97
  elif d.type == "jeopardy":
98
  ds_strategy = JeopardyPromptTokenizingStrategy(
99
  JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
19
  AlpacaReflectionPTStrategy,
20
  ShareGPTPromptTokenizingStrategy,
21
  JeopardyPromptTokenizingStrategy,
22
+ CompletionPromptTokenizingStrategy,
23
+ AlpacaMultipleChoicePromptTokenizingStrategy,
24
+ SummarizeTLDRPromptTokenizingStrategy,
25
  )
26
  from axolotl.prompters import (
27
  AlpacaPrompter,
 
29
  ReflectAlpacaPrompter,
30
  ShareGPTPrompter,
31
  JeopardyPrompter,
32
+ CompletionPrompter,
33
+ MultipleChoiceExplainPrompter,
34
+ SummarizeTLDRPrompter, MultipleChoiceConcisePrompter,
35
  )
36
 
37
 
 
98
  )
99
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
100
  datasets.append(ds_wrapper)
101
+ elif d.type == "concisechoice":
102
+ ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
103
+ MultipleChoiceConcisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
104
+ )
105
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
106
+ datasets.append(ds_wrapper)
107
+ elif d.type == "summarizetldr":
108
+ ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
109
+ SummarizeTLDRPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
110
+ )
111
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
112
+ datasets.append(ds_wrapper)
113
  elif d.type == "jeopardy":
114
  ds_strategy = JeopardyPromptTokenizingStrategy(
115
  JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len