winglian commited on
Commit
3d4984b
·
unverified ·
1 Parent(s): ff7f18d

update prompts for open orca to match the paper (#317)

Browse files
src/axolotl/prompt_strategies/alpaca_w_system.py CHANGED
@@ -66,15 +66,34 @@ class SystemDataPrompter(AlpacaPrompter):
66
  ) -> Generator[str, None, None]:
67
  # returns the full prompt from instruction and optional input
68
  # if a label (=response, =output) is provided, it's also appended.
 
69
  if input:
70
- res = system + self.turn_format.format(instruction=instruction, input=input)
 
 
71
  else:
72
- res = system + self.turn_no_input_format.format(instruction=instruction)
 
 
73
  if output:
74
  res = f"{res}{output}"
75
  yield res
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
79
  """
80
  Tokenizing strategy for OpenOrca datasets
@@ -113,7 +132,7 @@ def load_chat(tokenizer, cfg):
113
 
114
  def load_open_orca(tokenizer, cfg):
115
  return OpenOrcaPromptTokenizingStrategy(
116
- SystemDataPrompter(PromptStyle.INSTRUCT.value),
117
  tokenizer,
118
  cfg.train_on_inputs,
119
  cfg.sequence_len,
 
66
  ) -> Generator[str, None, None]:
67
  # returns the full prompt from instruction and optional input
68
  # if a label (=response, =output) is provided, it's also appended.
69
+ formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
70
  if input:
71
+ res = formatted_sys_prompt + self.turn_format.format(
72
+ instruction=instruction, input=input
73
+ )
74
  else:
75
+ res = formatted_sys_prompt + self.turn_no_input_format.format(
76
+ instruction=instruction
77
+ )
78
  if output:
79
  res = f"{res}{output}"
80
  yield res
81
 
82
 
83
+ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
84
+ """
85
+ Alpaca Style Prompter that uses system prompts from the dataset, with OpenOrca prompts
86
+ """
87
+
88
+ def match_prompt_style(self):
89
+ if self.prompt_style == PromptStyle.INSTRUCT.value:
90
+ self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
91
+ self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
92
+ if self.prompt_style == PromptStyle.CHAT.value:
93
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
94
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
95
+
96
+
97
  class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
98
  """
99
  Tokenizing strategy for OpenOrca datasets
 
132
 
133
  def load_open_orca(tokenizer, cfg):
134
  return OpenOrcaPromptTokenizingStrategy(
135
+ OpenOrcaSystemDataPrompter(PromptStyle.INSTRUCT.value),
136
  tokenizer,
137
  cfg.train_on_inputs,
138
  cfg.sequence_len,
tests/test_prompt_tokenizers.py CHANGED
@@ -130,8 +130,9 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
130
  "output": "Hi! How can I help?",
131
  }
132
  example = strat.tokenize_prompt(sample)
133
- assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
134
- assert example["input_ids"][3] == 11889 # USER
 
135
 
136
 
137
  if __name__ == "__main__":
 
130
  "output": "Hi! How can I help?",
131
  }
132
  example = strat.tokenize_prompt(sample)
133
+ assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
134
+ assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
135
+ assert example["input_ids"][9] == 11889 # USER
136
 
137
 
138
  if __name__ == "__main__":
tests/test_prompters.py CHANGED
@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
70
  )
71
  )
72
  assert "use cot" in res
73
- assert res.startswith("use cot")
74
  assert "### Instruction:" not in res
75
  assert "### Input:" not in res
76
  assert "alpacas" in res
 
70
  )
71
  )
72
  assert "use cot" in res
73
+ assert res.startswith("### System:")
74
  assert "### Instruction:" not in res
75
  assert "### Input:" not in res
76
  assert "alpacas" in res