multipack w batch sampler (#795)
Browse files* test batch sampler w varying batch lens
* wip
* multipack batchsampler wip
* wip
* fix for prepare data loader to get correct # of steps based on gpues
* lint and clean up
* calculate len estimate
* fix total num steps calc
* add options for dataloader_num_workers and dataloader_pin_memory
* remove gitbook
* support prefetch_factor for dataloader optimization
* fix the kwarg
- gitbook/README.md +0 -1
- gitbook/SUMMARY.md +0 -4
- gitbook/small-dev-details.md +0 -3
- src/axolotl/core/trainer_builder.py +105 -47
- src/axolotl/utils/collators.py +27 -0
- src/axolotl/utils/data.py +2 -2
- src/axolotl/utils/samplers/__init__.py +4 -0
- src/axolotl/utils/samplers/multipack.py +193 -0
- src/axolotl/utils/trainer.py +34 -37
gitbook/README.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
# Page
|
|
|
|
gitbook/SUMMARY.md
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
# Table of contents
|
2 |
-
|
3 |
-
* [Page](README.md)
|
4 |
-
* [Small dev details](small-dev-details.md)
|
|
|
|
|
|
|
|
|
|
gitbook/small-dev-details.md
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# Small dev details
|
2 |
-
|
3 |
-
/
|
|
|
|
|
|
|
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -6,7 +6,6 @@ import abc
|
|
6 |
import importlib
|
7 |
import logging
|
8 |
import math
|
9 |
-
import os
|
10 |
import sys
|
11 |
from abc import abstractmethod
|
12 |
from dataclasses import dataclass, field
|
@@ -18,9 +17,9 @@ import torch
|
|
18 |
import transformers
|
19 |
from datasets import Dataset
|
20 |
from torch.optim.lr_scheduler import OneCycleLR
|
21 |
-
from torch.utils.data import DataLoader,
|
22 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
23 |
-
from transformers.
|
24 |
|
25 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
26 |
from axolotl.utils.callbacks import (
|
@@ -31,8 +30,9 @@ from axolotl.utils.callbacks import (
|
|
31 |
bench_eval_callback_factory,
|
32 |
log_prediction_callback_factory,
|
33 |
)
|
34 |
-
from axolotl.utils.collators import
|
35 |
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
|
|
36 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
37 |
|
38 |
try:
|
@@ -102,6 +102,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
102 |
bench_source_max_len: int = field(
|
103 |
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
104 |
)
|
|
|
|
|
|
|
|
|
105 |
|
106 |
|
107 |
class AxolotlTrainer(Trainer):
|
@@ -145,46 +149,69 @@ class AxolotlTrainer(Trainer):
|
|
145 |
return self.lr_scheduler
|
146 |
|
147 |
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
148 |
-
if self.args.
|
149 |
-
return
|
150 |
-
self.train_dataset,
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
)
|
155 |
return super()._get_train_sampler()
|
156 |
|
157 |
def _get_eval_sampler(
|
158 |
self, eval_dataset: Dataset
|
159 |
) -> Optional[torch.utils.data.Sampler]:
|
160 |
-
if
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
170 |
)
|
171 |
return super()._get_eval_sampler(eval_dataset)
|
172 |
|
173 |
-
def get_train_dataloader(self) ->
|
174 |
if self.args.sample_packing:
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
)
|
189 |
return super().get_train_dataloader()
|
190 |
|
@@ -197,18 +224,29 @@ class AxolotlTrainer(Trainer):
|
|
197 |
)
|
198 |
|
199 |
eval_sampler = self._get_eval_sampler(eval_dataset)
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
)
|
213 |
return super().get_eval_dataloader(eval_dataset)
|
214 |
|
@@ -229,6 +267,8 @@ class AxolotlTrainer(Trainer):
|
|
229 |
"num_workers": self.args.dataloader_num_workers,
|
230 |
"pin_memory": self.args.dataloader_pin_memory,
|
231 |
}
|
|
|
|
|
232 |
|
233 |
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
234 |
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
@@ -493,6 +533,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
493 |
"sample_packing_efficiency"
|
494 |
] = self.cfg.sample_packing_eff_est
|
495 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
if self.cfg.eval_steps:
|
497 |
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
498 |
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
@@ -672,7 +725,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
672 |
train_dataset=self.train_dataset,
|
673 |
eval_dataset=self.eval_dataset,
|
674 |
args=training_args,
|
675 |
-
data_collator=
|
676 |
self.tokenizer,
|
677 |
return_tensors="pt",
|
678 |
**data_collator_kwargs,
|
@@ -690,4 +743,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
690 |
for callback in self.get_post_trainer_create_callbacks(trainer):
|
691 |
trainer.add_callback(callback)
|
692 |
|
|
|
|
|
|
|
|
|
|
|
693 |
return trainer
|
|
|
6 |
import importlib
|
7 |
import logging
|
8 |
import math
|
|
|
9 |
import sys
|
10 |
from abc import abstractmethod
|
11 |
from dataclasses import dataclass, field
|
|
|
17 |
import transformers
|
18 |
from datasets import Dataset
|
19 |
from torch.optim.lr_scheduler import OneCycleLR
|
20 |
+
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
22 |
+
from transformers.trainer_utils import seed_worker
|
23 |
|
24 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
25 |
from axolotl.utils.callbacks import (
|
|
|
30 |
bench_eval_callback_factory,
|
31 |
log_prediction_callback_factory,
|
32 |
)
|
33 |
+
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
34 |
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
35 |
+
from axolotl.utils.samplers import MultipackBatchSampler
|
36 |
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
37 |
|
38 |
try:
|
|
|
102 |
bench_source_max_len: int = field(
|
103 |
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
104 |
)
|
105 |
+
dataloader_prefetch_factor: Optional[int] = field(
|
106 |
+
default=None,
|
107 |
+
metadata={"help": "prefetch_factor argument to the dataloader"},
|
108 |
+
)
|
109 |
|
110 |
|
111 |
class AxolotlTrainer(Trainer):
|
|
|
149 |
return self.lr_scheduler
|
150 |
|
151 |
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
152 |
+
if self.args.sample_packing:
|
153 |
+
return MultipackBatchSampler(
|
154 |
+
RandomSampler(self.train_dataset),
|
155 |
+
self.args.train_batch_size,
|
156 |
+
drop_last=True,
|
157 |
+
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
158 |
+
lengths=(
|
159 |
+
self.train_dataset.data.column("position_ids")
|
160 |
+
.to_pandas()
|
161 |
+
.apply(lambda x: x[-1] + 1)
|
162 |
+
.values
|
163 |
+
),
|
164 |
+
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
165 |
)
|
166 |
return super()._get_train_sampler()
|
167 |
|
168 |
def _get_eval_sampler(
|
169 |
self, eval_dataset: Dataset
|
170 |
) -> Optional[torch.utils.data.Sampler]:
|
171 |
+
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
172 |
+
return MultipackBatchSampler(
|
173 |
+
SequentialSampler(eval_dataset),
|
174 |
+
self.args.per_device_eval_batch_size,
|
175 |
+
drop_last=True,
|
176 |
+
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
177 |
+
lengths=(
|
178 |
+
eval_dataset.data.column("position_ids")
|
179 |
+
.to_pandas()
|
180 |
+
.apply(lambda x: x[-1] + 1)
|
181 |
+
.values
|
182 |
+
),
|
183 |
+
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
184 |
)
|
185 |
return super()._get_eval_sampler(eval_dataset)
|
186 |
|
187 |
+
def get_train_dataloader(self) -> DataLoader:
|
188 |
if self.args.sample_packing:
|
189 |
+
train_dataset = self.train_dataset
|
190 |
+
train_dataset = train_dataset.remove_columns(["length"])
|
191 |
+
data_collator = self.data_collator
|
192 |
+
dataloader_params = {
|
193 |
+
"batch_size": self._train_batch_size,
|
194 |
+
"collate_fn": data_collator,
|
195 |
+
"num_workers": self.args.dataloader_num_workers,
|
196 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
197 |
+
}
|
198 |
+
if self.args.dataloader_prefetch_factor:
|
199 |
+
dataloader_params[
|
200 |
+
"prefetch_factor"
|
201 |
+
] = self.args.dataloader_prefetch_factor
|
202 |
+
|
203 |
+
sampler = self._get_train_sampler()
|
204 |
+
if isinstance(sampler, BatchSampler):
|
205 |
+
dataloader_params["batch_sampler"] = sampler
|
206 |
+
del dataloader_params["batch_size"]
|
207 |
+
else:
|
208 |
+
dataloader_params["sampler"] = sampler
|
209 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
210 |
+
dataloader_params["worker_init_fn"] = seed_worker
|
211 |
+
|
212 |
+
self.accelerator.even_batches = False
|
213 |
+
return self.accelerator.prepare_data_loader(
|
214 |
+
DataLoader(train_dataset, **dataloader_params)
|
215 |
)
|
216 |
return super().get_train_dataloader()
|
217 |
|
|
|
224 |
)
|
225 |
|
226 |
eval_sampler = self._get_eval_sampler(eval_dataset)
|
227 |
+
eval_dataset = eval_dataset.remove_columns(["length"])
|
228 |
+
data_collator = self.data_collator
|
229 |
+
dataloader_params = {
|
230 |
+
"batch_size": self.args.eval_batch_size,
|
231 |
+
"collate_fn": data_collator,
|
232 |
+
"num_workers": self.args.dataloader_num_workers,
|
233 |
+
"pin_memory": self.args.dataloader_pin_memory,
|
234 |
+
}
|
235 |
+
if self.args.dataloader_prefetch_factor:
|
236 |
+
dataloader_params[
|
237 |
+
"prefetch_factor"
|
238 |
+
] = self.args.dataloader_prefetch_factor
|
239 |
+
|
240 |
+
if isinstance(eval_sampler, BatchSampler):
|
241 |
+
dataloader_params["batch_sampler"] = eval_sampler
|
242 |
+
del dataloader_params["batch_size"]
|
243 |
+
else:
|
244 |
+
dataloader_params["sampler"] = eval_sampler
|
245 |
+
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
246 |
+
|
247 |
+
self.accelerator.even_batches = False
|
248 |
+
return self.accelerator.prepare_data_loader(
|
249 |
+
DataLoader(eval_dataset, **dataloader_params)
|
250 |
)
|
251 |
return super().get_eval_dataloader(eval_dataset)
|
252 |
|
|
|
267 |
"num_workers": self.args.dataloader_num_workers,
|
268 |
"pin_memory": self.args.dataloader_pin_memory,
|
269 |
}
|
270 |
+
if self.args.dataloader_prefetch_factor:
|
271 |
+
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
272 |
|
273 |
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
274 |
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
|
|
533 |
"sample_packing_efficiency"
|
534 |
] = self.cfg.sample_packing_eff_est
|
535 |
|
536 |
+
if self.cfg.dataloader_pin_memory is not None:
|
537 |
+
training_arguments_kwargs[
|
538 |
+
"dataloader_pin_memory"
|
539 |
+
] = self.cfg.dataloader_pin_memory
|
540 |
+
if self.cfg.dataloader_num_workers is not None:
|
541 |
+
training_arguments_kwargs[
|
542 |
+
"dataloader_num_workers"
|
543 |
+
] = self.cfg.dataloader_num_workers
|
544 |
+
if self.cfg.dataloader_prefetch_factor is not None:
|
545 |
+
training_arguments_kwargs[
|
546 |
+
"dataloader_prefetch_factor"
|
547 |
+
] = self.cfg.dataloader_prefetch_factor
|
548 |
+
|
549 |
if self.cfg.eval_steps:
|
550 |
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
551 |
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
|
|
725 |
train_dataset=self.train_dataset,
|
726 |
eval_dataset=self.eval_dataset,
|
727 |
args=training_args,
|
728 |
+
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
729 |
self.tokenizer,
|
730 |
return_tensors="pt",
|
731 |
**data_collator_kwargs,
|
|
|
743 |
for callback in self.get_post_trainer_create_callbacks(trainer):
|
744 |
trainer.add_callback(callback)
|
745 |
|
746 |
+
if self.cfg.deepspeed and self.cfg.sample_packing:
|
747 |
+
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
748 |
+
"train_micro_batch_size_per_gpu"
|
749 |
+
] = self.cfg.micro_batch_size
|
750 |
+
|
751 |
return trainer
|
src/axolotl/utils/collators.py
CHANGED
@@ -119,3 +119,30 @@ class DataCollatorForSeq2Seq:
|
|
119 |
features["decoder_input_ids"] = decoder_input_ids
|
120 |
|
121 |
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
features["decoder_input_ids"] = decoder_input_ids
|
120 |
|
121 |
return features
|
122 |
+
|
123 |
+
|
124 |
+
@dataclass
|
125 |
+
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
126 |
+
"""
|
127 |
+
Collator for multipack specific to the using the BatchSampler
|
128 |
+
"""
|
129 |
+
|
130 |
+
def __call__(self, features, return_tensors=None):
|
131 |
+
chunked_data = {}
|
132 |
+
for feature in features[0].keys():
|
133 |
+
if feature == "length":
|
134 |
+
continue
|
135 |
+
if feature == "attention_mask":
|
136 |
+
arrays = [
|
137 |
+
(1) * np.array(item[feature])
|
138 |
+
for item in features
|
139 |
+
if feature in item
|
140 |
+
]
|
141 |
+
chunked_data[feature] = np.concatenate(arrays)
|
142 |
+
else:
|
143 |
+
arrays = [
|
144 |
+
np.array(item[feature]) for item in features if feature in item
|
145 |
+
]
|
146 |
+
chunked_data[feature] = np.concatenate(arrays)
|
147 |
+
features = [chunked_data]
|
148 |
+
return super().__call__(features, return_tensors=return_tensors)
|
src/axolotl/utils/data.py
CHANGED
@@ -80,11 +80,11 @@ def prepare_dataset(cfg, tokenizer):
|
|
80 |
)
|
81 |
if cfg.max_steps:
|
82 |
total_num_steps = min(
|
83 |
-
calculate_total_num_steps(cfg, train_dataset
|
84 |
)
|
85 |
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
86 |
else:
|
87 |
-
total_num_steps = calculate_total_num_steps(cfg, train_dataset
|
88 |
return train_dataset, eval_dataset, total_num_steps, prompters
|
89 |
|
90 |
|
|
|
80 |
)
|
81 |
if cfg.max_steps:
|
82 |
total_num_steps = min(
|
83 |
+
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
84 |
)
|
85 |
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
86 |
else:
|
87 |
+
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
88 |
return train_dataset, eval_dataset, total_num_steps, prompters
|
89 |
|
90 |
|
src/axolotl/utils/samplers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
axolotl samplers module
|
3 |
+
"""
|
4 |
+
from .multipack import MultipackBatchSampler # noqa: F401
|
src/axolotl/utils/samplers/multipack.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pylint: skip-file
|
2 |
+
"""
|
3 |
+
Multipack Batch Sampler
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
import math
|
7 |
+
import os
|
8 |
+
from typing import Any, Iterable, List, Union
|
9 |
+
|
10 |
+
import numba
|
11 |
+
import numpy as np
|
12 |
+
from torch.utils.data import BatchSampler, Sampler
|
13 |
+
|
14 |
+
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
15 |
+
|
16 |
+
|
17 |
+
@numba.njit
|
18 |
+
def ffd_check(a: np.ndarray, c: int, n: int):
|
19 |
+
# First-fit-decreasing bin packing
|
20 |
+
# Check if a[] could fit in n bins with capacity c
|
21 |
+
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
22 |
+
|
23 |
+
a = np.sort(a)[::-1]
|
24 |
+
bins = np.full((n,), c, dtype=a.dtype)
|
25 |
+
for size in a:
|
26 |
+
not_found = True
|
27 |
+
for idx in range(n):
|
28 |
+
if bins[idx] >= size:
|
29 |
+
bins[idx] -= size
|
30 |
+
not_found = False
|
31 |
+
break
|
32 |
+
|
33 |
+
if not_found:
|
34 |
+
return False
|
35 |
+
|
36 |
+
return True
|
37 |
+
|
38 |
+
|
39 |
+
@numba.njit
|
40 |
+
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
41 |
+
# First-fit-decreasing bin packing (with result return)
|
42 |
+
|
43 |
+
indices = np.argsort(a)[::-1]
|
44 |
+
a = a[indices]
|
45 |
+
|
46 |
+
bins: List[Any] = []
|
47 |
+
bins_result: List[Any] = []
|
48 |
+
for a_id, size in enumerate(a):
|
49 |
+
add_new = True
|
50 |
+
for idx in range(len(bins)):
|
51 |
+
if bins[idx] >= size:
|
52 |
+
bins[idx] -= size
|
53 |
+
bins_result[idx].append(indices[a_id] + start_index)
|
54 |
+
add_new = False
|
55 |
+
break
|
56 |
+
|
57 |
+
if add_new:
|
58 |
+
bins.append(c - size)
|
59 |
+
bins_result.append([indices[a_id] + start_index])
|
60 |
+
|
61 |
+
return bins_result
|
62 |
+
|
63 |
+
|
64 |
+
@numba.njit
|
65 |
+
def allocate(
|
66 |
+
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
67 |
+
):
|
68 |
+
# Dynamic batch allocator, similar to Multifit
|
69 |
+
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
70 |
+
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
71 |
+
|
72 |
+
s = 0
|
73 |
+
start_index = 0
|
74 |
+
result = []
|
75 |
+
|
76 |
+
while True:
|
77 |
+
# binary search [l, r)
|
78 |
+
left = 1
|
79 |
+
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
80 |
+
|
81 |
+
while right - left > 1:
|
82 |
+
mid = (left + right) // 2
|
83 |
+
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
84 |
+
left = mid
|
85 |
+
else:
|
86 |
+
right = mid
|
87 |
+
|
88 |
+
# use length l
|
89 |
+
batch = ffd_with_result(
|
90 |
+
lengths[start_index : start_index + left], c, start_index
|
91 |
+
)
|
92 |
+
assert len(batch) <= n
|
93 |
+
if len(batch) < n:
|
94 |
+
break
|
95 |
+
|
96 |
+
start_index += left
|
97 |
+
s = lengths_cumsum[start_index - 1]
|
98 |
+
|
99 |
+
# add local rank
|
100 |
+
result.append(batch[rank])
|
101 |
+
|
102 |
+
return result, s, len(result) * c * n
|
103 |
+
|
104 |
+
|
105 |
+
class MultipackBatchSampler(BatchSampler):
|
106 |
+
"""
|
107 |
+
Batch Sampler class for multipack
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
sampler: Union[Sampler[int], Iterable[int]],
|
113 |
+
batch_size: int,
|
114 |
+
drop_last: bool,
|
115 |
+
batch_max_len: int,
|
116 |
+
lengths: np.ndarray,
|
117 |
+
packing_efficiency_estimate: float = 1.0,
|
118 |
+
):
|
119 |
+
super().__init__(sampler, batch_size, drop_last)
|
120 |
+
self.batch_size = None
|
121 |
+
self.batch_max_len = batch_max_len
|
122 |
+
self.lengths: np.ndarray = lengths
|
123 |
+
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
124 |
+
|
125 |
+
assert isinstance(self.lengths, np.ndarray)
|
126 |
+
|
127 |
+
self.epoch = 0
|
128 |
+
|
129 |
+
# statistics
|
130 |
+
self.eff_total_used = 0
|
131 |
+
self.eff_total_slots = 0
|
132 |
+
|
133 |
+
def set_epoch(self, epoch: int):
|
134 |
+
self.epoch = epoch
|
135 |
+
|
136 |
+
def generate_batches(self, set_stats=False):
|
137 |
+
indices = [idx for idx in self.sampler]
|
138 |
+
|
139 |
+
lengths = self.lengths[indices]
|
140 |
+
lengths_cumsum = np.cumsum(lengths)
|
141 |
+
|
142 |
+
batches, total_used, total_slots = allocate(
|
143 |
+
lengths=lengths,
|
144 |
+
lengths_cumsum=lengths_cumsum,
|
145 |
+
rank=0,
|
146 |
+
c=self.batch_max_len,
|
147 |
+
n=1,
|
148 |
+
)
|
149 |
+
|
150 |
+
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
151 |
+
|
152 |
+
# statistics
|
153 |
+
if set_stats:
|
154 |
+
self.eff_total_used += total_used
|
155 |
+
self.eff_total_slots += total_slots
|
156 |
+
|
157 |
+
return batches
|
158 |
+
|
159 |
+
def __iter__(self):
|
160 |
+
batches = self.generate_batches(set_stats=True)
|
161 |
+
return iter(batches)
|
162 |
+
|
163 |
+
def num_batches(self):
|
164 |
+
batches = self.generate_batches(set_stats=True)
|
165 |
+
return len(batches)
|
166 |
+
|
167 |
+
def efficiency(self):
|
168 |
+
return self.eff_total_used / self.eff_total_slots
|
169 |
+
|
170 |
+
def __len__(self):
|
171 |
+
self.num_batches()
|
172 |
+
return self._len_est()
|
173 |
+
|
174 |
+
def _len_est(self):
|
175 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
176 |
+
lengths_sum = np.sum(self.lengths)
|
177 |
+
lengths_sum_per_device = lengths_sum // world_size
|
178 |
+
LOG.info(
|
179 |
+
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
180 |
+
f"total_num_tokens per device: {lengths_sum_per_device}"
|
181 |
+
)
|
182 |
+
|
183 |
+
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
184 |
+
return (
|
185 |
+
world_size
|
186 |
+
* math.floor(
|
187 |
+
0.99
|
188 |
+
* lengths_sum_per_device
|
189 |
+
/ self.packing_efficiency_estimate
|
190 |
+
// self.batch_max_len
|
191 |
+
)
|
192 |
+
- 1
|
193 |
+
)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -8,20 +8,13 @@ from typing import List
|
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
import torch.cuda
|
11 |
-
import torch.distributed as dist
|
12 |
from accelerate.logging import get_logger
|
13 |
from datasets import set_caching_enabled
|
14 |
-
from torch.utils.data import
|
15 |
|
16 |
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
17 |
-
from axolotl.utils.
|
18 |
-
from axolotl.utils.
|
19 |
-
from axolotl.utils.distributed import (
|
20 |
-
is_distributed,
|
21 |
-
is_main_process,
|
22 |
-
reduce_and_broadcast,
|
23 |
-
zero_first,
|
24 |
-
)
|
25 |
|
26 |
LOG = get_logger("axolotl")
|
27 |
|
@@ -148,7 +141,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|
148 |
return train_dataset, eval_dataset
|
149 |
|
150 |
|
151 |
-
def calculate_total_num_steps(cfg, train_dataset
|
152 |
if cfg.sample_packing:
|
153 |
# we have to drop anything longer then sequence len otherwise
|
154 |
# flash attention with position ids fails
|
@@ -196,37 +189,36 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
196 |
main_process_only=True,
|
197 |
)
|
198 |
else:
|
199 |
-
|
200 |
-
sampler
|
201 |
-
train_dataset,
|
202 |
-
num_replicas=cfg.world_size,
|
203 |
-
rank=dist.get_rank(),
|
204 |
-
seed=cfg.seed or 42,
|
205 |
-
)
|
206 |
-
else:
|
207 |
-
sampler = RandomSampler(train_dataset)
|
208 |
-
|
209 |
-
data_loader = MultipackDistributedDataloader(
|
210 |
-
train_dataset,
|
211 |
batch_size=cfg.micro_batch_size,
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
217 |
),
|
218 |
-
sampler=sampler,
|
219 |
-
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
220 |
-
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
221 |
-
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
222 |
-
num_epochs=cfg.num_epochs,
|
223 |
)
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
226 |
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
227 |
# FIXME: is there a bug here somewhere? the total num steps depends
|
228 |
# on the agreed on value for sample_packing_eff_est
|
229 |
-
total_num_steps = int(
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
def calc_sample_packing_eff_est(estimates: List[float]):
|
232 |
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
@@ -246,7 +238,12 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
|
246 |
)
|
247 |
else:
|
248 |
total_num_steps = int(
|
249 |
-
math.ceil(
|
|
|
|
|
|
|
|
|
|
|
250 |
)
|
251 |
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
252 |
return total_num_steps
|
|
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
import torch.cuda
|
|
|
11 |
from accelerate.logging import get_logger
|
12 |
from datasets import set_caching_enabled
|
13 |
+
from torch.utils.data import DataLoader, RandomSampler
|
14 |
|
15 |
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
16 |
+
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
17 |
+
from axolotl.utils.samplers import MultipackBatchSampler
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
LOG = get_logger("axolotl")
|
20 |
|
|
|
141 |
return train_dataset, eval_dataset
|
142 |
|
143 |
|
144 |
+
def calculate_total_num_steps(cfg, train_dataset):
|
145 |
if cfg.sample_packing:
|
146 |
# we have to drop anything longer then sequence len otherwise
|
147 |
# flash attention with position ids fails
|
|
|
189 |
main_process_only=True,
|
190 |
)
|
191 |
else:
|
192 |
+
sampler = MultipackBatchSampler(
|
193 |
+
sampler=RandomSampler(train_dataset),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
batch_size=cfg.micro_batch_size,
|
195 |
+
drop_last=True,
|
196 |
+
batch_max_len=cfg.micro_batch_size
|
197 |
+
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
198 |
+
lengths=(
|
199 |
+
train_dataset.data.column("position_ids")
|
200 |
+
.to_pandas()
|
201 |
+
.apply(lambda x: x[-1] + 1)
|
202 |
+
.values
|
203 |
),
|
|
|
|
|
|
|
|
|
|
|
204 |
)
|
205 |
+
|
206 |
+
data_loader = DataLoader(
|
207 |
+
train_dataset.remove_columns(["length"]),
|
208 |
+
batch_sampler=sampler,
|
209 |
+
)
|
210 |
+
data_loader_len = len(data_loader)
|
211 |
+
actual_eff = sampler.efficiency()
|
212 |
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
213 |
# FIXME: is there a bug here somewhere? the total num steps depends
|
214 |
# on the agreed on value for sample_packing_eff_est
|
215 |
+
total_num_steps = int(
|
216 |
+
math.floor(
|
217 |
+
data_loader_len
|
218 |
+
* cfg.num_epochs
|
219 |
+
/ int(os.environ.get("WORLD_SIZE", 1))
|
220 |
+
)
|
221 |
+
)
|
222 |
|
223 |
def calc_sample_packing_eff_est(estimates: List[float]):
|
224 |
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
|
|
238 |
)
|
239 |
else:
|
240 |
total_num_steps = int(
|
241 |
+
math.ceil(
|
242 |
+
len(train_dataset)
|
243 |
+
* cfg.num_epochs
|
244 |
+
/ int(os.environ.get("WORLD_SIZE", 1))
|
245 |
+
/ cfg.batch_size
|
246 |
+
)
|
247 |
)
|
248 |
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
249 |
return total_num_steps
|