winglian commited on
Commit
4d09b42
1 Parent(s): b5b4492

plain input/output prompt strategy w/o chat templates (#1346)

Browse files

* plain input/output prompt strategy w/o chat templates

* disable duplicate code check

* make sure to add an eos/eot token to the end of the output so it will stop

* multi turn segement support and test

src/axolotl/prompt_strategies/input_output.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for plain input/output prompt pairs"""
2
+ from typing import Generator, Tuple
3
+
4
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
5
+ from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
6
+
7
+
8
+ class RawInputOutputStrategy(PromptTokenizingStrategy):
9
+ """Prompt Strategy class for input/output pairs"""
10
+
11
+ def __init__(self, *args, eos_token=None, **kwargs):
12
+ super().__init__(*args, **kwargs)
13
+ self.eos_token = eos_token
14
+ if not eos_token:
15
+ self.eos_token = self.tokenizer.eos_token
16
+
17
+ def tokenize_prompt(self, prompt):
18
+ # pylint: disable=duplicate-code
19
+ input_ids = []
20
+ labels = []
21
+ for label, text in self.prompter.build_prompt(prompt["segments"]):
22
+ tokenized_output = self.tokenizer(
23
+ text, add_special_tokens=False, return_tensors=None
24
+ )["input_ids"]
25
+ input_ids += tokenized_output
26
+ if label or self.train_on_inputs:
27
+ labels += tokenized_output
28
+ else:
29
+ labels += [IGNORE_TOKEN_ID] * len(tokenized_output)
30
+
31
+ tokenized_prompt = {
32
+ "input_ids": input_ids,
33
+ "labels": labels,
34
+ "attention_mask": [1] * len(input_ids),
35
+ }
36
+
37
+ return tokenized_prompt
38
+
39
+
40
+ class RawInputOutputPrompter(Prompter):
41
+ """prompter for raw i/o data"""
42
+
43
+ def build_prompt(self, source) -> Generator[Tuple[bool, str], None, None]:
44
+ for segment in source:
45
+ yield segment["label"], segment["text"]
46
+
47
+
48
+ def load(tokenizer, cfg):
49
+ return RawInputOutputStrategy(
50
+ RawInputOutputPrompter(),
51
+ tokenizer,
52
+ cfg.train_on_inputs,
53
+ cfg.sequence_len,
54
+ )
tests/prompt_strategies/test_raw_io.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test module for raw i/o data for prompts
3
+ """
4
+ import pytest
5
+ from datasets import Dataset
6
+ from tokenizers import AddedToken
7
+ from transformers import AutoTokenizer
8
+
9
+ from axolotl.datasets import TokenizedPromptDataset
10
+ from axolotl.prompt_strategies.input_output import (
11
+ RawInputOutputPrompter,
12
+ RawInputOutputStrategy,
13
+ )
14
+
15
+
16
+ @pytest.fixture(name="segments_dataset")
17
+ def fixture_sharegpt_dataset():
18
+ return Dataset.from_list(
19
+ [
20
+ {
21
+ "segments": [
22
+ {
23
+ "label": False,
24
+ "text": "<s>hello ",
25
+ },
26
+ {
27
+ "label": True,
28
+ "text": "hi there.<eot>",
29
+ },
30
+ {
31
+ "label": False,
32
+ "text": "goodbye ",
33
+ },
34
+ {
35
+ "label": True,
36
+ "text": "farewell<eot>",
37
+ },
38
+ ]
39
+ }
40
+ ]
41
+ )
42
+
43
+
44
+ @pytest.fixture(name="tokenizer")
45
+ def fixture_tokenizer():
46
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
47
+ tokenizer.add_tokens(
48
+ [
49
+ AddedToken("<eot>", rstrip=False, lstrip=False, normalized=False),
50
+ ]
51
+ )
52
+
53
+ return tokenizer
54
+
55
+
56
+ class TestRawInputOutputPrompts:
57
+ """
58
+ Test class for raw i/o prompter
59
+ """
60
+
61
+ def test_segment_prompts(self, segments_dataset, tokenizer):
62
+ strategy = RawInputOutputStrategy(
63
+ RawInputOutputPrompter(),
64
+ tokenizer,
65
+ False, # train_on_inputs
66
+ 2048, # sequence_len
67
+ )
68
+
69
+ dataset_wrapper = TokenizedPromptDataset(
70
+ strategy, segments_dataset, process_count=1
71
+ )
72
+
73
+ input_ids = dataset_wrapper[0]["input_ids"]
74
+ labels = dataset_wrapper[0]["labels"]
75
+
76
+ assert (
77
+ tokenizer.decode(input_ids)
78
+ == "<s> hello hi there.<eot> goodbye farewell<eot>"
79
+ )
80
+ # fmt: off
81
+ assert input_ids == [
82
+ 1, # <s>
83
+ 6312, # hell
84
+ 28709, # o
85
+ 28705, #
86
+ 12014, # hi
87
+ 736, # there
88
+ 28723, # .
89
+ 32000, # <eot>
90
+ 1179, # good
91
+ 17664, # bye
92
+ 28705, #
93
+ 19111, # fare
94
+ 5458, # well
95
+ 32000, # <eot>
96
+ ]
97
+ # fmt: on
98
+
99
+ # fmt: off
100
+ assert labels == [
101
+ -100, # <s>
102
+ -100, # hell
103
+ -100, # o
104
+ -100, #
105
+ 12014, # hi
106
+ 736, # there
107
+ 28723, # .
108
+ 32000, # <eot>
109
+ -100, # good
110
+ -100, # bye
111
+ -100, #
112
+ 19111, # fare
113
+ 5458, # well
114
+ 32000, # <eot>
115
+ ]
116
+ # fmt: on