winglian commited on
Commit
c7cf381
1 Parent(s): 8c2e05a

Pretrain transforms (#1261)

Browse files

* wip for pretraining/iterable data with arbitrary prompt strategies

* more fixes, wip

* more fixes for custom pretraining

* iterable ds wrapper not needed

* remove extra features

* chore: lint

* update pretraning example yml

* fix order for partials

* fixup for tests

examples/tiny-llama/pretrain.yml CHANGED
@@ -12,6 +12,7 @@ max_steps: 200
12
  pretraining_dataset:
13
  path: c4
14
  name: en
 
15
  dataset_prepared_path:
16
  val_set_size: 0.0
17
  output_dir: ./model-out
 
12
  pretraining_dataset:
13
  path: c4
14
  name: en
15
+ type: pretrain
16
  dataset_prepared_path:
17
  val_set_size: 0.0
18
  output_dir: ./model-out
src/axolotl/datasets.py CHANGED
@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
31
  def __init__( # pylint: disable=super-init-not-called
32
  self,
33
  prompt_tokenizer: PromptTokenizingStrategy,
34
- dataset: IterableDataset,
35
  process_count: Optional[int] = None,
36
  keep_in_memory: Optional[bool] = False,
37
  **kwargs,
 
31
  def __init__( # pylint: disable=super-init-not-called
32
  self,
33
  prompt_tokenizer: PromptTokenizingStrategy,
34
+ dataset: Dataset,
35
  process_count: Optional[int] = None,
36
  keep_in_memory: Optional[bool] = False,
37
  **kwargs,
src/axolotl/prompt_strategies/pretrain.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """pretraining prompt strategies"""
2
+ from typing import Generator
3
+
4
+ from transformers import BatchEncoding
5
+
6
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
+
8
+
9
+ class PretrainTokenizer:
10
+ """basic tokenization class for pretraining"""
11
+
12
+ def build_prompt(self, prompt) -> Generator[str, None, None]:
13
+ yield prompt
14
+
15
+
16
+ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
17
+ """handles tokenization for pretraining with strides"""
18
+
19
+ @property
20
+ def supports_batched(self):
21
+ return True
22
+
23
+ def __init__(self, *args, max_length=None, **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ if max_length:
26
+ self.max_length = max_length
27
+
28
+ def _tokenize(
29
+ self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
30
+ ) -> BatchEncoding:
31
+ res = self.tokenizer(
32
+ prompt,
33
+ truncation=True,
34
+ max_length=self.max_length - 1,
35
+ add_special_tokens=True,
36
+ return_overflowing_tokens=True,
37
+ stride=256,
38
+ )
39
+ res["input_ids"] = [
40
+ seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
41
+ ]
42
+ res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
43
+
44
+ return res
45
+
46
+ def tokenize_prompt(self, prompt):
47
+ return self._tokenize(prompt["text"])
48
+
49
+
50
+ def load(tokenizer, cfg):
51
+ strat = PretrainTokenizationStrategy(
52
+ PretrainTokenizer(),
53
+ tokenizer,
54
+ cfg.train_on_inputs,
55
+ cfg.sequence_len,
56
+ max_length=cfg.sequence_len * 64,
57
+ )
58
+ return strat
src/axolotl/utils/data.py CHANGED
@@ -4,7 +4,7 @@ import hashlib
4
  import logging
5
  from collections import defaultdict
6
  from pathlib import Path
7
- from typing import Any, Dict, List, Optional, Tuple, Union
8
 
9
  import torch
10
  import yaml
@@ -88,12 +88,21 @@ def prepare_dataset(cfg, tokenizer):
88
  path = cfg.pretraining_dataset[0]["path"]
89
  name = cfg.pretraining_dataset[0]["name"]
90
 
91
- train_dataset = load_pretraining_dataset(
92
- path,
 
93
  tokenizer,
94
  cfg,
95
- name=name,
 
 
 
 
 
 
 
96
  max_tokens=cfg.sequence_len,
 
97
  seed=cfg.seed or 42,
98
  )
99
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
@@ -383,9 +392,9 @@ def load_tokenized_prepared_datasets(
383
 
384
  dataset_wrapper, dataset_prompter = get_dataset_wrapper(
385
  config_dataset=config_dataset,
386
- dataset=ds,
387
  tokenizer=tokenizer,
388
  cfg=cfg,
 
389
  d_base_type=d_base_type,
390
  d_prompt_style=d_prompt_style,
391
  )
@@ -496,7 +505,12 @@ def load_prepare_datasets(
496
 
497
 
498
  def get_dataset_wrapper(
499
- config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
 
 
 
 
 
500
  ):
501
  dataset_wrapper = None
502
  dataset_prompter = None
@@ -507,7 +521,8 @@ def get_dataset_wrapper(
507
  }
508
 
509
  if (
510
- "input_ids" in dataset.features
 
511
  and "attention_mask" in dataset.features
512
  and "labels" in dataset.features
513
  ):
@@ -765,69 +780,60 @@ def encode_pretraining(
765
  return ret
766
 
767
 
768
- def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
 
 
 
 
 
 
 
 
 
769
  if cfg.sample_packing:
770
  collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
771
  tokenizer,
772
  return_tensors="pt",
773
  padding=True,
774
- pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
775
  )
776
  encode = functools.partial(
777
  encode_packed_pretraining,
778
- tokenizer,
779
  collate_fn,
 
780
  max_seq_length=max_tokens,
781
- batch_size=cfg.micro_batch_size,
782
  )
783
  # set this to 1 so downstream data_loader doesn't try to increase the batch again
784
  cfg.micro_batch_size = 1
785
  else:
786
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
787
 
788
- dataset = load_dataset(path, streaming=True, split="train", name=name)
789
- dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
790
  dataset = dataset.map(
791
  encode,
792
  batched=True,
793
- batch_size=10_000,
794
- input_columns="text",
795
  # remove all the existing columns after mapping since they end up having
796
  # a different length than the encoded/tokenized column
797
  remove_columns=dataset.features.keys(),
798
- desc="Encoding Pretraining",
799
  )
800
  return dataset
801
 
802
 
803
  def encode_packed_pretraining(
804
- tokenizer: PreTrainedTokenizerBase,
805
  collate_fn,
806
- examples: List[str],
 
807
  max_seq_length: int = 2048,
808
  batch_size: int = 4,
809
  ) -> Dict[str, List]:
810
  # pylint: disable=duplicate-code
811
  # tokenize all the examples
812
  # rows get split with stride (overlap)
813
- res = tokenizer(
814
- examples,
815
- truncation=True,
816
- max_length=max_seq_length - 1,
817
- add_special_tokens=True,
818
- return_overflowing_tokens=True,
819
- stride=256,
820
- )
821
-
822
- input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
823
- attention_mask = [seq + [1] for seq in res["attention_mask"]]
824
-
825
- tokenized_examples = {
826
- "input_ids": input_ids,
827
- "attention_mask": attention_mask,
828
- }
829
 
830
- train_dataset = Dataset.from_dict(tokenized_examples)
831
  train_dataset = process_pretraining_datasets_for_packing(
832
  train_dataset, max_seq_length
833
  )
@@ -845,7 +851,14 @@ def encode_packed_pretraining(
845
  for batch in sampler:
846
  for data in batch:
847
  features = train_dataset[data]
848
- features["labels"] = features["input_ids"].copy()
 
 
 
 
 
 
 
849
  collated_features = collate_fn(features)
850
 
851
  for feature in features.keys():
 
4
  import logging
5
  from collections import defaultdict
6
  from pathlib import Path
7
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8
 
9
  import torch
10
  import yaml
 
88
  path = cfg.pretraining_dataset[0]["path"]
89
  name = cfg.pretraining_dataset[0]["name"]
90
 
91
+ ds_wrapper_partial = functools.partial(
92
+ get_dataset_wrapper,
93
+ cfg.pretraining_dataset[0],
94
  tokenizer,
95
  cfg,
96
+ cfg.pretraining_dataset[0]["type"] or "pretrain",
97
+ )
98
+
99
+ train_dataset = wrap_pretraining_dataset(
100
+ load_dataset(path, streaming=True, split="train", name=name),
101
+ tokenizer,
102
+ cfg,
103
+ ds_wrapper_partial,
104
  max_tokens=cfg.sequence_len,
105
+ batch_size=cfg.micro_batch_size,
106
  seed=cfg.seed or 42,
107
  )
108
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
 
392
 
393
  dataset_wrapper, dataset_prompter = get_dataset_wrapper(
394
  config_dataset=config_dataset,
 
395
  tokenizer=tokenizer,
396
  cfg=cfg,
397
+ dataset=ds,
398
  d_base_type=d_base_type,
399
  d_prompt_style=d_prompt_style,
400
  )
 
505
 
506
 
507
  def get_dataset_wrapper(
508
+ config_dataset,
509
+ tokenizer,
510
+ cfg,
511
+ d_base_type,
512
+ dataset,
513
+ d_prompt_style=None,
514
  ):
515
  dataset_wrapper = None
516
  dataset_prompter = None
 
521
  }
522
 
523
  if (
524
+ isinstance(dataset, Dataset)
525
+ and "input_ids" in dataset.features
526
  and "attention_mask" in dataset.features
527
  and "labels" in dataset.features
528
  ):
 
780
  return ret
781
 
782
 
783
+ def wrap_pretraining_dataset(
784
+ dataset,
785
+ tokenizer,
786
+ cfg,
787
+ ds_wrapper_fn,
788
+ max_tokens=2048,
789
+ batch_size=1,
790
+ seed=42,
791
+ buffer_size=10_000,
792
+ ):
793
  if cfg.sample_packing:
794
  collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
795
  tokenizer,
796
  return_tensors="pt",
797
  padding=True,
798
+ pad_to_multiple_of=max_tokens * batch_size,
799
  )
800
  encode = functools.partial(
801
  encode_packed_pretraining,
 
802
  collate_fn,
803
+ ds_wrapper_fn,
804
  max_seq_length=max_tokens,
805
+ batch_size=batch_size,
806
  )
807
  # set this to 1 so downstream data_loader doesn't try to increase the batch again
808
  cfg.micro_batch_size = 1
809
  else:
810
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
811
 
812
+ dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
 
813
  dataset = dataset.map(
814
  encode,
815
  batched=True,
816
+ batch_size=buffer_size,
817
+ # input_columns="text",
818
  # remove all the existing columns after mapping since they end up having
819
  # a different length than the encoded/tokenized column
820
  remove_columns=dataset.features.keys(),
 
821
  )
822
  return dataset
823
 
824
 
825
  def encode_packed_pretraining(
 
826
  collate_fn,
827
+ ds_wrapper: Callable,
828
+ examples: Dict[str, List],
829
  max_seq_length: int = 2048,
830
  batch_size: int = 4,
831
  ) -> Dict[str, List]:
832
  # pylint: disable=duplicate-code
833
  # tokenize all the examples
834
  # rows get split with stride (overlap)
835
+ train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
836
 
 
837
  train_dataset = process_pretraining_datasets_for_packing(
838
  train_dataset, max_seq_length
839
  )
 
851
  for batch in sampler:
852
  for data in batch:
853
  features = train_dataset[data]
854
+ if "num_truncated_tokens" in features:
855
+ del features["num_truncated_tokens"]
856
+ if "num_truncated_tokens" in features:
857
+ del features["num_truncated_tokens"]
858
+ if "overflow_to_sample_mapping" in features:
859
+ del features["overflow_to_sample_mapping"]
860
+ if "labels" not in features:
861
+ features["labels"] = features["input_ids"].copy()
862
  collated_features = collate_fn(features)
863
 
864
  for feature in features.keys():
tests/test_packed_pretraining.py CHANGED
@@ -1,14 +1,14 @@
1
  """Module for testing streaming dataset sequence packing"""
 
2
  import unittest
3
- from functools import partial
4
 
5
  import torch
6
  from datasets import load_dataset
7
  from torch.utils.data import DataLoader
8
  from transformers import AutoTokenizer
9
 
10
- from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
11
- from axolotl.utils.data import encode_packed_pretraining
12
 
13
 
14
  class TestPretrainingPacking(unittest.TestCase):
@@ -20,8 +20,6 @@ class TestPretrainingPacking(unittest.TestCase):
20
  # pylint: disable=duplicate-code
21
  self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
22
  self.tokenizer.pad_token = "</s>"
23
- self.max_seq_length = 2048
24
- self.batch_size = 2
25
 
26
  def test_packing_stream_dataset(self):
27
  # pylint: disable=duplicate-code
@@ -31,30 +29,43 @@ class TestPretrainingPacking(unittest.TestCase):
31
  streaming=True,
32
  )["train"]
33
 
34
- collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
35
- self.tokenizer,
36
- return_tensors="pt",
37
- padding=True,
38
- pad_to_multiple_of=self.max_seq_length,
 
 
 
 
 
 
 
 
 
39
  )
40
 
41
- encode = partial(
42
- encode_packed_pretraining,
 
43
  self.tokenizer,
44
- collate_fn,
45
- max_seq_length=self.max_seq_length,
46
- batch_size=self.batch_size,
47
  )
48
 
49
- dataset = dataset.map(
50
- encode,
51
- batched=True,
52
- input_columns="text",
53
- remove_columns=dataset.features.keys(),
 
 
 
 
54
  )
55
 
56
  trainer_loader = DataLoader(
57
- dataset,
58
  batch_size=1,
59
  collate_fn=None,
60
  drop_last=True,
@@ -64,16 +75,16 @@ class TestPretrainingPacking(unittest.TestCase):
64
  if idx > 10:
65
  break
66
  assert data["input_ids"].shape == torch.Size(
67
- [1, self.batch_size * self.max_seq_length]
68
  )
69
  assert data["position_ids"].shape == torch.Size(
70
- [1, self.batch_size * self.max_seq_length]
71
  )
72
  assert data["labels"].shape == torch.Size(
73
- [1, self.batch_size * self.max_seq_length]
74
  )
75
  assert data["attention_mask"].shape == torch.Size(
76
- [1, self.batch_size * self.max_seq_length]
77
  )
78
  idx += 1
79
 
 
1
  """Module for testing streaming dataset sequence packing"""
2
+ import functools
3
  import unittest
 
4
 
5
  import torch
6
  from datasets import load_dataset
7
  from torch.utils.data import DataLoader
8
  from transformers import AutoTokenizer
9
 
10
+ from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
11
+ from axolotl.utils.dict import DictDefault
12
 
13
 
14
  class TestPretrainingPacking(unittest.TestCase):
 
20
  # pylint: disable=duplicate-code
21
  self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
22
  self.tokenizer.pad_token = "</s>"
 
 
23
 
24
  def test_packing_stream_dataset(self):
25
  # pylint: disable=duplicate-code
 
29
  streaming=True,
30
  )["train"]
31
 
32
+ cfg = DictDefault(
33
+ {
34
+ "pretraining_dataset": [
35
+ {
36
+ "path": "c4",
37
+ "name": "en",
38
+ "type": "pretrain",
39
+ }
40
+ ],
41
+ "sample_packing": True,
42
+ "pad_to_sequence_len": True,
43
+ "sequence_len": 2048,
44
+ "micro_batch_size": 2,
45
+ }
46
  )
47
 
48
+ ds_wrapper_partial = functools.partial(
49
+ get_dataset_wrapper,
50
+ cfg.pretraining_dataset[0],
51
  self.tokenizer,
52
+ cfg,
53
+ cfg.pretraining_dataset[0]["type"] or "pretrain",
 
54
  )
55
 
56
+ original_bsz = cfg.micro_batch_size
57
+ train_dataset = wrap_pretraining_dataset(
58
+ dataset,
59
+ self.tokenizer,
60
+ cfg,
61
+ ds_wrapper_partial,
62
+ max_tokens=cfg.sequence_len,
63
+ batch_size=cfg.micro_batch_size,
64
+ seed=cfg.seed or 42,
65
  )
66
 
67
  trainer_loader = DataLoader(
68
+ train_dataset,
69
  batch_size=1,
70
  collate_fn=None,
71
  drop_last=True,
 
75
  if idx > 10:
76
  break
77
  assert data["input_ids"].shape == torch.Size(
78
+ [1, original_bsz * cfg.sequence_len]
79
  )
80
  assert data["position_ids"].shape == torch.Size(
81
+ [1, original_bsz * cfg.sequence_len]
82
  )
83
  assert data["labels"].shape == torch.Size(
84
+ [1, original_bsz * cfg.sequence_len]
85
  )
86
  assert data["attention_mask"].shape == torch.Size(
87
+ [1, original_bsz * cfg.sequence_len]
88
  )
89
  idx += 1
90