be able to use adam bnb 8bit and one cycle scheduler w fsdp
Browse files- src/axolotl/utils/data.py +3 -3
- src/axolotl/utils/trainer.py +25 -2
src/axolotl/utils/data.py
CHANGED
@@ -7,7 +7,7 @@ from datasets import (
|
|
7 |
load_dataset,
|
8 |
IterableDataset,
|
9 |
Dataset,
|
10 |
-
concatenate_datasets,
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
from transformers import PreTrainedTokenizerBase
|
@@ -37,7 +37,7 @@ from axolotl.prompters import (
|
|
37 |
)
|
38 |
|
39 |
|
40 |
-
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
41 |
tokenizer_name = tokenizer.__class__.__name__
|
42 |
ds_hash = str(
|
43 |
md5(
|
@@ -196,7 +196,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
196 |
return dataset
|
197 |
|
198 |
|
199 |
-
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path):
|
200 |
max_packed_sequence_len = (
|
201 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
202 |
)
|
|
|
7 |
load_dataset,
|
8 |
IterableDataset,
|
9 |
Dataset,
|
10 |
+
concatenate_datasets, DatasetDict,
|
11 |
)
|
12 |
from huggingface_hub import hf_hub_download
|
13 |
from transformers import PreTrainedTokenizerBase
|
|
|
37 |
)
|
38 |
|
39 |
|
40 |
+
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict:
|
41 |
tokenizer_name = tokenizer.__class__.__name__
|
42 |
ds_hash = str(
|
43 |
md5(
|
|
|
196 |
return dataset
|
197 |
|
198 |
|
199 |
+
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset):
|
200 |
max_packed_sequence_len = (
|
201 |
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
202 |
)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -9,13 +9,31 @@ import torch.cuda
|
|
9 |
import transformers
|
10 |
from torch import nn
|
11 |
from torch.optim.lr_scheduler import OneCycleLR
|
12 |
-
from transformers import EarlyStoppingCallback
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
16 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
20 |
total_num_steps = int(
|
21 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
@@ -63,6 +81,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
63 |
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
64 |
if cfg.fsdp_config:
|
65 |
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
|
|
|
|
|
|
66 |
|
67 |
# deepspeed
|
68 |
if (
|
@@ -119,6 +140,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
119 |
cfg.optimizer == "adamw_bnb_8bit"
|
120 |
and not cfg.load_4bit
|
121 |
and not "deepspeed" in training_arguments_kwargs
|
|
|
122 |
):
|
123 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
124 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
@@ -194,7 +216,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
194 |
else:
|
195 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
196 |
|
197 |
-
|
|
|
198 |
model=model,
|
199 |
train_dataset=train_dataset,
|
200 |
eval_dataset=eval_dataset,
|
|
|
9 |
import transformers
|
10 |
from torch import nn
|
11 |
from torch.optim.lr_scheduler import OneCycleLR
|
12 |
+
from transformers import EarlyStoppingCallback, Trainer
|
13 |
from transformers.trainer_pt_utils import get_parameter_names
|
14 |
|
15 |
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
16 |
from axolotl.utils.callbacks import SavePeftModelCallback
|
17 |
|
18 |
|
19 |
+
class OneCycleLRSchedulerTrainer(Trainer):
|
20 |
+
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
21 |
+
optimizer=self.optimizer if optimizer is None else optimizer
|
22 |
+
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
|
23 |
+
num_training_steps=num_training_steps
|
24 |
+
pct_start = num_warmup_steps / num_training_steps
|
25 |
+
|
26 |
+
lr_scheduler = OneCycleLR(
|
27 |
+
optimizer,
|
28 |
+
max_lr=self.args.learning_rate,
|
29 |
+
total_steps=num_training_steps,
|
30 |
+
pct_start=pct_start,
|
31 |
+
div_factor=6,
|
32 |
+
)
|
33 |
+
|
34 |
+
return lr_scheduler
|
35 |
+
|
36 |
+
|
37 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
38 |
total_num_steps = int(
|
39 |
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
|
|
81 |
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
82 |
if cfg.fsdp_config:
|
83 |
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
84 |
+
# can't set optimizers directly on trainer when using fsdp, so set them here
|
85 |
+
if cfg.optimizer:
|
86 |
+
training_arguments_kwargs["optim"] = cfg.optimizer
|
87 |
|
88 |
# deepspeed
|
89 |
if (
|
|
|
140 |
cfg.optimizer == "adamw_bnb_8bit"
|
141 |
and not cfg.load_4bit
|
142 |
and not "deepspeed" in training_arguments_kwargs
|
143 |
+
and not cfg.fsdp
|
144 |
):
|
145 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
146 |
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
|
216 |
else:
|
217 |
data_collator_kwargs["pad_to_multiple_of"] = 8
|
218 |
|
219 |
+
trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer
|
220 |
+
trainer = trainer_cls(
|
221 |
model=model,
|
222 |
train_dataset=train_dataset,
|
223 |
eval_dataset=eval_dataset,
|