winglian commited on
Commit
32580c1
1 Parent(s): 802f966

Vram fix attempt (#1164) [skip ci]

Browse files

* revert order of filter/drop_long step and handle calc for max_input_len only during preprocessing

* revert some changes to preparing for packing to allow more flexibility

* prepare dataset for packing during pre-processing step

* prepare dataset hash based on sample packing too

* enclose none check

* just cast straight to string for ds hash

src/axolotl/utils/data.py CHANGED
@@ -116,6 +116,12 @@ def load_tokenized_prepared_datasets(
116
  (
117
  str(cfg.sequence_len)
118
  + "@"
 
 
 
 
 
 
119
  + "|".join(
120
  sorted(
121
  [
@@ -162,7 +168,7 @@ def load_tokenized_prepared_datasets(
162
  LOG.info("Loading raw datasets...")
163
  if not cfg.is_preprocess:
164
  LOG.warning(
165
- "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset"
166
  )
167
 
168
  if cfg.seed:
 
116
  (
117
  str(cfg.sequence_len)
118
  + "@"
119
+ + str(cfg.sample_packing)
120
+ + "@"
121
+ + str(cfg.eval_sample_packing)
122
+ + "@"
123
+ + str(cfg.group_by_length)
124
+ + "@"
125
  + "|".join(
126
  sorted(
127
  [
 
168
  LOG.info("Loading raw datasets...")
169
  if not cfg.is_preprocess:
170
  LOG.warning(
171
+ "Processing datasets during training can lead to VRAM instability. Please pre-process your dataset."
172
  )
173
 
174
  if cfg.seed:
src/axolotl/utils/samplers/utils.py CHANGED
@@ -7,11 +7,11 @@ import numpy as np
7
  def get_dataset_lengths(dataset):
8
  if "length" in dataset.data.column_names:
9
  lengths = np.array(dataset.data.column("length"))
 
 
 
10
  else:
11
- lengths = (
12
- dataset.data.column("position_ids")
13
- .to_pandas()
14
- .apply(lambda x: x[-1] + 1)
15
- .values
16
- )
17
  return lengths
 
7
  def get_dataset_lengths(dataset):
8
  if "length" in dataset.data.column_names:
9
  lengths = np.array(dataset.data.column("length"))
10
+ elif "position_ids" in dataset.data.column_names:
11
+ position_ids = dataset.data.column("position_ids")
12
+ lengths = np.array([x[-1] + 1 for x in position_ids])
13
  else:
14
+ input_ids = dataset.data.column("input_ids")
15
+ lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
16
+ return lengths
 
 
 
17
  return lengths
src/axolotl/utils/trainer.py CHANGED
@@ -109,6 +109,33 @@ def drop_long_seq(sample, sequence_len=2048):
109
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
110
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
111
  with zero_first(is_main_process()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  if cfg.group_by_length:
113
  train_dataset = train_dataset.map(
114
  add_length,
@@ -130,33 +157,6 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
130
  load_from_cache_file=not cfg.is_preprocess,
131
  )
132
 
133
- if cfg.group_by_length or cfg.sample_packing:
134
- max_input_len = np.max(get_dataset_lengths(train_dataset))
135
- LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
136
-
137
- train_dataset = train_dataset.filter(
138
- drop_long,
139
- num_proc=cfg.dataset_processes,
140
- load_from_cache_file=not cfg.is_preprocess,
141
- )
142
- if eval_dataset:
143
- eval_dataset = eval_dataset.filter(
144
- drop_long,
145
- num_proc=cfg.dataset_processes,
146
- load_from_cache_file=not cfg.is_preprocess,
147
- )
148
-
149
- # Phi doesn't want the attention_mask feature when training
150
- if (
151
- "CodeGenTokenizer" in tokenizer.__class__.__name__
152
- or (cfg.is_mistral_derived_model and cfg.flash_attention)
153
- or cfg.model_config_type == "mamba"
154
- ):
155
- LOG.info("dropping attention_mask column")
156
- train_dataset = train_dataset.remove_columns("attention_mask")
157
- if eval_dataset:
158
- eval_dataset = eval_dataset.remove_columns("attention_mask")
159
-
160
  return train_dataset, eval_dataset
161
 
162
 
 
109
  def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
110
  drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
111
  with zero_first(is_main_process()):
112
+ if cfg.is_preprocess:
113
+ max_input_len = np.max(get_dataset_lengths(train_dataset))
114
+ LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
115
+
116
+ # Phi doesn't want the attention_mask feature when training
117
+ if (
118
+ "CodeGenTokenizer" in tokenizer.__class__.__name__
119
+ or (cfg.is_mistral_derived_model and cfg.flash_attention)
120
+ or cfg.model_config_type == "mamba"
121
+ ):
122
+ LOG.info("dropping attention_mask column")
123
+ train_dataset = train_dataset.remove_columns("attention_mask")
124
+ if eval_dataset:
125
+ eval_dataset = eval_dataset.remove_columns("attention_mask")
126
+
127
+ train_dataset = train_dataset.filter(
128
+ drop_long,
129
+ num_proc=cfg.dataset_processes,
130
+ load_from_cache_file=not cfg.is_preprocess,
131
+ )
132
+ if eval_dataset:
133
+ eval_dataset = eval_dataset.filter(
134
+ drop_long,
135
+ num_proc=cfg.dataset_processes,
136
+ load_from_cache_file=not cfg.is_preprocess,
137
+ )
138
+
139
  if cfg.group_by_length:
140
  train_dataset = train_dataset.map(
141
  add_length,
 
157
  load_from_cache_file=not cfg.is_preprocess,
158
  )
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  return train_dataset, eval_dataset
161
 
162