Merge pull request #214 from OpenAccess-AI-Collective/fix-tokenizing-labels
Browse files
src/axolotl/prompt_strategies/alpaca_chat.py
CHANGED
@@ -20,11 +20,36 @@ def load(tokenizer, cfg):
|
|
20 |
|
21 |
class AlpacaConcisePrompter(AlpacaPrompter):
|
22 |
"""
|
23 |
-
Alpaca Prompter extending the system prompt to ask for concise answers
|
24 |
"""
|
25 |
|
26 |
-
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context.
|
27 |
-
system_no_input_prompt = "Below is an instruction that describes a task.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
@@ -64,7 +89,7 @@ def load_concise(tokenizer, cfg):
|
|
64 |
|
65 |
def load_qa(tokenizer, cfg):
|
66 |
return AlpacaQAPromptTokenizingStrategy(
|
67 |
-
|
68 |
tokenizer,
|
69 |
cfg.train_on_inputs,
|
70 |
cfg.sequence_len,
|
@@ -73,7 +98,7 @@ def load_qa(tokenizer, cfg):
|
|
73 |
|
74 |
def load_camel_ai(tokenizer, cfg):
|
75 |
return CamelAIPromptTokenizingStrategy(
|
76 |
-
|
77 |
tokenizer,
|
78 |
cfg.train_on_inputs,
|
79 |
cfg.sequence_len,
|
|
|
20 |
|
21 |
class AlpacaConcisePrompter(AlpacaPrompter):
|
22 |
"""
|
23 |
+
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
24 |
"""
|
25 |
|
26 |
+
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
27 |
+
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
28 |
+
|
29 |
+
|
30 |
+
class AlpacaChatPrompter(AlpacaPrompter):
|
31 |
+
"""
|
32 |
+
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
33 |
+
"""
|
34 |
+
|
35 |
+
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
36 |
+
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
37 |
+
|
38 |
+
def __init__(self): # pylint: disable=super-init-not-called
|
39 |
+
self.prompt_style = PromptStyle.CHAT.value
|
40 |
+
self.match_prompt_style()
|
41 |
+
|
42 |
+
|
43 |
+
class NoSystemPrompter(AlpacaPrompter):
|
44 |
+
"""
|
45 |
+
Null Prompter with no system prompts
|
46 |
+
"""
|
47 |
+
|
48 |
+
prompt_input = "{instruction} {input} "
|
49 |
+
prompt_no_input = "{instruction} "
|
50 |
+
|
51 |
+
def __init__(self): # pylint: disable=super-init-not-called
|
52 |
+
pass
|
53 |
|
54 |
|
55 |
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
|
89 |
|
90 |
def load_qa(tokenizer, cfg):
|
91 |
return AlpacaQAPromptTokenizingStrategy(
|
92 |
+
AlpacaChatPrompter(),
|
93 |
tokenizer,
|
94 |
cfg.train_on_inputs,
|
95 |
cfg.sequence_len,
|
|
|
98 |
|
99 |
def load_camel_ai(tokenizer, cfg):
|
100 |
return CamelAIPromptTokenizingStrategy(
|
101 |
+
AlpacaChatPrompter(),
|
102 |
tokenizer,
|
103 |
cfg.train_on_inputs,
|
104 |
cfg.sequence_len,
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
96 |
input, # pylint: disable=redefined-builtin
|
97 |
response,
|
98 |
) = self.parse_instruction_fields(prompt)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
self.prompter.build_prompt(
|
105 |
-
instruction,
|
106 |
-
input,
|
107 |
-
)
|
108 |
)
|
109 |
)
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
# TODO this could be sped up using numpy array slicing
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
return
|
118 |
|
119 |
def _build_full_prompt(
|
120 |
self, instruction, input, response # pylint: disable=redefined-builtin
|
|
|
96 |
input, # pylint: disable=redefined-builtin
|
97 |
response,
|
98 |
) = self.parse_instruction_fields(prompt)
|
99 |
+
user_prompt = next(
|
100 |
+
iter(
|
101 |
+
self.prompter.build_prompt(
|
102 |
+
instruction,
|
103 |
+
input,
|
|
|
|
|
|
|
|
|
104 |
)
|
105 |
)
|
106 |
+
)
|
107 |
+
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
108 |
+
if not self.train_on_inputs:
|
109 |
+
user_prompt_len = len(tokenized_prompt["input_ids"])
|
110 |
# TODO this could be sped up using numpy array slicing
|
111 |
+
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
112 |
+
tokenized_res_prompt = self._tokenize(
|
113 |
+
response, strip_bos_token=True, add_eos_token=True
|
114 |
+
)
|
115 |
+
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
116 |
+
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
117 |
+
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
118 |
|
119 |
+
return tokenized_prompt
|
120 |
|
121 |
def _build_full_prompt(
|
122 |
self, instruction, input, response # pylint: disable=redefined-builtin
|
tests/test_prompt_tokenizers.py
CHANGED
@@ -6,8 +6,12 @@ from pathlib import Path
|
|
6 |
|
7 |
from transformers import AutoTokenizer
|
8 |
|
9 |
-
from axolotl.
|
10 |
-
from axolotl.
|
|
|
|
|
|
|
|
|
11 |
|
12 |
logging.basicConfig(level="INFO")
|
13 |
|
@@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
29 |
)
|
30 |
|
31 |
def test_sharegpt_integration(self):
|
32 |
-
print(Path(__file__).parent)
|
33 |
with open(
|
34 |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
35 |
) as fin:
|
@@ -53,6 +56,45 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
53 |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
54 |
self.assertEqual(example[fields], tokenized_conversation[fields])
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
if __name__ == "__main__":
|
58 |
unittest.main()
|
|
|
6 |
|
7 |
from transformers import AutoTokenizer
|
8 |
|
9 |
+
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
10 |
+
from axolotl.prompt_tokenizers import (
|
11 |
+
AlpacaPromptTokenizingStrategy,
|
12 |
+
ShareGPTPromptTokenizingStrategy,
|
13 |
+
)
|
14 |
+
from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
|
15 |
|
16 |
logging.basicConfig(level="INFO")
|
17 |
|
|
|
33 |
)
|
34 |
|
35 |
def test_sharegpt_integration(self):
|
|
|
36 |
with open(
|
37 |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
38 |
) as fin:
|
|
|
56 |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
57 |
self.assertEqual(example[fields], tokenized_conversation[fields])
|
58 |
|
59 |
+
def test_no_sys_prompt(self):
|
60 |
+
"""
|
61 |
+
tests the interface between the user and assistant parts
|
62 |
+
"""
|
63 |
+
prompter = NoSystemPrompter()
|
64 |
+
# pylint: disable=duplicate-code
|
65 |
+
strat = AlpacaPromptTokenizingStrategy(
|
66 |
+
prompter,
|
67 |
+
self.tokenizer,
|
68 |
+
False,
|
69 |
+
2048,
|
70 |
+
)
|
71 |
+
sample = {
|
72 |
+
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
|
73 |
+
"output": "world!",
|
74 |
+
}
|
75 |
+
example = strat.tokenize_prompt(sample)
|
76 |
+
world_idx = example["input_ids"].index(3186)
|
77 |
+
assert example["labels"][world_idx] == 3186
|
78 |
+
assert example["labels"][world_idx - 1] == -100
|
79 |
+
|
80 |
+
def test_alpaca(self):
|
81 |
+
"""
|
82 |
+
tests the interface between the user and assistant parts
|
83 |
+
"""
|
84 |
+
# pylint: disable=duplicate-code
|
85 |
+
prompter = AlpacaPrompter()
|
86 |
+
strat = AlpacaPromptTokenizingStrategy(
|
87 |
+
prompter,
|
88 |
+
self.tokenizer,
|
89 |
+
False,
|
90 |
+
2048,
|
91 |
+
)
|
92 |
+
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
|
93 |
+
example = strat.tokenize_prompt(sample)
|
94 |
+
world_idx = example["input_ids"].index(6324)
|
95 |
+
assert example["labels"][world_idx] == 6324
|
96 |
+
assert example["labels"][world_idx - 1] == -100
|
97 |
+
|
98 |
|
99 |
if __name__ == "__main__":
|
100 |
unittest.main()
|