winglian commited on
Commit
8d20e0a
1 Parent(s): de8ed22

initial wip to get sys prompt from dataset

Browse files
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -45,8 +45,10 @@ class NoSystemPrompter(AlpacaPrompter):
45
  Null Prompter with no system prompts
46
  """
47
 
48
- prompt_input = "{instruction} {input} "
49
- prompt_no_input = "{instruction} "
 
 
50
 
51
  def __init__(self): # pylint: disable=super-init-not-called
52
  pass
 
45
  Null Prompter with no system prompts
46
  """
47
 
48
+ system_prompt = ""
49
+ system_no_input_prompt = ""
50
+ turn_format = "{instruction} {input} "
51
+ turn_no_input_format = "{instruction} "
52
 
53
  def __init__(self): # pylint: disable=super-init-not-called
54
  pass
src/axolotl/prompt_tokenizers.py CHANGED
@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
87
  Tokenizing strategy for instruction-based prompts.
88
  """
89
 
90
- def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
 
 
91
  raise NotImplementedError
92
 
93
  def tokenize_prompt(self, prompt):
 
87
  Tokenizing strategy for instruction-based prompts.
88
  """
89
 
90
+ def parse_instruction_fields(
91
+ self, prompt
92
+ ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
93
  raise NotImplementedError
94
 
95
  def tokenize_prompt(self, prompt):
src/axolotl/prompters.py CHANGED
@@ -24,6 +24,8 @@ class AlpacaPrompter:
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
 
 
27
  prompt_style: Optional[PromptStyle] = None
28
 
29
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
32
 
33
  def match_prompt_style(self):
34
  if self.prompt_style == PromptStyle.INSTRUCT.value:
35
- self.prompt_input = (
36
- self.system_prompt
37
- + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
38
- )
39
- self.prompt_no_input = (
40
- self.system_no_input_prompt
41
- + "### Instruction:\n{instruction}\n\n### Response:\n"
42
  )
43
- self.response_split = "### Response:"
44
  if self.prompt_style == PromptStyle.CHAT.value:
45
- self.prompt_input = (
46
- self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
47
- )
48
- self.prompt_no_input = (
49
- self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
50
- )
51
- self.response_split = "ASSISTANT:"
52
 
53
  def build_prompt(
54
  self,
@@ -59,15 +51,39 @@ class AlpacaPrompter:
59
  # returns the full prompt from instruction and optional input
60
  # if a label (=response, =output) is provided, it's also appended.
61
  if input:
62
- res = self.prompt_input.format(instruction=instruction, input=input)
 
 
63
  else:
64
- res = self.prompt_no_input.format(instruction=instruction)
 
 
65
  if output:
66
  res = f"{res}{output}"
67
  yield res
68
 
69
- def get_response(self, output: str) -> str:
70
- return output.split(self.response_split)[1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
  class UnpromptedPrompter(AlpacaPrompter):
@@ -93,7 +109,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
93
  """
94
 
95
  system_prompt = (
96
- "Choose the answer that best answers the question. Explain your reasoning."
 
 
 
97
  )
98
 
99
 
@@ -102,7 +121,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
102
  Prompter for multiple choice concise
103
  """
104
 
105
- prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
 
 
 
 
 
106
 
107
 
108
  class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +134,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
110
  Prompter for summarize TLDR
111
  """
112
 
113
- prompt_no_input = (
114
- "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
115
- )
 
 
 
116
 
117
 
118
  class CompletionPrompter:
@@ -128,9 +155,6 @@ class CompletionPrompter:
128
  ) -> Generator[str, None, None]:
129
  yield instruction
130
 
131
- def get_response(self, output: str) -> str:
132
- return output.strip()
133
-
134
 
135
  class GPTeacherPrompter(AlpacaPrompter):
136
  """
@@ -210,9 +234,6 @@ class ReflectAlpacaPrompter:
210
  res = f"{res}{label}"
211
  yield res
212
 
213
- def get_response(self, output: str) -> str:
214
- return output.split(self.response_split)[1].strip()
215
-
216
 
217
  class SeparatorStyle(Enum):
218
  """Different separator style."""
@@ -289,12 +310,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
289
  sep2=" ",
290
  )
291
 
292
- # def match_prompt_style(self):
293
- # if self.prompt_style == PromptStyle.chat.value:
294
- # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
295
- # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
296
- # self.response_split = "ASSISTANT:"
297
-
298
  def build_prompt(self, source) -> Generator[str, None, None]:
299
  # ignore the system prompt if provided
300
  if source[0]["from"] == "system":
 
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
27
+ turn_format: str
28
+ turn_no_input_format: str
29
  prompt_style: Optional[PromptStyle] = None
30
 
31
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
 
34
 
35
  def match_prompt_style(self):
36
  if self.prompt_style == PromptStyle.INSTRUCT.value:
37
+ self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
38
+ self.turn_no_input_format = (
39
+ "### Instruction:\n{instruction}\n\n### Response:\n"
 
 
 
 
40
  )
 
41
  if self.prompt_style == PromptStyle.CHAT.value:
42
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
43
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
 
 
 
 
 
44
 
45
  def build_prompt(
46
  self,
 
51
  # returns the full prompt from instruction and optional input
52
  # if a label (=response, =output) is provided, it's also appended.
53
  if input:
54
+ res = self.system_prompt + self.turn_format.format(
55
+ instruction=instruction, input=input
56
+ )
57
  else:
58
+ res = self.system_no_input_prompt + self.turn_no_input_format.format(
59
+ instruction=instruction
60
+ )
61
  if output:
62
  res = f"{res}{output}"
63
  yield res
64
 
65
+
66
+ class SystemDataPrompter(AlpacaPrompter):
67
+ """
68
+ Alpaca Style Prompter that uses system prompts from the dataset
69
+ """
70
+
71
+ def build_prompt_w_system(
72
+ self,
73
+ system: str,
74
+ instruction: str,
75
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
76
+ output: Union[None, str] = None,
77
+ ) -> Generator[str, None, None]:
78
+ # returns the full prompt from instruction and optional input
79
+ # if a label (=response, =output) is provided, it's also appended.
80
+ if input:
81
+ res = system + self.turn_format.format(instruction=instruction, input=input)
82
+ else:
83
+ res = system + self.turn_no_input_format.format(instruction=instruction)
84
+ if output:
85
+ res = f"{res}{output}"
86
+ yield res
87
 
88
 
89
  class UnpromptedPrompter(AlpacaPrompter):
 
109
  """
110
 
111
  system_prompt = (
112
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
113
+ )
114
+ system_no_input_prompt = (
115
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
116
  )
117
 
118
 
 
121
  Prompter for multiple choice concise
122
  """
123
 
124
+ system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
125
+ system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
126
+
127
+ def match_prompt_style(self):
128
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
129
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
130
 
131
 
132
  class SummarizeTLDRPrompter(AlpacaPrompter):
 
134
  Prompter for summarize TLDR
135
  """
136
 
137
+ system_prompt = ""
138
+ system_no_input_prompt = ""
139
+
140
+ def match_prompt_style(self):
141
+ self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
142
+ self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
143
 
144
 
145
  class CompletionPrompter:
 
155
  ) -> Generator[str, None, None]:
156
  yield instruction
157
 
 
 
 
158
 
159
  class GPTeacherPrompter(AlpacaPrompter):
160
  """
 
234
  res = f"{res}{label}"
235
  yield res
236
 
 
 
 
237
 
238
  class SeparatorStyle(Enum):
239
  """Different separator style."""
 
310
  sep2=" ",
311
  )
312
 
 
 
 
 
 
 
313
  def build_prompt(self, source) -> Generator[str, None, None]:
314
  # ignore the system prompt if provided
315
  if source[0]["from"] == "system":
tests/test_prompters.py CHANGED
@@ -2,7 +2,13 @@
2
 
3
  import unittest
4
 
5
- from axolotl.prompters import AlpacaPrompter, PromptStyle
 
 
 
 
 
 
6
 
7
 
8
  class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
55
  assert "### Response:" not in res
56
  assert "USER:" in res
57
  assert "ASSISTANT:" in res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import unittest
4
 
5
+ from axolotl.prompters import (
6
+ AlpacaPrompter,
7
+ MultipleChoiceExplainPrompter,
8
+ PromptStyle,
9
+ SystemDataPrompter,
10
+ UnpromptedPrompter,
11
+ )
12
 
13
 
14
  class AlpacaPrompterTest(unittest.TestCase):
 
61
  assert "### Response:" not in res
62
  assert "USER:" in res
63
  assert "ASSISTANT:" in res
64
+
65
+ def test_system_prompt(self):
66
+ prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
67
+ res = next(
68
+ prompter.build_prompt_w_system(
69
+ "use cot", "tell me a joke about the following", "alpacas"
70
+ )
71
+ )
72
+ assert "use cot" in res
73
+ assert res.startswith("use cot")
74
+ assert "### Instruction:" not in res
75
+ assert "### Input:" not in res
76
+ assert "alpacas" in res
77
+ assert "### Response:" not in res
78
+ assert "USER:" in res
79
+ assert "ASSISTANT:" in res
80
+
81
+
82
+ class UnpromptedPrompterTest(unittest.TestCase):
83
+ """
84
+ Test class for UnpromptedPrompter with no system prompts
85
+ """
86
+
87
+ def test_prompt_style_w_none(self):
88
+ prompter = UnpromptedPrompter(prompt_style=None)
89
+ res = next(prompter.build_prompt("tell me a joke"))
90
+ assert "### Instruction:" in res
91
+ assert "tell me a joke" in res
92
+ assert res.startswith("###")
93
+
94
+ def test_prompt_style_w_instruct(self):
95
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
96
+ res = next(
97
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
98
+ )
99
+ assert "### Instruction:" in res
100
+ assert "tell me a joke" in res
101
+ assert res.startswith("###")
102
+
103
+ def test_prompt_style_w_chat(self):
104
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
105
+ res = next(
106
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
107
+ )
108
+ assert "USER:" in res
109
+ assert "tell me a joke" in res
110
+ assert res.startswith("USER:")
111
+
112
+
113
+ class MultipleChoiceExplainPrompterTest(unittest.TestCase):
114
+ """
115
+ Test class for MultipleChoiceExplainPrompter
116
+ """
117
+
118
+ def test_prompt_style_w_chat(self):
119
+ prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
120
+ res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
121
+ assert "USER:" in res
122
+ assert "choose one" in res
123
+ assert "Choose the answer that best answers the question." in res
124
+ assert "- A\n- B\n- C" in res