winglian commited on
Commit
5aa5097
1 Parent(s): cae608f

Pretrain multipack v2 (#1470)

Browse files
requirements.txt CHANGED
@@ -40,3 +40,4 @@ gcsfs
40
  # adlfs
41
 
42
  trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
 
 
40
  # adlfs
41
 
42
  trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
43
+ zstandard==0.22.0
src/axolotl/utils/collators.py CHANGED
@@ -217,13 +217,24 @@ class PretrainingBatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
217
  Collator for multipack specific to the using the BatchSampler
218
  """
219
 
 
 
 
 
220
  def __call__(self, features, return_tensors=None):
221
  chunked_data = {}
222
  for feature in features.keys():
223
  if feature == "length":
224
  continue
225
  if feature == "attention_mask":
226
- arrays = [(1) * np.array(item) for item in features[feature]]
 
 
 
 
 
 
 
227
  chunked_data[feature] = np.concatenate(arrays)
228
  else:
229
  arrays = [np.array(item) for item in features[feature]]
 
217
  Collator for multipack specific to the using the BatchSampler
218
  """
219
 
220
+ def __init__(self, *args, multipack_attn=True, **kwargs):
221
+ super().__init__(*args, **kwargs)
222
+ self.multipack_attn = multipack_attn
223
+
224
  def __call__(self, features, return_tensors=None):
225
  chunked_data = {}
226
  for feature in features.keys():
227
  if feature == "length":
228
  continue
229
  if feature == "attention_mask":
230
+ if self.multipack_attn:
231
+ arrays = [
232
+ (i + 1) * np.array(item[feature])
233
+ for i, item in enumerate(features[feature])
234
+ if feature in item
235
+ ]
236
+ else:
237
+ arrays = [(1) * np.array(item) for item in features[feature]]
238
  chunked_data[feature] = np.concatenate(arrays)
239
  else:
240
  arrays = [np.array(item) for item in features[feature]]
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -511,6 +511,14 @@ class AxolotlInputConfig(
511
  eval_sample_packing: Optional[bool] = None
512
  pad_to_sequence_len: Optional[bool] = None
513
 
 
 
 
 
 
 
 
 
514
  xformers_attention: Optional[bool] = None
515
  sdp_attention: Optional[bool] = None
516
  s2_attention: Optional[bool] = None
 
511
  eval_sample_packing: Optional[bool] = None
512
  pad_to_sequence_len: Optional[bool] = None
513
 
514
+ pretrain_multipack_buffer_size: Optional[int] = 10_000
515
+ pretrain_multipack_attn: Optional[bool] = Field(
516
+ default=True,
517
+ metadata={
518
+ "help": "whether to prevent cross attention for packed sequences during pretraining",
519
+ },
520
+ )
521
+
522
  xformers_attention: Optional[bool] = None
523
  sdp_attention: Optional[bool] = None
524
  s2_attention: Optional[bool] = None
src/axolotl/utils/data.py CHANGED
@@ -108,6 +108,7 @@ def prepare_dataset(cfg, tokenizer):
108
  max_tokens=cfg.sequence_len,
109
  batch_size=cfg.micro_batch_size,
110
  seed=cfg.seed or 42,
 
111
  )
112
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
113
  train_dataset = train_dataset.with_format("torch")
@@ -816,6 +817,7 @@ def wrap_pretraining_dataset(
816
  return_tensors="pt",
817
  padding=True,
818
  pad_to_multiple_of=max_tokens * batch_size,
 
819
  )
820
  encode = functools.partial(
821
  encode_packed_pretraining,
@@ -823,6 +825,7 @@ def wrap_pretraining_dataset(
823
  ds_wrapper_fn,
824
  max_seq_length=max_tokens,
825
  batch_size=batch_size,
 
826
  )
827
  # set this to 1 so downstream data_loader doesn't try to increase the batch again
828
  cfg.micro_batch_size = 1
@@ -861,6 +864,7 @@ def encode_packed_pretraining(
861
  examples: Dict[str, List],
862
  max_seq_length: int = 2048,
863
  batch_size: int = 4,
 
864
  ) -> Dict[str, List]:
865
  # pylint: disable=duplicate-code
866
  # tokenize all the examples
@@ -868,7 +872,9 @@ def encode_packed_pretraining(
868
  train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
869
 
870
  train_dataset = process_pretraining_datasets_for_packing(
871
- train_dataset, max_seq_length
 
 
872
  )
873
 
874
  sampler = MultipackBatchSampler(
 
108
  max_tokens=cfg.sequence_len,
109
  batch_size=cfg.micro_batch_size,
110
  seed=cfg.seed or 42,
111
+ buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
112
  )
113
  # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
114
  train_dataset = train_dataset.with_format("torch")
 
817
  return_tensors="pt",
818
  padding=True,
819
  pad_to_multiple_of=max_tokens * batch_size,
820
+ multipack_attn=cfg.pretrain_multipack_attn,
821
  )
822
  encode = functools.partial(
823
  encode_packed_pretraining,
 
825
  ds_wrapper_fn,
826
  max_seq_length=max_tokens,
827
  batch_size=batch_size,
828
+ multipack_attn=cfg.pretrain_multipack_attn,
829
  )
830
  # set this to 1 so downstream data_loader doesn't try to increase the batch again
831
  cfg.micro_batch_size = 1
 
864
  examples: Dict[str, List],
865
  max_seq_length: int = 2048,
866
  batch_size: int = 4,
867
+ multipack_attn: Optional[bool] = False,
868
  ) -> Dict[str, List]:
869
  # pylint: disable=duplicate-code
870
  # tokenize all the examples
 
872
  train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
873
 
874
  train_dataset = process_pretraining_datasets_for_packing(
875
+ train_dataset,
876
+ max_seq_length,
877
+ skip_position_ids=not multipack_attn,
878
  )
879
 
880
  sampler = MultipackBatchSampler(
src/axolotl/utils/trainer.py CHANGED
@@ -172,17 +172,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
172
  return train_dataset, eval_dataset
173
 
174
 
175
- def process_pretraining_datasets_for_packing(train_dataset, sequence_len):
 
 
176
  drop_long = partial(drop_long_seq, sequence_len=sequence_len)
177
 
178
  train_dataset = train_dataset.filter(
179
  drop_long,
180
  desc="Dropping Long Sequences",
181
  )
182
- train_dataset = train_dataset.map(
183
- add_position_ids,
184
- desc="Add position_id column (Pretraining Sample Packing)",
185
- )
 
 
186
  return train_dataset
187
 
188
 
 
172
  return train_dataset, eval_dataset
173
 
174
 
175
+ def process_pretraining_datasets_for_packing(
176
+ train_dataset, sequence_len, skip_position_ids=True
177
+ ):
178
  drop_long = partial(drop_long_seq, sequence_len=sequence_len)
179
 
180
  train_dataset = train_dataset.filter(
181
  drop_long,
182
  desc="Dropping Long Sequences",
183
  )
184
+ if skip_position_ids:
185
+ train_dataset = train_dataset.map(
186
+ add_position_ids,
187
+ desc="Add position_id column (Pretraining Sample Packing)",
188
+ )
189
+
190
  return train_dataset
191
 
192