winglian commited on
Commit
9b8585d
1 Parent(s): 8eb5811

fix packing so that concatenated sequences reset the attention

Browse files
src/axolotl/datasets.py CHANGED
@@ -127,6 +127,11 @@ class ConstantLengthDataset(IterableDataset):
127
  input_ids = example["input_ids"]
128
  attention_mask = example["attention_mask"]
129
  labels = example["labels"]
 
 
 
 
 
130
 
131
  if add_concat_token:
132
  input_ids.append(self.concat_token_id)
 
127
  input_ids = example["input_ids"]
128
  attention_mask = example["attention_mask"]
129
  labels = example["labels"]
130
+ if (
131
+ buffer["input_ids"]
132
+ and input_ids[0] == self.tokenizer.bos_token_id
133
+ ):
134
+ attention_mask[0] = 0
135
 
136
  if add_concat_token:
137
  input_ids.append(self.concat_token_id)
tests/fixtures/alpaca/alpaca.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "You will be given a series of words. Output these words in reverse order, with each word on its own line.",
4
+ "input": "Words: ['Hello', 'world'].",
5
+ "output": "['world', 'Hello']"
6
+ },
7
+ {
8
+ "instruction": "In this task, you're given a short description of an event. Your job is to order the steps involved in the event from first to last. Note that there may be multiple correct answers for each event.",
9
+ "input": "Description: A man walks into a bar and orders a drink. He pays for his drink and leaves the bar.",
10
+ "output": "1. The man walks into the bar.\n2. He orders a drink.\n3. He pays for his drink.\n4. He leaves the bar."
11
+ }
12
+ ]
tests/test_packed_dataset.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for testing dataset sequence packing"""
2
+
3
+ import unittest
4
+ from pathlib import Path
5
+
6
+ from datasets import Dataset, load_dataset
7
+ from transformers import AutoTokenizer
8
+
9
+ from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
10
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
11
+ from axolotl.prompters import AlpacaPrompter
12
+
13
+
14
+ class TestPacking(unittest.TestCase):
15
+ """
16
+ Test class for packing dataset sequences
17
+ """
18
+
19
+ def setUp(self) -> None:
20
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
21
+ self.tokenizer.add_special_tokens(
22
+ {
23
+ "bos_token": "<s>",
24
+ "eos_token": "</s>",
25
+ "unk_token": "<unk>",
26
+ }
27
+ )
28
+
29
+ def test_resets_attention(self):
30
+ prompter = AlpacaPrompter("chat")
31
+ strat = AlpacaPromptTokenizingStrategy(
32
+ prompter,
33
+ self.tokenizer,
34
+ False,
35
+ 2048,
36
+ )
37
+ dateset = load_dataset(
38
+ "json",
39
+ data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
40
+ )["train"]
41
+ dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
42
+
43
+ constant_len_dataset = ConstantLengthDataset(
44
+ self.tokenizer,
45
+ [dataset],
46
+ seq_length=2048,
47
+ )
48
+ packed_dataset = Dataset.from_list(list(constant_len_dataset))
49
+ example = packed_dataset[0]
50
+ next_bos_index = (
51
+ example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
52
+ ) # add one since we sliced
53
+
54
+ # first example doesn't have mask reset
55
+ assert example["input_ids"][0] == self.tokenizer.bos_token_id
56
+ assert example["attention_mask"][0] == 1
57
+
58
+ # but subsequent one does
59
+ assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
60
+ assert example["attention_mask"][next_bos_index] == 0
61
+
62
+
63
+ if __name__ == "__main__":
64
+ unittest.main()