add support for opimum bettertransformers
Browse files- configs/gpt_neox_20b.yml +18 -12
- requirements.txt +1 -0
- scripts/finetune.py +11 -4
- src/axolotl/utils/models.py +6 -2
- src/axolotl/utils/validation.py +8 -0
configs/gpt_neox_20b.yml
CHANGED
@@ -1,24 +1,25 @@
|
|
1 |
base_model: EleutherAI/gpt-neox-20b
|
|
|
2 |
base_model_ignore_patterns: pytorch* # prefer safetensors
|
3 |
model_type: GPTNeoXForCausalLM
|
4 |
tokenizer_type: AutoTokenizer
|
5 |
-
load_in_8bit:
|
|
|
|
|
6 |
datasets:
|
7 |
-
- path:
|
8 |
type: alpaca
|
9 |
-
shards: 4
|
10 |
-
shards_index: 0
|
11 |
dataset_prepared_path: last_run_prepared
|
12 |
val_set_size: 0.05
|
13 |
-
adapter:
|
14 |
lora_model_dir:
|
15 |
sequence_len: 2048
|
16 |
max_packed_sequence_len: 2048
|
17 |
-
lora_r:
|
18 |
lora_alpha: 32
|
19 |
-
lora_dropout: 0.
|
20 |
lora_target_modules:
|
21 |
-
|
22 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
23 |
wandb_project: gpt4all-neox-20b
|
24 |
wandb_watch:
|
@@ -26,14 +27,19 @@ wandb_run_id:
|
|
26 |
wandb_log_model:
|
27 |
output_dir: ./gpt4all-neox-20b
|
28 |
gradient_accumulation_steps: 1
|
29 |
-
micro_batch_size:
|
30 |
num_epochs: 5
|
31 |
learning_rate: 0.00003
|
32 |
-
|
|
|
33 |
train_on_inputs: false
|
34 |
group_by_length: false
|
35 |
-
bf16:
|
36 |
-
|
|
|
|
|
|
|
37 |
early_stopping_patience:
|
38 |
resume_from_checkpoint:
|
39 |
local_rank:
|
|
|
|
1 |
base_model: EleutherAI/gpt-neox-20b
|
2 |
+
base_model_config: EleutherAI/gpt-neox-20b
|
3 |
base_model_ignore_patterns: pytorch* # prefer safetensors
|
4 |
model_type: GPTNeoXForCausalLM
|
5 |
tokenizer_type: AutoTokenizer
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: true
|
8 |
+
load_4bit: false
|
9 |
datasets:
|
10 |
+
- path: vicgalle/alpaca-gpt4
|
11 |
type: alpaca
|
|
|
|
|
12 |
dataset_prepared_path: last_run_prepared
|
13 |
val_set_size: 0.05
|
14 |
+
adapter:
|
15 |
lora_model_dir:
|
16 |
sequence_len: 2048
|
17 |
max_packed_sequence_len: 2048
|
18 |
+
lora_r: 64
|
19 |
lora_alpha: 32
|
20 |
+
lora_dropout: 0.0
|
21 |
lora_target_modules:
|
22 |
+
lora_target_linear: true
|
23 |
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
24 |
wandb_project: gpt4all-neox-20b
|
25 |
wandb_watch:
|
|
|
27 |
wandb_log_model:
|
28 |
output_dir: ./gpt4all-neox-20b
|
29 |
gradient_accumulation_steps: 1
|
30 |
+
micro_batch_size: 2
|
31 |
num_epochs: 5
|
32 |
learning_rate: 0.00003
|
33 |
+
optimizer: paged_adamw_32bit
|
34 |
+
lr_scheduler: cosine
|
35 |
train_on_inputs: false
|
36 |
group_by_length: false
|
37 |
+
bf16: false
|
38 |
+
fp16: false
|
39 |
+
float16: true
|
40 |
+
tf32: true
|
41 |
+
flash_optimum: true
|
42 |
early_stopping_patience:
|
43 |
resume_from_checkpoint:
|
44 |
local_rank:
|
45 |
+
gradient_checkpointing: true
|
requirements.txt
CHANGED
@@ -11,6 +11,7 @@ sentencepiece
|
|
11 |
wandb
|
12 |
einops
|
13 |
xformers
|
|
|
14 |
# qlora things
|
15 |
bert-score==0.3.13
|
16 |
evaluate==0.4.0
|
|
|
11 |
wandb
|
12 |
einops
|
13 |
xformers
|
14 |
+
optimum
|
15 |
# qlora things
|
16 |
bert-score==0.3.13
|
17 |
evaluate==0.4.0
|
scripts/finetune.py
CHANGED
@@ -6,6 +6,7 @@ import os
|
|
6 |
import random
|
7 |
import signal
|
8 |
import sys
|
|
|
9 |
from pathlib import Path
|
10 |
from typing import Any, Dict, List, Optional, Union
|
11 |
|
@@ -19,6 +20,8 @@ from axolotl.utils.dict import DictDefault
|
|
19 |
from axolotl.utils.models import load_model, load_tokenizer
|
20 |
|
21 |
# add src to the pythonpath so we don't need to pip install this
|
|
|
|
|
22 |
from axolotl.utils.tokenization import check_dataset_labels
|
23 |
from axolotl.utils.trainer import setup_trainer
|
24 |
from axolotl.utils.validation import validate_config
|
@@ -264,12 +267,14 @@ def train(
|
|
264 |
|
265 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
266 |
if cfg.local_rank == 0:
|
|
|
|
|
|
|
|
|
|
|
267 |
signal.signal(
|
268 |
signal.SIGINT,
|
269 |
-
lambda
|
270 |
-
model.save_pretrained(cfg.output_dir),
|
271 |
-
sys.exit(0),
|
272 |
-
),
|
273 |
)
|
274 |
|
275 |
logging.info("Starting trainer...")
|
@@ -299,6 +304,8 @@ def train(
|
|
299 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
300 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
301 |
if cfg.local_rank == 0:
|
|
|
|
|
302 |
model.save_pretrained(cfg.output_dir)
|
303 |
|
304 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
|
|
6 |
import random
|
7 |
import signal
|
8 |
import sys
|
9 |
+
from functools import partial
|
10 |
from pathlib import Path
|
11 |
from typing import Any, Dict, List, Optional, Union
|
12 |
|
|
|
20 |
from axolotl.utils.models import load_model, load_tokenizer
|
21 |
|
22 |
# add src to the pythonpath so we don't need to pip install this
|
23 |
+
from optimum.bettertransformer import BetterTransformer
|
24 |
+
|
25 |
from axolotl.utils.tokenization import check_dataset_labels
|
26 |
from axolotl.utils.trainer import setup_trainer
|
27 |
from axolotl.utils.validation import validate_config
|
|
|
267 |
|
268 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
269 |
if cfg.local_rank == 0:
|
270 |
+
def terminate_handler(signum, frame, model):
|
271 |
+
if cfg.flash_optimum:
|
272 |
+
model = BetterTransformer.reverse(model)
|
273 |
+
model.save_pretrained(cfg.output_dir)
|
274 |
+
sys.exit(0)
|
275 |
signal.signal(
|
276 |
signal.SIGINT,
|
277 |
+
lambda signum, frame: terminate_handler(signum, frame, model)
|
|
|
|
|
|
|
278 |
)
|
279 |
|
280 |
logging.info("Starting trainer...")
|
|
|
304 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
305 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
306 |
if cfg.local_rank == 0:
|
307 |
+
if cfg.flash_optimum:
|
308 |
+
model = BetterTransformer.reverse(model)
|
309 |
model.save_pretrained(cfg.output_dir)
|
310 |
|
311 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
src/axolotl/utils/models.py
CHANGED
@@ -11,7 +11,8 @@ import bitsandbytes as bnb
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
from transformers import PreTrainedModel # noqa: F401
|
14 |
-
from
|
|
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
17 |
AutoTokenizer,
|
@@ -137,7 +138,7 @@ def load_model(
|
|
137 |
|
138 |
if cfg.bf16:
|
139 |
torch_dtype = torch.bfloat16
|
140 |
-
elif cfg.load_in_8bit or cfg.fp16:
|
141 |
torch_dtype = torch.float16
|
142 |
else:
|
143 |
torch_dtype = torch.float32
|
@@ -342,6 +343,9 @@ def load_model(
|
|
342 |
logging.warning("there are no parameters that require gradient updates")
|
343 |
model.config.use_cache = False
|
344 |
|
|
|
|
|
|
|
345 |
# TODO resume_from_checkpoint handling
|
346 |
return model, lora_config
|
347 |
|
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
from transformers import PreTrainedModel # noqa: F401
|
14 |
+
from optimum.bettertransformer import BetterTransformer
|
15 |
+
from transformers import (
|
16 |
AutoConfig,
|
17 |
AutoModelForCausalLM,
|
18 |
AutoTokenizer,
|
|
|
138 |
|
139 |
if cfg.bf16:
|
140 |
torch_dtype = torch.bfloat16
|
141 |
+
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
142 |
torch_dtype = torch.float16
|
143 |
else:
|
144 |
torch_dtype = torch.float32
|
|
|
343 |
logging.warning("there are no parameters that require gradient updates")
|
344 |
model.config.use_cache = False
|
345 |
|
346 |
+
if cfg.flash_optimum:
|
347 |
+
model = BetterTransformer.transform(model)
|
348 |
+
|
349 |
# TODO resume_from_checkpoint handling
|
350 |
return model, lora_config
|
351 |
|
src/axolotl/utils/validation.py
CHANGED
@@ -57,6 +57,14 @@ def validate_config(cfg):
|
|
57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
58 |
raise ValueError("FSDP is not supported for falcon models")
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# TODO
|
61 |
# MPT 7b
|
62 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
|
|
57 |
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
|
58 |
raise ValueError("FSDP is not supported for falcon models")
|
59 |
|
60 |
+
if cfg.flash_optimum is True:
|
61 |
+
if cfg.adapter:
|
62 |
+
logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
63 |
+
if cfg.fp16 or cfg.bf16:
|
64 |
+
raise ValueError("AMP is not supported with BetterTransformer")
|
65 |
+
if cfg.float16 is not True:
|
66 |
+
logging.warning("You should probably set float16 to true")
|
67 |
+
|
68 |
# TODO
|
69 |
# MPT 7b
|
70 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|