monsoon-nlp commited on
Commit
586bd8d
1 Parent(s): 86b7d22

fix pretraining_ on odd datasets (#1463)

Browse files

* can configure name of split of pretraining dataset

* streaming data and dataset map

* text column customized

* allow text_column to be set in pretrain

* pretrain type

* load a bit of the dataset

* fix dataset where splits have separate configs

* ok name param here is the config

* whitespace

src/axolotl/prompt_strategies/pretrain.py CHANGED
@@ -20,10 +20,11 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
20
  def supports_batched(self):
21
  return True
22
 
23
- def __init__(self, *args, max_length=None, **kwargs):
24
  super().__init__(*args, **kwargs)
25
  if max_length:
26
  self.max_length = max_length
 
27
 
28
  def _tokenize(
29
  self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
@@ -44,7 +45,7 @@ class PretrainTokenizationStrategy(PromptTokenizingStrategy):
44
  return res
45
 
46
  def tokenize_prompt(self, prompt):
47
- return self._tokenize(prompt["text"])
48
 
49
 
50
  def load(tokenizer, cfg):
@@ -53,6 +54,7 @@ def load(tokenizer, cfg):
53
  tokenizer,
54
  cfg.train_on_inputs,
55
  cfg.sequence_len,
 
56
  max_length=cfg.sequence_len * 64,
57
  )
58
  return strat
 
20
  def supports_batched(self):
21
  return True
22
 
23
+ def __init__(self, *args, max_length=None, text_column="text", **kwargs):
24
  super().__init__(*args, **kwargs)
25
  if max_length:
26
  self.max_length = max_length
27
+ self.text_column = text_column
28
 
29
  def _tokenize(
30
  self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
 
45
  return res
46
 
47
  def tokenize_prompt(self, prompt):
48
+ return self._tokenize(prompt[self.text_column])
49
 
50
 
51
  def load(tokenizer, cfg):
 
54
  tokenizer,
55
  cfg.train_on_inputs,
56
  cfg.sequence_len,
57
+ text_column=cfg.pretraining_dataset[0]["text_column"] or "text",
58
  max_length=cfg.sequence_len * 64,
59
  )
60
  return strat
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -61,7 +61,11 @@ class RemappedParameters(BaseModel):
61
  class PretrainingDataset(BaseModel):
62
  """pretraining dataset configuration subset"""
63
 
 
64
  path: Optional[str] = None
 
 
 
65
 
66
 
67
  class UserDefinedPrompterType(BaseModel):
@@ -448,7 +452,7 @@ class AxolotlInputConfig(
448
  dataset_shard_idx: Optional[int] = None
449
 
450
  pretraining_dataset: Optional[ # type: ignore
451
- conlist(Union[SFTDataset, PretrainingDataset], min_length=1)
452
  ] = Field(
453
  default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
454
  )
 
61
  class PretrainingDataset(BaseModel):
62
  """pretraining dataset configuration subset"""
63
 
64
+ name: Optional[str] = None
65
  path: Optional[str] = None
66
+ split: Optional[str] = "train"
67
+ text_column: Optional[str] = "text"
68
+ type: Optional[str] = "pretrain"
69
 
70
 
71
  class UserDefinedPrompterType(BaseModel):
 
452
  dataset_shard_idx: Optional[int] = None
453
 
454
  pretraining_dataset: Optional[ # type: ignore
455
+ conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
456
  ] = Field(
457
  default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
458
  )
src/axolotl/utils/data.py CHANGED
@@ -82,12 +82,15 @@ def prepare_dataset(cfg, tokenizer):
82
  )
83
  else:
84
  path = cfg.pretraining_dataset
 
85
  name = None
86
  if isinstance(cfg.pretraining_dataset, list) and isinstance(
87
  cfg.pretraining_dataset[0], dict
88
  ):
89
  path = cfg.pretraining_dataset[0]["path"]
90
  name = cfg.pretraining_dataset[0]["name"]
 
 
91
 
92
  ds_wrapper_partial = functools.partial(
93
  get_dataset_wrapper,
@@ -98,7 +101,7 @@ def prepare_dataset(cfg, tokenizer):
98
  )
99
 
100
  train_dataset = wrap_pretraining_dataset(
101
- load_dataset(path, streaming=True, split="train", name=name),
102
  tokenizer,
103
  cfg,
104
  ds_wrapper_partial,
@@ -831,14 +834,23 @@ def wrap_pretraining_dataset(
831
  else:
832
  LOG.debug("NOT shuffling merged pretraining datasets")
833
 
 
 
 
 
 
 
 
 
 
 
 
834
  dataset = dataset.map(
835
  encode,
836
  batched=True,
837
  batch_size=buffer_size,
838
  # input_columns="text",
839
- # remove all the existing columns after mapping since they end up having
840
- # a different length than the encoded/tokenized column
841
- remove_columns=dataset.features.keys(),
842
  )
843
  return dataset
844
 
 
82
  )
83
  else:
84
  path = cfg.pretraining_dataset
85
+ split = "train"
86
  name = None
87
  if isinstance(cfg.pretraining_dataset, list) and isinstance(
88
  cfg.pretraining_dataset[0], dict
89
  ):
90
  path = cfg.pretraining_dataset[0]["path"]
91
  name = cfg.pretraining_dataset[0]["name"]
92
+ if "split" in cfg.pretraining_dataset[0]:
93
+ split = cfg.pretraining_dataset[0]["split"]
94
 
95
  ds_wrapper_partial = functools.partial(
96
  get_dataset_wrapper,
 
101
  )
102
 
103
  train_dataset = wrap_pretraining_dataset(
104
+ load_dataset(path, streaming=True, split=split, name=name),
105
  tokenizer,
106
  cfg,
107
  ds_wrapper_partial,
 
834
  else:
835
  LOG.debug("NOT shuffling merged pretraining datasets")
836
 
837
+ # remove all the existing columns after mapping since they end up having
838
+ # a different length than the encoded/tokenized column
839
+ # this is empty during streaming/pretraining
840
+ remove_columns = []
841
+ if dataset.features is None:
842
+ for first_row in dataset:
843
+ remove_columns = first_row.keys()
844
+ break
845
+ else:
846
+ remove_columns = dataset.features.keys()
847
+
848
  dataset = dataset.map(
849
  encode,
850
  batched=True,
851
  batch_size=buffer_size,
852
  # input_columns="text",
853
+ remove_columns=remove_columns,
 
 
854
  )
855
  return dataset
856