Pretrain transforms (#1261)
Browse files* wip for pretraining/iterable data with arbitrary prompt strategies
* more fixes, wip
* more fixes for custom pretraining
* iterable ds wrapper not needed
* remove extra features
* chore: lint
* update pretraning example yml
* fix order for partials
* fixup for tests
- examples/tiny-llama/pretrain.yml +1 -0
- src/axolotl/datasets.py +1 -1
- src/axolotl/prompt_strategies/pretrain.py +58 -0
- src/axolotl/utils/data.py +49 -36
- tests/test_packed_pretraining.py +36 -25
examples/tiny-llama/pretrain.yml
CHANGED
@@ -12,6 +12,7 @@ max_steps: 200
|
|
12 |
pretraining_dataset:
|
13 |
path: c4
|
14 |
name: en
|
|
|
15 |
dataset_prepared_path:
|
16 |
val_set_size: 0.0
|
17 |
output_dir: ./model-out
|
|
|
12 |
pretraining_dataset:
|
13 |
path: c4
|
14 |
name: en
|
15 |
+
type: pretrain
|
16 |
dataset_prepared_path:
|
17 |
val_set_size: 0.0
|
18 |
output_dir: ./model-out
|
src/axolotl/datasets.py
CHANGED
@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
|
|
31 |
def __init__( # pylint: disable=super-init-not-called
|
32 |
self,
|
33 |
prompt_tokenizer: PromptTokenizingStrategy,
|
34 |
-
dataset:
|
35 |
process_count: Optional[int] = None,
|
36 |
keep_in_memory: Optional[bool] = False,
|
37 |
**kwargs,
|
|
|
31 |
def __init__( # pylint: disable=super-init-not-called
|
32 |
self,
|
33 |
prompt_tokenizer: PromptTokenizingStrategy,
|
34 |
+
dataset: Dataset,
|
35 |
process_count: Optional[int] = None,
|
36 |
keep_in_memory: Optional[bool] = False,
|
37 |
**kwargs,
|
src/axolotl/prompt_strategies/pretrain.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""pretraining prompt strategies"""
|
2 |
+
from typing import Generator
|
3 |
+
|
4 |
+
from transformers import BatchEncoding
|
5 |
+
|
6 |
+
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
7 |
+
|
8 |
+
|
9 |
+
class PretrainTokenizer:
|
10 |
+
"""basic tokenization class for pretraining"""
|
11 |
+
|
12 |
+
def build_prompt(self, prompt) -> Generator[str, None, None]:
|
13 |
+
yield prompt
|
14 |
+
|
15 |
+
|
16 |
+
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
17 |
+
"""handles tokenization for pretraining with strides"""
|
18 |
+
|
19 |
+
@property
|
20 |
+
def supports_batched(self):
|
21 |
+
return True
|
22 |
+
|
23 |
+
def __init__(self, *args, max_length=None, **kwargs):
|
24 |
+
super().__init__(*args, **kwargs)
|
25 |
+
if max_length:
|
26 |
+
self.max_length = max_length
|
27 |
+
|
28 |
+
def _tokenize(
|
29 |
+
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
30 |
+
) -> BatchEncoding:
|
31 |
+
res = self.tokenizer(
|
32 |
+
prompt,
|
33 |
+
truncation=True,
|
34 |
+
max_length=self.max_length - 1,
|
35 |
+
add_special_tokens=True,
|
36 |
+
return_overflowing_tokens=True,
|
37 |
+
stride=256,
|
38 |
+
)
|
39 |
+
res["input_ids"] = [
|
40 |
+
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
|
41 |
+
]
|
42 |
+
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
|
43 |
+
|
44 |
+
return res
|
45 |
+
|
46 |
+
def tokenize_prompt(self, prompt):
|
47 |
+
return self._tokenize(prompt["text"])
|
48 |
+
|
49 |
+
|
50 |
+
def load(tokenizer, cfg):
|
51 |
+
strat = PretrainTokenizationStrategy(
|
52 |
+
PretrainTokenizer(),
|
53 |
+
tokenizer,
|
54 |
+
cfg.train_on_inputs,
|
55 |
+
cfg.sequence_len,
|
56 |
+
max_length=cfg.sequence_len * 64,
|
57 |
+
)
|
58 |
+
return strat
|
src/axolotl/utils/data.py
CHANGED
@@ -4,7 +4,7 @@ import hashlib
|
|
4 |
import logging
|
5 |
from collections import defaultdict
|
6 |
from pathlib import Path
|
7 |
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
10 |
import yaml
|
@@ -88,12 +88,21 @@ def prepare_dataset(cfg, tokenizer):
|
|
88 |
path = cfg.pretraining_dataset[0]["path"]
|
89 |
name = cfg.pretraining_dataset[0]["name"]
|
90 |
|
91 |
-
|
92 |
-
|
|
|
93 |
tokenizer,
|
94 |
cfg,
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
max_tokens=cfg.sequence_len,
|
|
|
97 |
seed=cfg.seed or 42,
|
98 |
)
|
99 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
@@ -383,9 +392,9 @@ def load_tokenized_prepared_datasets(
|
|
383 |
|
384 |
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
385 |
config_dataset=config_dataset,
|
386 |
-
dataset=ds,
|
387 |
tokenizer=tokenizer,
|
388 |
cfg=cfg,
|
|
|
389 |
d_base_type=d_base_type,
|
390 |
d_prompt_style=d_prompt_style,
|
391 |
)
|
@@ -496,7 +505,12 @@ def load_prepare_datasets(
|
|
496 |
|
497 |
|
498 |
def get_dataset_wrapper(
|
499 |
-
config_dataset,
|
|
|
|
|
|
|
|
|
|
|
500 |
):
|
501 |
dataset_wrapper = None
|
502 |
dataset_prompter = None
|
@@ -507,7 +521,8 @@ def get_dataset_wrapper(
|
|
507 |
}
|
508 |
|
509 |
if (
|
510 |
-
|
|
|
511 |
and "attention_mask" in dataset.features
|
512 |
and "labels" in dataset.features
|
513 |
):
|
@@ -765,69 +780,60 @@ def encode_pretraining(
|
|
765 |
return ret
|
766 |
|
767 |
|
768 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
769 |
if cfg.sample_packing:
|
770 |
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
771 |
tokenizer,
|
772 |
return_tensors="pt",
|
773 |
padding=True,
|
774 |
-
pad_to_multiple_of=max_tokens *
|
775 |
)
|
776 |
encode = functools.partial(
|
777 |
encode_packed_pretraining,
|
778 |
-
tokenizer,
|
779 |
collate_fn,
|
|
|
780 |
max_seq_length=max_tokens,
|
781 |
-
batch_size=
|
782 |
)
|
783 |
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
784 |
cfg.micro_batch_size = 1
|
785 |
else:
|
786 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
787 |
|
788 |
-
dataset =
|
789 |
-
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
790 |
dataset = dataset.map(
|
791 |
encode,
|
792 |
batched=True,
|
793 |
-
batch_size=
|
794 |
-
input_columns="text",
|
795 |
# remove all the existing columns after mapping since they end up having
|
796 |
# a different length than the encoded/tokenized column
|
797 |
remove_columns=dataset.features.keys(),
|
798 |
-
desc="Encoding Pretraining",
|
799 |
)
|
800 |
return dataset
|
801 |
|
802 |
|
803 |
def encode_packed_pretraining(
|
804 |
-
tokenizer: PreTrainedTokenizerBase,
|
805 |
collate_fn,
|
806 |
-
|
|
|
807 |
max_seq_length: int = 2048,
|
808 |
batch_size: int = 4,
|
809 |
) -> Dict[str, List]:
|
810 |
# pylint: disable=duplicate-code
|
811 |
# tokenize all the examples
|
812 |
# rows get split with stride (overlap)
|
813 |
-
|
814 |
-
examples,
|
815 |
-
truncation=True,
|
816 |
-
max_length=max_seq_length - 1,
|
817 |
-
add_special_tokens=True,
|
818 |
-
return_overflowing_tokens=True,
|
819 |
-
stride=256,
|
820 |
-
)
|
821 |
-
|
822 |
-
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
|
823 |
-
attention_mask = [seq + [1] for seq in res["attention_mask"]]
|
824 |
-
|
825 |
-
tokenized_examples = {
|
826 |
-
"input_ids": input_ids,
|
827 |
-
"attention_mask": attention_mask,
|
828 |
-
}
|
829 |
|
830 |
-
train_dataset = Dataset.from_dict(tokenized_examples)
|
831 |
train_dataset = process_pretraining_datasets_for_packing(
|
832 |
train_dataset, max_seq_length
|
833 |
)
|
@@ -845,7 +851,14 @@ def encode_packed_pretraining(
|
|
845 |
for batch in sampler:
|
846 |
for data in batch:
|
847 |
features = train_dataset[data]
|
848 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
849 |
collated_features = collate_fn(features)
|
850 |
|
851 |
for feature in features.keys():
|
|
|
4 |
import logging
|
5 |
from collections import defaultdict
|
6 |
from pathlib import Path
|
7 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
10 |
import yaml
|
|
|
88 |
path = cfg.pretraining_dataset[0]["path"]
|
89 |
name = cfg.pretraining_dataset[0]["name"]
|
90 |
|
91 |
+
ds_wrapper_partial = functools.partial(
|
92 |
+
get_dataset_wrapper,
|
93 |
+
cfg.pretraining_dataset[0],
|
94 |
tokenizer,
|
95 |
cfg,
|
96 |
+
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
97 |
+
)
|
98 |
+
|
99 |
+
train_dataset = wrap_pretraining_dataset(
|
100 |
+
load_dataset(path, streaming=True, split="train", name=name),
|
101 |
+
tokenizer,
|
102 |
+
cfg,
|
103 |
+
ds_wrapper_partial,
|
104 |
max_tokens=cfg.sequence_len,
|
105 |
+
batch_size=cfg.micro_batch_size,
|
106 |
seed=cfg.seed or 42,
|
107 |
)
|
108 |
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
|
|
392 |
|
393 |
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
394 |
config_dataset=config_dataset,
|
|
|
395 |
tokenizer=tokenizer,
|
396 |
cfg=cfg,
|
397 |
+
dataset=ds,
|
398 |
d_base_type=d_base_type,
|
399 |
d_prompt_style=d_prompt_style,
|
400 |
)
|
|
|
505 |
|
506 |
|
507 |
def get_dataset_wrapper(
|
508 |
+
config_dataset,
|
509 |
+
tokenizer,
|
510 |
+
cfg,
|
511 |
+
d_base_type,
|
512 |
+
dataset,
|
513 |
+
d_prompt_style=None,
|
514 |
):
|
515 |
dataset_wrapper = None
|
516 |
dataset_prompter = None
|
|
|
521 |
}
|
522 |
|
523 |
if (
|
524 |
+
isinstance(dataset, Dataset)
|
525 |
+
and "input_ids" in dataset.features
|
526 |
and "attention_mask" in dataset.features
|
527 |
and "labels" in dataset.features
|
528 |
):
|
|
|
780 |
return ret
|
781 |
|
782 |
|
783 |
+
def wrap_pretraining_dataset(
|
784 |
+
dataset,
|
785 |
+
tokenizer,
|
786 |
+
cfg,
|
787 |
+
ds_wrapper_fn,
|
788 |
+
max_tokens=2048,
|
789 |
+
batch_size=1,
|
790 |
+
seed=42,
|
791 |
+
buffer_size=10_000,
|
792 |
+
):
|
793 |
if cfg.sample_packing:
|
794 |
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
795 |
tokenizer,
|
796 |
return_tensors="pt",
|
797 |
padding=True,
|
798 |
+
pad_to_multiple_of=max_tokens * batch_size,
|
799 |
)
|
800 |
encode = functools.partial(
|
801 |
encode_packed_pretraining,
|
|
|
802 |
collate_fn,
|
803 |
+
ds_wrapper_fn,
|
804 |
max_seq_length=max_tokens,
|
805 |
+
batch_size=batch_size,
|
806 |
)
|
807 |
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
808 |
cfg.micro_batch_size = 1
|
809 |
else:
|
810 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
811 |
|
812 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
|
|
813 |
dataset = dataset.map(
|
814 |
encode,
|
815 |
batched=True,
|
816 |
+
batch_size=buffer_size,
|
817 |
+
# input_columns="text",
|
818 |
# remove all the existing columns after mapping since they end up having
|
819 |
# a different length than the encoded/tokenized column
|
820 |
remove_columns=dataset.features.keys(),
|
|
|
821 |
)
|
822 |
return dataset
|
823 |
|
824 |
|
825 |
def encode_packed_pretraining(
|
|
|
826 |
collate_fn,
|
827 |
+
ds_wrapper: Callable,
|
828 |
+
examples: Dict[str, List],
|
829 |
max_seq_length: int = 2048,
|
830 |
batch_size: int = 4,
|
831 |
) -> Dict[str, List]:
|
832 |
# pylint: disable=duplicate-code
|
833 |
# tokenize all the examples
|
834 |
# rows get split with stride (overlap)
|
835 |
+
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
836 |
|
|
|
837 |
train_dataset = process_pretraining_datasets_for_packing(
|
838 |
train_dataset, max_seq_length
|
839 |
)
|
|
|
851 |
for batch in sampler:
|
852 |
for data in batch:
|
853 |
features = train_dataset[data]
|
854 |
+
if "num_truncated_tokens" in features:
|
855 |
+
del features["num_truncated_tokens"]
|
856 |
+
if "num_truncated_tokens" in features:
|
857 |
+
del features["num_truncated_tokens"]
|
858 |
+
if "overflow_to_sample_mapping" in features:
|
859 |
+
del features["overflow_to_sample_mapping"]
|
860 |
+
if "labels" not in features:
|
861 |
+
features["labels"] = features["input_ids"].copy()
|
862 |
collated_features = collate_fn(features)
|
863 |
|
864 |
for feature in features.keys():
|
tests/test_packed_pretraining.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
"""Module for testing streaming dataset sequence packing"""
|
|
|
2 |
import unittest
|
3 |
-
from functools import partial
|
4 |
|
5 |
import torch
|
6 |
from datasets import load_dataset
|
7 |
from torch.utils.data import DataLoader
|
8 |
from transformers import AutoTokenizer
|
9 |
|
10 |
-
from axolotl.utils.
|
11 |
-
from axolotl.utils.
|
12 |
|
13 |
|
14 |
class TestPretrainingPacking(unittest.TestCase):
|
@@ -20,8 +20,6 @@ class TestPretrainingPacking(unittest.TestCase):
|
|
20 |
# pylint: disable=duplicate-code
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
22 |
self.tokenizer.pad_token = "</s>"
|
23 |
-
self.max_seq_length = 2048
|
24 |
-
self.batch_size = 2
|
25 |
|
26 |
def test_packing_stream_dataset(self):
|
27 |
# pylint: disable=duplicate-code
|
@@ -31,30 +29,43 @@ class TestPretrainingPacking(unittest.TestCase):
|
|
31 |
streaming=True,
|
32 |
)["train"]
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
)
|
40 |
|
41 |
-
|
42 |
-
|
|
|
43 |
self.tokenizer,
|
44 |
-
|
45 |
-
|
46 |
-
batch_size=self.batch_size,
|
47 |
)
|
48 |
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
54 |
)
|
55 |
|
56 |
trainer_loader = DataLoader(
|
57 |
-
|
58 |
batch_size=1,
|
59 |
collate_fn=None,
|
60 |
drop_last=True,
|
@@ -64,16 +75,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
|
64 |
if idx > 10:
|
65 |
break
|
66 |
assert data["input_ids"].shape == torch.Size(
|
67 |
-
[1,
|
68 |
)
|
69 |
assert data["position_ids"].shape == torch.Size(
|
70 |
-
[1,
|
71 |
)
|
72 |
assert data["labels"].shape == torch.Size(
|
73 |
-
[1,
|
74 |
)
|
75 |
assert data["attention_mask"].shape == torch.Size(
|
76 |
-
[1,
|
77 |
)
|
78 |
idx += 1
|
79 |
|
|
|
1 |
"""Module for testing streaming dataset sequence packing"""
|
2 |
+
import functools
|
3 |
import unittest
|
|
|
4 |
|
5 |
import torch
|
6 |
from datasets import load_dataset
|
7 |
from torch.utils.data import DataLoader
|
8 |
from transformers import AutoTokenizer
|
9 |
|
10 |
+
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
11 |
+
from axolotl.utils.dict import DictDefault
|
12 |
|
13 |
|
14 |
class TestPretrainingPacking(unittest.TestCase):
|
|
|
20 |
# pylint: disable=duplicate-code
|
21 |
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
22 |
self.tokenizer.pad_token = "</s>"
|
|
|
|
|
23 |
|
24 |
def test_packing_stream_dataset(self):
|
25 |
# pylint: disable=duplicate-code
|
|
|
29 |
streaming=True,
|
30 |
)["train"]
|
31 |
|
32 |
+
cfg = DictDefault(
|
33 |
+
{
|
34 |
+
"pretraining_dataset": [
|
35 |
+
{
|
36 |
+
"path": "c4",
|
37 |
+
"name": "en",
|
38 |
+
"type": "pretrain",
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"sample_packing": True,
|
42 |
+
"pad_to_sequence_len": True,
|
43 |
+
"sequence_len": 2048,
|
44 |
+
"micro_batch_size": 2,
|
45 |
+
}
|
46 |
)
|
47 |
|
48 |
+
ds_wrapper_partial = functools.partial(
|
49 |
+
get_dataset_wrapper,
|
50 |
+
cfg.pretraining_dataset[0],
|
51 |
self.tokenizer,
|
52 |
+
cfg,
|
53 |
+
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
|
|
54 |
)
|
55 |
|
56 |
+
original_bsz = cfg.micro_batch_size
|
57 |
+
train_dataset = wrap_pretraining_dataset(
|
58 |
+
dataset,
|
59 |
+
self.tokenizer,
|
60 |
+
cfg,
|
61 |
+
ds_wrapper_partial,
|
62 |
+
max_tokens=cfg.sequence_len,
|
63 |
+
batch_size=cfg.micro_batch_size,
|
64 |
+
seed=cfg.seed or 42,
|
65 |
)
|
66 |
|
67 |
trainer_loader = DataLoader(
|
68 |
+
train_dataset,
|
69 |
batch_size=1,
|
70 |
collate_fn=None,
|
71 |
drop_last=True,
|
|
|
75 |
if idx > 10:
|
76 |
break
|
77 |
assert data["input_ids"].shape == torch.Size(
|
78 |
+
[1, original_bsz * cfg.sequence_len]
|
79 |
)
|
80 |
assert data["position_ids"].shape == torch.Size(
|
81 |
+
[1, original_bsz * cfg.sequence_len]
|
82 |
)
|
83 |
assert data["labels"].shape == torch.Size(
|
84 |
+
[1, original_bsz * cfg.sequence_len]
|
85 |
)
|
86 |
assert data["attention_mask"].shape == torch.Size(
|
87 |
+
[1, original_bsz * cfg.sequence_len]
|
88 |
)
|
89 |
idx += 1
|
90 |
|