winglian commited on
Commit
1925eaf
2 Parent(s): 6f84980 1ab3bf3

Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels

Browse files
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -20,11 +20,36 @@ def load(tokenizer, cfg):
20
 
21
  class AlpacaConcisePrompter(AlpacaPrompter):
22
  """
23
- Alpaca Prompter extending the system prompt to ask for concise answers
24
  """
25
 
26
- system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
27
- system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
@@ -64,7 +89,7 @@ def load_concise(tokenizer, cfg):
64
 
65
  def load_qa(tokenizer, cfg):
66
  return AlpacaQAPromptTokenizingStrategy(
67
- AlpacaPrompter(PromptStyle.CHAT.value),
68
  tokenizer,
69
  cfg.train_on_inputs,
70
  cfg.sequence_len,
@@ -73,7 +98,7 @@ def load_qa(tokenizer, cfg):
73
 
74
  def load_camel_ai(tokenizer, cfg):
75
  return CamelAIPromptTokenizingStrategy(
76
- AlpacaPrompter(PromptStyle.CHAT.value),
77
  tokenizer,
78
  cfg.train_on_inputs,
79
  cfg.sequence_len,
 
20
 
21
  class AlpacaConcisePrompter(AlpacaPrompter):
22
  """
23
+ Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
24
  """
25
 
26
+ system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
27
+ system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
28
+
29
+
30
+ class AlpacaChatPrompter(AlpacaPrompter):
31
+ """
32
+ Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
33
+ """
34
+
35
+ system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
36
+ system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
37
+
38
+ def __init__(self): # pylint: disable=super-init-not-called
39
+ self.prompt_style = PromptStyle.CHAT.value
40
+ self.match_prompt_style()
41
+
42
+
43
+ class NoSystemPrompter(AlpacaPrompter):
44
+ """
45
+ Null Prompter with no system prompts
46
+ """
47
+
48
+ prompt_input = "{instruction} {input} "
49
+ prompt_no_input = "{instruction} "
50
+
51
+ def __init__(self): # pylint: disable=super-init-not-called
52
+ pass
53
 
54
 
55
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
 
89
 
90
  def load_qa(tokenizer, cfg):
91
  return AlpacaQAPromptTokenizingStrategy(
92
+ AlpacaChatPrompter(),
93
  tokenizer,
94
  cfg.train_on_inputs,
95
  cfg.sequence_len,
 
98
 
99
  def load_camel_ai(tokenizer, cfg):
100
  return CamelAIPromptTokenizingStrategy(
101
+ AlpacaChatPrompter(),
102
  tokenizer,
103
  cfg.train_on_inputs,
104
  cfg.sequence_len,
src/axolotl/prompt_tokenizers.py CHANGED
@@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
96
  input, # pylint: disable=redefined-builtin
97
  response,
98
  ) = self.parse_instruction_fields(prompt)
99
- full_prompt = self._build_full_prompt(instruction, input, response)
100
- tokenized_full_prompt = self._tokenize(full_prompt)
101
- if not self.train_on_inputs:
102
- user_prompt = next(
103
- iter(
104
- self.prompter.build_prompt(
105
- instruction,
106
- input,
107
- )
108
  )
109
  )
110
- tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
111
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
 
112
  # TODO this could be sped up using numpy array slicing
113
- tokenized_full_prompt["labels"] = [
114
- -100
115
- ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
 
 
 
 
116
 
117
- return tokenized_full_prompt
118
 
119
  def _build_full_prompt(
120
  self, instruction, input, response # pylint: disable=redefined-builtin
 
96
  input, # pylint: disable=redefined-builtin
97
  response,
98
  ) = self.parse_instruction_fields(prompt)
99
+ user_prompt = next(
100
+ iter(
101
+ self.prompter.build_prompt(
102
+ instruction,
103
+ input,
 
 
 
 
104
  )
105
  )
106
+ )
107
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
108
+ if not self.train_on_inputs:
109
+ user_prompt_len = len(tokenized_prompt["input_ids"])
110
  # TODO this could be sped up using numpy array slicing
111
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
112
+ tokenized_res_prompt = self._tokenize(
113
+ response, strip_bos_token=True, add_eos_token=True
114
+ )
115
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
116
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
117
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
118
 
119
+ return tokenized_prompt
120
 
121
  def _build_full_prompt(
122
  self, instruction, input, response # pylint: disable=redefined-builtin
tests/test_prompt_tokenizers.py CHANGED
@@ -6,8 +6,12 @@ from pathlib import Path
6
 
7
  from transformers import AutoTokenizer
8
 
9
- from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
10
- from axolotl.prompters import ShareGPTPrompter
 
 
 
 
11
 
12
  logging.basicConfig(level="INFO")
13
 
@@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
29
  )
30
 
31
  def test_sharegpt_integration(self):
32
- print(Path(__file__).parent)
33
  with open(
34
  Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
35
  ) as fin:
@@ -53,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
53
  self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
54
  self.assertEqual(example[fields], tokenized_conversation[fields])
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  if __name__ == "__main__":
58
  unittest.main()
 
6
 
7
  from transformers import AutoTokenizer
8
 
9
+ from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
10
+ from axolotl.prompt_tokenizers import (
11
+ AlpacaPromptTokenizingStrategy,
12
+ ShareGPTPromptTokenizingStrategy,
13
+ )
14
+ from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
15
 
16
  logging.basicConfig(level="INFO")
17
 
 
33
  )
34
 
35
  def test_sharegpt_integration(self):
 
36
  with open(
37
  Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
38
  ) as fin:
 
56
  self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
57
  self.assertEqual(example[fields], tokenized_conversation[fields])
58
 
59
+ def test_no_sys_prompt(self):
60
+ """
61
+ tests the interface between the user and assistant parts
62
+ """
63
+ prompter = NoSystemPrompter()
64
+ # pylint: disable=duplicate-code
65
+ strat = AlpacaPromptTokenizingStrategy(
66
+ prompter,
67
+ self.tokenizer,
68
+ False,
69
+ 2048,
70
+ )
71
+ sample = {
72
+ "instruction": "hello cruel. lorem ipsum dolor sit amet.",
73
+ "output": "world!",
74
+ }
75
+ example = strat.tokenize_prompt(sample)
76
+ world_idx = example["input_ids"].index(3186)
77
+ assert example["labels"][world_idx] == 3186
78
+ assert example["labels"][world_idx - 1] == -100
79
+
80
+ def test_alpaca(self):
81
+ """
82
+ tests the interface between the user and assistant parts
83
+ """
84
+ # pylint: disable=duplicate-code
85
+ prompter = AlpacaPrompter()
86
+ strat = AlpacaPromptTokenizingStrategy(
87
+ prompter,
88
+ self.tokenizer,
89
+ False,
90
+ 2048,
91
+ )
92
+ sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
93
+ example = strat.tokenize_prompt(sample)
94
+ world_idx = example["input_ids"].index(6324)
95
+ assert example["labels"][world_idx] == 6324
96
+ assert example["labels"][world_idx - 1] == -100
97
+
98
 
99
  if __name__ == "__main__":
100
  unittest.main()