misc fixes to add gptq tests (#621)
Browse files* misc fixes to add gptq tests
* set bf16 needed for fa2
- src/axolotl/utils/bench.py +5 -1
- src/axolotl/utils/models.py +23 -12
- src/axolotl/utils/trainer.py +1 -0
- tests/e2e/test_lora_llama.py +56 -2
- tests/e2e/test_phi.py +8 -6
src/axolotl/utils/bench.py
CHANGED
@@ -19,7 +19,11 @@ def check_cuda_device(default_value):
|
|
19 |
def wrapper(*args, **kwargs):
|
20 |
device = kwargs.get("device", args[0] if args else None)
|
21 |
|
22 |
-
if
|
|
|
|
|
|
|
|
|
23 |
return default_value
|
24 |
|
25 |
return func(*args, **kwargs)
|
|
|
19 |
def wrapper(*args, **kwargs):
|
20 |
device = kwargs.get("device", args[0] if args else None)
|
21 |
|
22 |
+
if (
|
23 |
+
not torch.cuda.is_available()
|
24 |
+
or device == "auto"
|
25 |
+
or torch.device(device).type == "cpu"
|
26 |
+
):
|
27 |
return default_value
|
28 |
|
29 |
return func(*args, **kwargs)
|
src/axolotl/utils/models.py
CHANGED
@@ -10,6 +10,7 @@ import torch
|
|
10 |
import transformers
|
11 |
from optimum.bettertransformer import BetterTransformer
|
12 |
from peft import PeftConfig, prepare_model_for_kbit_training
|
|
|
13 |
from transformers import ( # noqa: F401
|
14 |
AutoConfig,
|
15 |
AutoModelForCausalLM,
|
@@ -309,16 +310,26 @@ def load_model(
|
|
309 |
):
|
310 |
config.max_sequence_length = cfg.sequence_len
|
311 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
except Exception as err: # pylint: disable=broad-exception-caught
|
323 |
LOG.error(
|
324 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
@@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):
|
|
466 |
|
467 |
|
468 |
def find_all_linear_names(model):
|
469 |
-
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
470 |
lora_module_names = set()
|
471 |
for name, module in model.named_modules():
|
472 |
-
if isinstance(module, cls):
|
473 |
names = name.split(".")
|
474 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
475 |
|
|
|
10 |
import transformers
|
11 |
from optimum.bettertransformer import BetterTransformer
|
12 |
from peft import PeftConfig, prepare_model_for_kbit_training
|
13 |
+
from peft.tuners.lora import QuantLinear
|
14 |
from transformers import ( # noqa: F401
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
|
|
310 |
):
|
311 |
config.max_sequence_length = cfg.sequence_len
|
312 |
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
313 |
+
if cfg.gptq:
|
314 |
+
model = AutoModelForCausalLM.from_pretrained(
|
315 |
+
base_model,
|
316 |
+
config=config,
|
317 |
+
device_map=cfg.device_map,
|
318 |
+
torch_dtype=cfg.torch_dtype,
|
319 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
320 |
+
**model_kwargs,
|
321 |
+
)
|
322 |
+
else:
|
323 |
+
model = AutoModelForCausalLM.from_pretrained(
|
324 |
+
base_model,
|
325 |
+
config=config,
|
326 |
+
device_map=cfg.device_map,
|
327 |
+
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
328 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
329 |
+
torch_dtype=cfg.torch_dtype,
|
330 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
331 |
+
**model_kwargs,
|
332 |
+
)
|
333 |
except Exception as err: # pylint: disable=broad-exception-caught
|
334 |
LOG.error(
|
335 |
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
|
|
477 |
|
478 |
|
479 |
def find_all_linear_names(model):
|
480 |
+
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
481 |
lora_module_names = set()
|
482 |
for name, module in model.named_modules():
|
483 |
+
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
|
484 |
names = name.split(".")
|
485 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
486 |
|
src/axolotl/utils/trainer.py
CHANGED
@@ -676,6 +676,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
676 |
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
677 |
and cfg.val_set_size > 0
|
678 |
and cfg.save_steps
|
|
|
679 |
and cfg.save_steps % cfg.eval_steps == 0
|
680 |
)
|
681 |
or False,
|
|
|
676 |
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
677 |
and cfg.val_set_size > 0
|
678 |
and cfg.save_steps
|
679 |
+
and cfg.eval_steps
|
680 |
and cfg.save_steps % cfg.eval_steps == 0
|
681 |
)
|
682 |
or False,
|
tests/e2e/test_lora_llama.py
CHANGED
@@ -6,6 +6,7 @@ import logging
|
|
6 |
import os
|
7 |
import tempfile
|
8 |
import unittest
|
|
|
9 |
|
10 |
from axolotl.cli import load_datasets
|
11 |
from axolotl.common.cli import TrainerCliArgs
|
@@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):
|
|
24 |
|
25 |
def test_lora(self):
|
26 |
# pylint: disable=duplicate-code
|
|
|
27 |
cfg = DictDefault(
|
28 |
{
|
29 |
"base_model": "JackFram/llama-68m",
|
@@ -51,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
|
|
51 |
"num_epochs": 2,
|
52 |
"micro_batch_size": 8,
|
53 |
"gradient_accumulation_steps": 1,
|
54 |
-
"output_dir":
|
55 |
"learning_rate": 0.00001,
|
56 |
"optimizer": "adamw_torch",
|
57 |
"lr_scheduler": "cosine",
|
@@ -62,9 +64,11 @@ class TestLoraLlama(unittest.TestCase):
|
|
62 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
63 |
|
64 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
65 |
|
66 |
def test_lora_packing(self):
|
67 |
# pylint: disable=duplicate-code
|
|
|
68 |
cfg = DictDefault(
|
69 |
{
|
70 |
"base_model": "JackFram/llama-68m",
|
@@ -94,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
|
|
94 |
"num_epochs": 2,
|
95 |
"micro_batch_size": 8,
|
96 |
"gradient_accumulation_steps": 1,
|
97 |
-
"output_dir":
|
98 |
"learning_rate": 0.00001,
|
99 |
"optimizer": "adamw_torch",
|
100 |
"lr_scheduler": "cosine",
|
@@ -105,3 +109,53 @@ class TestLoraLlama(unittest.TestCase):
|
|
105 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
106 |
|
107 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import os
|
7 |
import tempfile
|
8 |
import unittest
|
9 |
+
from pathlib import Path
|
10 |
|
11 |
from axolotl.cli import load_datasets
|
12 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
25 |
|
26 |
def test_lora(self):
|
27 |
# pylint: disable=duplicate-code
|
28 |
+
output_dir = tempfile.mkdtemp()
|
29 |
cfg = DictDefault(
|
30 |
{
|
31 |
"base_model": "JackFram/llama-68m",
|
|
|
53 |
"num_epochs": 2,
|
54 |
"micro_batch_size": 8,
|
55 |
"gradient_accumulation_steps": 1,
|
56 |
+
"output_dir": output_dir,
|
57 |
"learning_rate": 0.00001,
|
58 |
"optimizer": "adamw_torch",
|
59 |
"lr_scheduler": "cosine",
|
|
|
64 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
65 |
|
66 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
67 |
+
assert (Path(output_dir) / "adapter_model.bin").exists()
|
68 |
|
69 |
def test_lora_packing(self):
|
70 |
# pylint: disable=duplicate-code
|
71 |
+
output_dir = tempfile.mkdtemp()
|
72 |
cfg = DictDefault(
|
73 |
{
|
74 |
"base_model": "JackFram/llama-68m",
|
|
|
98 |
"num_epochs": 2,
|
99 |
"micro_batch_size": 8,
|
100 |
"gradient_accumulation_steps": 1,
|
101 |
+
"output_dir": output_dir,
|
102 |
"learning_rate": 0.00001,
|
103 |
"optimizer": "adamw_torch",
|
104 |
"lr_scheduler": "cosine",
|
|
|
109 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
110 |
|
111 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
112 |
+
assert (Path(output_dir) / "adapter_model.bin").exists()
|
113 |
+
|
114 |
+
def test_lora_gptq(self):
|
115 |
+
# pylint: disable=duplicate-code
|
116 |
+
output_dir = tempfile.mkdtemp()
|
117 |
+
cfg = DictDefault(
|
118 |
+
{
|
119 |
+
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
120 |
+
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
121 |
+
"model_type": "AutoModelForCausalLM",
|
122 |
+
"tokenizer_type": "LlamaTokenizer",
|
123 |
+
"sequence_len": 1024,
|
124 |
+
"sample_packing": True,
|
125 |
+
"flash_attention": True,
|
126 |
+
"load_in_8bit": True,
|
127 |
+
"adapter": "lora",
|
128 |
+
"gptq": True,
|
129 |
+
"gptq_disable_exllama": True,
|
130 |
+
"lora_r": 32,
|
131 |
+
"lora_alpha": 64,
|
132 |
+
"lora_dropout": 0.05,
|
133 |
+
"lora_target_linear": True,
|
134 |
+
"val_set_size": 0.1,
|
135 |
+
"special_tokens": {
|
136 |
+
"unk_token": "<unk>",
|
137 |
+
"bos_token": "<s>",
|
138 |
+
"eos_token": "</s>",
|
139 |
+
},
|
140 |
+
"datasets": [
|
141 |
+
{
|
142 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
143 |
+
"type": "alpaca",
|
144 |
+
},
|
145 |
+
],
|
146 |
+
"num_epochs": 2,
|
147 |
+
"save_steps": 0.5,
|
148 |
+
"micro_batch_size": 8,
|
149 |
+
"gradient_accumulation_steps": 1,
|
150 |
+
"output_dir": output_dir,
|
151 |
+
"learning_rate": 0.00001,
|
152 |
+
"optimizer": "adamw_torch",
|
153 |
+
"lr_scheduler": "cosine",
|
154 |
+
}
|
155 |
+
)
|
156 |
+
normalize_config(cfg)
|
157 |
+
cli_args = TrainerCliArgs()
|
158 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
159 |
+
|
160 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
161 |
+
assert (Path(output_dir) / "adapter_model.bin").exists()
|
tests/e2e/test_phi.py
CHANGED
@@ -31,9 +31,9 @@ class TestPhi(unittest.TestCase):
|
|
31 |
"trust_remote_code": True,
|
32 |
"model_type": "MixFormerSequentialForCausalLM",
|
33 |
"tokenizer_type": "AutoTokenizer",
|
34 |
-
"sequence_len":
|
35 |
"sample_packing": False,
|
36 |
-
"load_in_8bit":
|
37 |
"adapter": None,
|
38 |
"val_set_size": 0.1,
|
39 |
"special_tokens": {
|
@@ -55,8 +55,9 @@ class TestPhi(unittest.TestCase):
|
|
55 |
"gradient_accumulation_steps": 1,
|
56 |
"output_dir": tempfile.mkdtemp(),
|
57 |
"learning_rate": 0.00001,
|
58 |
-
"optimizer": "
|
59 |
"lr_scheduler": "cosine",
|
|
|
60 |
}
|
61 |
)
|
62 |
normalize_config(cfg)
|
@@ -74,9 +75,9 @@ class TestPhi(unittest.TestCase):
|
|
74 |
"trust_remote_code": True,
|
75 |
"model_type": "MixFormerSequentialForCausalLM",
|
76 |
"tokenizer_type": "AutoTokenizer",
|
77 |
-
"sequence_len":
|
78 |
"sample_packing": True,
|
79 |
-
"load_in_8bit":
|
80 |
"adapter": None,
|
81 |
"val_set_size": 0.1,
|
82 |
"special_tokens": {
|
@@ -98,8 +99,9 @@ class TestPhi(unittest.TestCase):
|
|
98 |
"gradient_accumulation_steps": 1,
|
99 |
"output_dir": tempfile.mkdtemp(),
|
100 |
"learning_rate": 0.00001,
|
101 |
-
"optimizer": "
|
102 |
"lr_scheduler": "cosine",
|
|
|
103 |
}
|
104 |
)
|
105 |
normalize_config(cfg)
|
|
|
31 |
"trust_remote_code": True,
|
32 |
"model_type": "MixFormerSequentialForCausalLM",
|
33 |
"tokenizer_type": "AutoTokenizer",
|
34 |
+
"sequence_len": 512,
|
35 |
"sample_packing": False,
|
36 |
+
"load_in_8bit": False,
|
37 |
"adapter": None,
|
38 |
"val_set_size": 0.1,
|
39 |
"special_tokens": {
|
|
|
55 |
"gradient_accumulation_steps": 1,
|
56 |
"output_dir": tempfile.mkdtemp(),
|
57 |
"learning_rate": 0.00001,
|
58 |
+
"optimizer": "adamw_bnb_8bit",
|
59 |
"lr_scheduler": "cosine",
|
60 |
+
"bf16": True,
|
61 |
}
|
62 |
)
|
63 |
normalize_config(cfg)
|
|
|
75 |
"trust_remote_code": True,
|
76 |
"model_type": "MixFormerSequentialForCausalLM",
|
77 |
"tokenizer_type": "AutoTokenizer",
|
78 |
+
"sequence_len": 512,
|
79 |
"sample_packing": True,
|
80 |
+
"load_in_8bit": False,
|
81 |
"adapter": None,
|
82 |
"val_set_size": 0.1,
|
83 |
"special_tokens": {
|
|
|
99 |
"gradient_accumulation_steps": 1,
|
100 |
"output_dir": tempfile.mkdtemp(),
|
101 |
"learning_rate": 0.00001,
|
102 |
+
"optimizer": "adamw_bnb_8bit",
|
103 |
"lr_scheduler": "cosine",
|
104 |
+
"bf16": True,
|
105 |
}
|
106 |
)
|
107 |
normalize_config(cfg)
|