winglian commited on
Commit
03e5907
1 Parent(s): 97d3776

misc fixes to add gptq tests (#621)

Browse files

* misc fixes to add gptq tests

* set bf16 needed for fa2

src/axolotl/utils/bench.py CHANGED
@@ -19,7 +19,11 @@ def check_cuda_device(default_value):
19
  def wrapper(*args, **kwargs):
20
  device = kwargs.get("device", args[0] if args else None)
21
 
22
- if not torch.cuda.is_available() or device == "auto" or device == "cpu":
 
 
 
 
23
  return default_value
24
 
25
  return func(*args, **kwargs)
 
19
  def wrapper(*args, **kwargs):
20
  device = kwargs.get("device", args[0] if args else None)
21
 
22
+ if (
23
+ not torch.cuda.is_available()
24
+ or device == "auto"
25
+ or torch.device(device).type == "cpu"
26
+ ):
27
  return default_value
28
 
29
  return func(*args, **kwargs)
src/axolotl/utils/models.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  import transformers
11
  from optimum.bettertransformer import BetterTransformer
12
  from peft import PeftConfig, prepare_model_for_kbit_training
 
13
  from transformers import ( # noqa: F401
14
  AutoConfig,
15
  AutoModelForCausalLM,
@@ -309,16 +310,26 @@ def load_model(
309
  ):
310
  config.max_sequence_length = cfg.sequence_len
311
  LOG.warning(f"increasing context length to {cfg.sequence_len}")
312
- model = AutoModelForCausalLM.from_pretrained(
313
- base_model,
314
- config=config,
315
- device_map=cfg.device_map,
316
- load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
317
- load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
318
- torch_dtype=cfg.torch_dtype,
319
- trust_remote_code=cfg.trust_remote_code or False,
320
- **model_kwargs,
321
- )
 
 
 
 
 
 
 
 
 
 
322
  except Exception as err: # pylint: disable=broad-exception-caught
323
  LOG.error(
324
  "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
@@ -466,10 +477,10 @@ def load_llama_adapter(model, cfg):
466
 
467
 
468
  def find_all_linear_names(model):
469
- cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
470
  lora_module_names = set()
471
  for name, module in model.named_modules():
472
- if isinstance(module, cls):
473
  names = name.split(".")
474
  lora_module_names.add(names[0] if len(names) == 1 else names[-1])
475
 
 
10
  import transformers
11
  from optimum.bettertransformer import BetterTransformer
12
  from peft import PeftConfig, prepare_model_for_kbit_training
13
+ from peft.tuners.lora import QuantLinear
14
  from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
 
310
  ):
311
  config.max_sequence_length = cfg.sequence_len
312
  LOG.warning(f"increasing context length to {cfg.sequence_len}")
313
+ if cfg.gptq:
314
+ model = AutoModelForCausalLM.from_pretrained(
315
+ base_model,
316
+ config=config,
317
+ device_map=cfg.device_map,
318
+ torch_dtype=cfg.torch_dtype,
319
+ trust_remote_code=cfg.trust_remote_code or False,
320
+ **model_kwargs,
321
+ )
322
+ else:
323
+ model = AutoModelForCausalLM.from_pretrained(
324
+ base_model,
325
+ config=config,
326
+ device_map=cfg.device_map,
327
+ load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
328
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
329
+ torch_dtype=cfg.torch_dtype,
330
+ trust_remote_code=cfg.trust_remote_code or False,
331
+ **model_kwargs,
332
+ )
333
  except Exception as err: # pylint: disable=broad-exception-caught
334
  LOG.error(
335
  "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
 
477
 
478
 
479
  def find_all_linear_names(model):
480
+ cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
481
  lora_module_names = set()
482
  for name, module in model.named_modules():
483
+ if isinstance(module, cls) or "Linear" in module.__class__.__name__:
484
  names = name.split(".")
485
  lora_module_names.add(names[0] if len(names) == 1 else names[-1])
486
 
src/axolotl/utils/trainer.py CHANGED
@@ -676,6 +676,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
676
  (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
677
  and cfg.val_set_size > 0
678
  and cfg.save_steps
 
679
  and cfg.save_steps % cfg.eval_steps == 0
680
  )
681
  or False,
 
676
  (cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
677
  and cfg.val_set_size > 0
678
  and cfg.save_steps
679
+ and cfg.eval_steps
680
  and cfg.save_steps % cfg.eval_steps == 0
681
  )
682
  or False,
tests/e2e/test_lora_llama.py CHANGED
@@ -6,6 +6,7 @@ import logging
6
  import os
7
  import tempfile
8
  import unittest
 
9
 
10
  from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
@@ -24,6 +25,7 @@ class TestLoraLlama(unittest.TestCase):
24
 
25
  def test_lora(self):
26
  # pylint: disable=duplicate-code
 
27
  cfg = DictDefault(
28
  {
29
  "base_model": "JackFram/llama-68m",
@@ -51,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
51
  "num_epochs": 2,
52
  "micro_batch_size": 8,
53
  "gradient_accumulation_steps": 1,
54
- "output_dir": tempfile.mkdtemp(),
55
  "learning_rate": 0.00001,
56
  "optimizer": "adamw_torch",
57
  "lr_scheduler": "cosine",
@@ -62,9 +64,11 @@ class TestLoraLlama(unittest.TestCase):
62
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
63
 
64
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
 
65
 
66
  def test_lora_packing(self):
67
  # pylint: disable=duplicate-code
 
68
  cfg = DictDefault(
69
  {
70
  "base_model": "JackFram/llama-68m",
@@ -94,7 +98,7 @@ class TestLoraLlama(unittest.TestCase):
94
  "num_epochs": 2,
95
  "micro_batch_size": 8,
96
  "gradient_accumulation_steps": 1,
97
- "output_dir": tempfile.mkdtemp(),
98
  "learning_rate": 0.00001,
99
  "optimizer": "adamw_torch",
100
  "lr_scheduler": "cosine",
@@ -105,3 +109,53 @@ class TestLoraLlama(unittest.TestCase):
105
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
106
 
107
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import os
7
  import tempfile
8
  import unittest
9
+ from pathlib import Path
10
 
11
  from axolotl.cli import load_datasets
12
  from axolotl.common.cli import TrainerCliArgs
 
25
 
26
  def test_lora(self):
27
  # pylint: disable=duplicate-code
28
+ output_dir = tempfile.mkdtemp()
29
  cfg = DictDefault(
30
  {
31
  "base_model": "JackFram/llama-68m",
 
53
  "num_epochs": 2,
54
  "micro_batch_size": 8,
55
  "gradient_accumulation_steps": 1,
56
+ "output_dir": output_dir,
57
  "learning_rate": 0.00001,
58
  "optimizer": "adamw_torch",
59
  "lr_scheduler": "cosine",
 
64
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
65
 
66
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
67
+ assert (Path(output_dir) / "adapter_model.bin").exists()
68
 
69
  def test_lora_packing(self):
70
  # pylint: disable=duplicate-code
71
+ output_dir = tempfile.mkdtemp()
72
  cfg = DictDefault(
73
  {
74
  "base_model": "JackFram/llama-68m",
 
98
  "num_epochs": 2,
99
  "micro_batch_size": 8,
100
  "gradient_accumulation_steps": 1,
101
+ "output_dir": output_dir,
102
  "learning_rate": 0.00001,
103
  "optimizer": "adamw_torch",
104
  "lr_scheduler": "cosine",
 
109
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
110
 
111
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
112
+ assert (Path(output_dir) / "adapter_model.bin").exists()
113
+
114
+ def test_lora_gptq(self):
115
+ # pylint: disable=duplicate-code
116
+ output_dir = tempfile.mkdtemp()
117
+ cfg = DictDefault(
118
+ {
119
+ "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
120
+ "base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
121
+ "model_type": "AutoModelForCausalLM",
122
+ "tokenizer_type": "LlamaTokenizer",
123
+ "sequence_len": 1024,
124
+ "sample_packing": True,
125
+ "flash_attention": True,
126
+ "load_in_8bit": True,
127
+ "adapter": "lora",
128
+ "gptq": True,
129
+ "gptq_disable_exllama": True,
130
+ "lora_r": 32,
131
+ "lora_alpha": 64,
132
+ "lora_dropout": 0.05,
133
+ "lora_target_linear": True,
134
+ "val_set_size": 0.1,
135
+ "special_tokens": {
136
+ "unk_token": "<unk>",
137
+ "bos_token": "<s>",
138
+ "eos_token": "</s>",
139
+ },
140
+ "datasets": [
141
+ {
142
+ "path": "mhenrichsen/alpaca_2k_test",
143
+ "type": "alpaca",
144
+ },
145
+ ],
146
+ "num_epochs": 2,
147
+ "save_steps": 0.5,
148
+ "micro_batch_size": 8,
149
+ "gradient_accumulation_steps": 1,
150
+ "output_dir": output_dir,
151
+ "learning_rate": 0.00001,
152
+ "optimizer": "adamw_torch",
153
+ "lr_scheduler": "cosine",
154
+ }
155
+ )
156
+ normalize_config(cfg)
157
+ cli_args = TrainerCliArgs()
158
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
159
+
160
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
161
+ assert (Path(output_dir) / "adapter_model.bin").exists()
tests/e2e/test_phi.py CHANGED
@@ -31,9 +31,9 @@ class TestPhi(unittest.TestCase):
31
  "trust_remote_code": True,
32
  "model_type": "MixFormerSequentialForCausalLM",
33
  "tokenizer_type": "AutoTokenizer",
34
- "sequence_len": 2048,
35
  "sample_packing": False,
36
- "load_in_8bit": True,
37
  "adapter": None,
38
  "val_set_size": 0.1,
39
  "special_tokens": {
@@ -55,8 +55,9 @@ class TestPhi(unittest.TestCase):
55
  "gradient_accumulation_steps": 1,
56
  "output_dir": tempfile.mkdtemp(),
57
  "learning_rate": 0.00001,
58
- "optimizer": "adamw_torch",
59
  "lr_scheduler": "cosine",
 
60
  }
61
  )
62
  normalize_config(cfg)
@@ -74,9 +75,9 @@ class TestPhi(unittest.TestCase):
74
  "trust_remote_code": True,
75
  "model_type": "MixFormerSequentialForCausalLM",
76
  "tokenizer_type": "AutoTokenizer",
77
- "sequence_len": 2048,
78
  "sample_packing": True,
79
- "load_in_8bit": True,
80
  "adapter": None,
81
  "val_set_size": 0.1,
82
  "special_tokens": {
@@ -98,8 +99,9 @@ class TestPhi(unittest.TestCase):
98
  "gradient_accumulation_steps": 1,
99
  "output_dir": tempfile.mkdtemp(),
100
  "learning_rate": 0.00001,
101
- "optimizer": "adamw_torch",
102
  "lr_scheduler": "cosine",
 
103
  }
104
  )
105
  normalize_config(cfg)
 
31
  "trust_remote_code": True,
32
  "model_type": "MixFormerSequentialForCausalLM",
33
  "tokenizer_type": "AutoTokenizer",
34
+ "sequence_len": 512,
35
  "sample_packing": False,
36
+ "load_in_8bit": False,
37
  "adapter": None,
38
  "val_set_size": 0.1,
39
  "special_tokens": {
 
55
  "gradient_accumulation_steps": 1,
56
  "output_dir": tempfile.mkdtemp(),
57
  "learning_rate": 0.00001,
58
+ "optimizer": "adamw_bnb_8bit",
59
  "lr_scheduler": "cosine",
60
+ "bf16": True,
61
  }
62
  )
63
  normalize_config(cfg)
 
75
  "trust_remote_code": True,
76
  "model_type": "MixFormerSequentialForCausalLM",
77
  "tokenizer_type": "AutoTokenizer",
78
+ "sequence_len": 512,
79
  "sample_packing": True,
80
+ "load_in_8bit": False,
81
  "adapter": None,
82
  "val_set_size": 0.1,
83
  "special_tokens": {
 
99
  "gradient_accumulation_steps": 1,
100
  "output_dir": tempfile.mkdtemp(),
101
  "learning_rate": 0.00001,
102
+ "optimizer": "adamw_bnb_8bit",
103
  "lr_scheduler": "cosine",
104
+ "bf16": True,
105
  }
106
  )
107
  normalize_config(cfg)