winglian commited on
Commit
3a38271
·
1 Parent(s): 8d20e0a

add tests and supoort for loader for sys prompt data

Browse files
src/axolotl/prompt_strategies/alpaca_w_system.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt strategies loader for alpaca instruction datasets with system prompts
3
+ """
4
+ from typing import Generator, Tuple, Union
5
+
6
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
8
+
9
+
10
+ class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
11
+ """
12
+ Tokenizing strategy for instruction-based prompts.
13
+ """
14
+
15
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
16
+ return (
17
+ prompt["instruction"],
18
+ prompt["input"] if "input" in prompt else "",
19
+ prompt["output"],
20
+ prompt["system"],
21
+ )
22
+
23
+ def tokenize_prompt(self, prompt):
24
+ (
25
+ instruction,
26
+ input, # pylint: disable=redefined-builtin
27
+ response,
28
+ system,
29
+ ) = self.parse_instruction_fields(prompt)
30
+ user_prompt = next(
31
+ iter(
32
+ self.prompter.build_prompt_w_system(
33
+ system,
34
+ instruction,
35
+ input,
36
+ )
37
+ )
38
+ )
39
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
40
+ if not self.train_on_inputs:
41
+ user_prompt_len = len(tokenized_prompt["input_ids"])
42
+ # TODO this could be sped up using numpy array slicing
43
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
44
+ tokenized_res_prompt = self._tokenize(
45
+ response, strip_bos_token=True, add_eos_token=True
46
+ )
47
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
48
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
49
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
50
+
51
+ return tokenized_prompt
52
+
53
+
54
+ class SystemDataPrompter(AlpacaPrompter):
55
+ """
56
+ Alpaca Style Prompter that uses system prompts from the dataset
57
+ """
58
+
59
+ def build_prompt_w_system(
60
+ self,
61
+ system: str,
62
+ instruction: str,
63
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
64
+ output: Union[None, str] = None,
65
+ ) -> Generator[str, None, None]:
66
+ # returns the full prompt from instruction and optional input
67
+ # if a label (=response, =output) is provided, it's also appended.
68
+ if input:
69
+ res = system + self.turn_format.format(instruction=instruction, input=input)
70
+ else:
71
+ res = system + self.turn_no_input_format.format(instruction=instruction)
72
+ if output:
73
+ res = f"{res}{output}"
74
+ yield res
75
+
76
+
77
+ def load(tokenizer, cfg):
78
+ return InstructionWSystemPromptTokenizingStrategy(
79
+ SystemDataPrompter(PromptStyle.CHAT.value),
80
+ tokenizer,
81
+ cfg.train_on_inputs,
82
+ cfg.sequence_len,
83
+ )
src/axolotl/prompters.py CHANGED
@@ -63,29 +63,6 @@ class AlpacaPrompter:
63
  yield res
64
 
65
 
66
- class SystemDataPrompter(AlpacaPrompter):
67
- """
68
- Alpaca Style Prompter that uses system prompts from the dataset
69
- """
70
-
71
- def build_prompt_w_system(
72
- self,
73
- system: str,
74
- instruction: str,
75
- input: Union[None, str] = None, # pylint: disable=redefined-builtin
76
- output: Union[None, str] = None,
77
- ) -> Generator[str, None, None]:
78
- # returns the full prompt from instruction and optional input
79
- # if a label (=response, =output) is provided, it's also appended.
80
- if input:
81
- res = system + self.turn_format.format(instruction=instruction, input=input)
82
- else:
83
- res = system + self.turn_no_input_format.format(instruction=instruction)
84
- if output:
85
- res = f"{res}{output}"
86
- yield res
87
-
88
-
89
  class UnpromptedPrompter(AlpacaPrompter):
90
  """
91
  Prompter for alpaca no system prompt
 
63
  yield res
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  class UnpromptedPrompter(AlpacaPrompter):
67
  """
68
  Prompter for alpaca no system prompt
src/axolotl/utils/tokenization.py CHANGED
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
34
 
35
  logging.info(" ".join(colored_tokens))
36
  logging.info("\n\n\n")
 
 
 
34
 
35
  logging.info(" ".join(colored_tokens))
36
  logging.info("\n\n\n")
37
+
38
+ return " ".join(colored_tokens)
tests/test_prompt_tokenizers.py CHANGED
@@ -7,11 +7,15 @@ from pathlib import Path
7
  from transformers import AutoTokenizer
8
 
9
  from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
 
 
 
 
10
  from axolotl.prompt_tokenizers import (
11
  AlpacaPromptTokenizingStrategy,
12
  ShareGPTPromptTokenizingStrategy,
13
  )
14
- from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
15
 
16
  logging.basicConfig(level="INFO")
17
 
@@ -96,5 +100,39 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
96
  assert example["labels"][world_idx - 1] == -100
97
 
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  if __name__ == "__main__":
100
  unittest.main()
 
7
  from transformers import AutoTokenizer
8
 
9
  from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
10
+ from axolotl.prompt_strategies.alpaca_w_system import (
11
+ InstructionWSystemPromptTokenizingStrategy,
12
+ SystemDataPrompter,
13
+ )
14
  from axolotl.prompt_tokenizers import (
15
  AlpacaPromptTokenizingStrategy,
16
  ShareGPTPromptTokenizingStrategy,
17
  )
18
+ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
19
 
20
  logging.basicConfig(level="INFO")
21
 
 
100
  assert example["labels"][world_idx - 1] == -100
101
 
102
 
103
+ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
104
+ """
105
+ Test class for prompt tokenization strategies with sys prompt from the dataset
106
+ """
107
+
108
+ def setUp(self) -> None:
109
+ # pylint: disable=duplicate-code
110
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
111
+ self.tokenizer.add_special_tokens(
112
+ {
113
+ "bos_token": "<s>",
114
+ "eos_token": "</s>",
115
+ "unk_token": "<unk>",
116
+ }
117
+ )
118
+
119
+ def test_system_alpaca(self):
120
+ prompter = SystemDataPrompter(PromptStyle.CHAT.value)
121
+ strat = InstructionWSystemPromptTokenizingStrategy(
122
+ prompter,
123
+ self.tokenizer,
124
+ False,
125
+ 2048,
126
+ )
127
+ sample = {
128
+ "system": "use cot",
129
+ "instruction": "hello!",
130
+ "output": "Hi! How can I help?",
131
+ }
132
+ example = strat.tokenize_prompt(sample)
133
+ assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
134
+ assert example["input_ids"][3] == 11889 # USER
135
+
136
+
137
  if __name__ == "__main__":
138
  unittest.main()
tests/test_prompters.py CHANGED
@@ -2,11 +2,11 @@
2
 
3
  import unittest
4
 
 
5
  from axolotl.prompters import (
6
  AlpacaPrompter,
7
  MultipleChoiceExplainPrompter,
8
  PromptStyle,
9
- SystemDataPrompter,
10
  UnpromptedPrompter,
11
  )
12
 
 
2
 
3
  import unittest
4
 
5
+ from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
6
  from axolotl.prompters import (
7
  AlpacaPrompter,
8
  MultipleChoiceExplainPrompter,
9
  PromptStyle,
 
10
  UnpromptedPrompter,
11
  )
12