fix packing so that concatenated sequences reset the attention
Browse files- src/axolotl/datasets.py +5 -0
- tests/fixtures/alpaca/alpaca.json +12 -0
- tests/test_packed_dataset.py +64 -0
src/axolotl/datasets.py
CHANGED
@@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset):
|
|
127 |
input_ids = example["input_ids"]
|
128 |
attention_mask = example["attention_mask"]
|
129 |
labels = example["labels"]
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
if add_concat_token:
|
132 |
input_ids.append(self.concat_token_id)
|
|
|
127 |
input_ids = example["input_ids"]
|
128 |
attention_mask = example["attention_mask"]
|
129 |
labels = example["labels"]
|
130 |
+
if (
|
131 |
+
buffer["input_ids"]
|
132 |
+
and input_ids[0] == self.tokenizer.bos_token_id
|
133 |
+
):
|
134 |
+
attention_mask[0] = 0
|
135 |
|
136 |
if add_concat_token:
|
137 |
input_ids.append(self.concat_token_id)
|
tests/fixtures/alpaca/alpaca.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
|
4 |
+
"input": "Words: ['Hello', 'world'].",
|
5 |
+
"output": "['world', 'Hello']"
|
6 |
+
},
|
7 |
+
{
|
8 |
+
"instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
|
9 |
+
"input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
|
10 |
+
"output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
|
11 |
+
}
|
12 |
+
]
|
tests/test_packed_dataset.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module for testing dataset sequence packing"""
|
2 |
+
|
3 |
+
import unittest
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from datasets import Dataset, load_dataset
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
10 |
+
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
11 |
+
from axolotl.prompters import AlpacaPrompter
|
12 |
+
|
13 |
+
|
14 |
+
class TestPacking(unittest.TestCase):
|
15 |
+
"""
|
16 |
+
Test class for packing dataset sequences
|
17 |
+
"""
|
18 |
+
|
19 |
+
def setUp(self) -> None:
|
20 |
+
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
21 |
+
self.tokenizer.add_special_tokens(
|
22 |
+
{
|
23 |
+
"bos_token": "<s>",
|
24 |
+
"eos_token": "</s>",
|
25 |
+
"unk_token": "<unk>",
|
26 |
+
}
|
27 |
+
)
|
28 |
+
|
29 |
+
def test_resets_attention(self):
|
30 |
+
prompter = AlpacaPrompter("chat")
|
31 |
+
strat = AlpacaPromptTokenizingStrategy(
|
32 |
+
prompter,
|
33 |
+
self.tokenizer,
|
34 |
+
False,
|
35 |
+
2048,
|
36 |
+
)
|
37 |
+
dateset = load_dataset(
|
38 |
+
"json",
|
39 |
+
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
|
40 |
+
)["train"]
|
41 |
+
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
|
42 |
+
|
43 |
+
constant_len_dataset = ConstantLengthDataset(
|
44 |
+
self.tokenizer,
|
45 |
+
[dataset],
|
46 |
+
seq_length=2048,
|
47 |
+
)
|
48 |
+
packed_dataset = Dataset.from_list(list(constant_len_dataset))
|
49 |
+
example = packed_dataset[0]
|
50 |
+
next_bos_index = (
|
51 |
+
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
|
52 |
+
) # add one since we sliced
|
53 |
+
|
54 |
+
# first example doesn't have mask reset
|
55 |
+
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
56 |
+
assert example["attention_mask"][0] == 1
|
57 |
+
|
58 |
+
# but subsequent one does
|
59 |
+
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
60 |
+
assert example["attention_mask"][next_bos_index] == 0
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
unittest.main()
|