winglian commited on
Commit
da97285
·
unverified ·
1 Parent(s): 2dc4310

keep gate in fp32 for 16 bit loras (#1105)

Browse files

* keep gate in fp32 for loras

* add e2e check for lora w/o flash attention for mixtral to check gate

* add checks for gate in fp32 for mixtral, add typehints to train outputs

* mixtral doesn't support basic lora :facepalm:

add lora tests @ 16bit and fix gate layer check
fix the parameter name, was using the old disco name
don't lora over the gate so we can check that is in fp32
fix dtype check

* ensure we're using fp16/bf16 for 16bit and qlora is always going to be in uint8

src/axolotl/train.py CHANGED
@@ -5,14 +5,16 @@ import signal
5
  import sys
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
- from typing import Optional
9
 
10
  import torch
11
  import transformers.modelcard
12
  from accelerate.logging import get_logger
13
  from datasets import Dataset
14
  from optimum.bettertransformer import BetterTransformer
 
15
  from pkg_resources import get_distribution # type: ignore
 
16
  from transformers.deepspeed import is_deepspeed_zero3_enabled
17
 
18
  from axolotl.common.cli import TrainerCliArgs
@@ -43,7 +45,7 @@ class TrainDatasetMeta:
43
 
44
  def train(
45
  *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
46
- ):
47
  # load the tokenizer first
48
  LOG.debug(
49
  f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
 
5
  import sys
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
+ from typing import Optional, Tuple, Union
9
 
10
  import torch
11
  import transformers.modelcard
12
  from accelerate.logging import get_logger
13
  from datasets import Dataset
14
  from optimum.bettertransformer import BetterTransformer
15
+ from peft import PeftModel
16
  from pkg_resources import get_distribution # type: ignore
17
+ from transformers import PreTrainedModel, PreTrainedTokenizer
18
  from transformers.deepspeed import is_deepspeed_zero3_enabled
19
 
20
  from axolotl.common.cli import TrainerCliArgs
 
45
 
46
  def train(
47
  *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
48
+ ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
49
  # load the tokenizer first
50
  LOG.debug(
51
  f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
src/axolotl/utils/models.py CHANGED
@@ -590,7 +590,7 @@ def load_model(
590
  # make sure these are fp32 per Ramesh et al. (2021)
591
  embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
592
  for name, module in model.named_modules():
593
- if "norm" in name:
594
  module.to(torch.float32)
595
  if model_config.model_type == "btlm":
596
  # don't upcast lm_head for btlm
 
590
  # make sure these are fp32 per Ramesh et al. (2021)
591
  embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
592
  for name, module in model.named_modules():
593
+ if any(m in name for m in ["norm", "gate"]):
594
  module.to(torch.float32)
595
  if model_config.model_type == "btlm":
596
  # don't upcast lm_head for btlm
tests/e2e/test_mixtral.py CHANGED
@@ -7,6 +7,7 @@ import os
7
  import unittest
8
  from pathlib import Path
9
 
 
10
  from transformers.utils import is_torch_bf16_gpu_available
11
 
12
  from axolotl.cli import load_datasets
@@ -27,7 +28,7 @@ class TestMixtral(unittest.TestCase):
27
  """
28
 
29
  @with_temp_dir
30
- def test_qlora(self, temp_dir):
31
  # pylint: disable=duplicate-code
32
  cfg = DictDefault(
33
  {
@@ -37,10 +38,18 @@ class TestMixtral(unittest.TestCase):
37
  "sequence_len": 1024,
38
  "load_in_4bit": True,
39
  "adapter": "qlora",
40
- "lora_r": 16,
41
- "lora_alpha": 32,
42
  "lora_dropout": 0.1,
43
- "lora_target_linear": True,
 
 
 
 
 
 
 
 
44
  "val_set_size": 0.1,
45
  "special_tokens": {},
46
  "datasets": [
@@ -65,7 +74,179 @@ class TestMixtral(unittest.TestCase):
65
  cli_args = TrainerCliArgs()
66
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
67
 
68
- train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  assert (Path(temp_dir) / "adapter_model.bin").exists()
70
 
71
  @with_temp_dir
 
7
  import unittest
8
  from pathlib import Path
9
 
10
+ import torch
11
  from transformers.utils import is_torch_bf16_gpu_available
12
 
13
  from axolotl.cli import load_datasets
 
28
  """
29
 
30
  @with_temp_dir
31
+ def test_qlora_w_fa2(self, temp_dir):
32
  # pylint: disable=duplicate-code
33
  cfg = DictDefault(
34
  {
 
38
  "sequence_len": 1024,
39
  "load_in_4bit": True,
40
  "adapter": "qlora",
41
+ "lora_r": 4,
42
+ "lora_alpha": 8,
43
  "lora_dropout": 0.1,
44
+ "lora_target_modules": [
45
+ "o_proj",
46
+ "w3",
47
+ "k_proj",
48
+ "v_proj",
49
+ "w1",
50
+ "q_proj",
51
+ "w2",
52
+ ],
53
  "val_set_size": 0.1,
54
  "special_tokens": {},
55
  "datasets": [
 
74
  cli_args = TrainerCliArgs()
75
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
76
 
77
+ model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
78
+ assert (
79
+ model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
80
+ == torch.uint8
81
+ )
82
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
83
+
84
+ @with_temp_dir
85
+ def test_qlora_wo_fa2(self, temp_dir):
86
+ # pylint: disable=duplicate-code
87
+ cfg = DictDefault(
88
+ {
89
+ "base_model": "hf-internal-testing/Mixtral-tiny",
90
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
91
+ "flash_attention": False,
92
+ "sequence_len": 1024,
93
+ "load_in_4bit": True,
94
+ "adapter": "qlora",
95
+ "lora_r": 4,
96
+ "lora_alpha": 8,
97
+ "lora_dropout": 0.1,
98
+ "lora_target_modules": [
99
+ "o_proj",
100
+ "w3",
101
+ "k_proj",
102
+ "v_proj",
103
+ "w1",
104
+ "q_proj",
105
+ "w2",
106
+ ],
107
+ "val_set_size": 0.1,
108
+ "special_tokens": {},
109
+ "datasets": [
110
+ {
111
+ "path": "mhenrichsen/alpaca_2k_test",
112
+ "type": "alpaca",
113
+ },
114
+ ],
115
+ "num_epochs": 2,
116
+ "micro_batch_size": 2,
117
+ "gradient_accumulation_steps": 1,
118
+ "output_dir": temp_dir,
119
+ "learning_rate": 0.00001,
120
+ "optimizer": "adamw_bnb_8bit",
121
+ "lr_scheduler": "cosine",
122
+ "max_steps": 20,
123
+ "save_steps": 10,
124
+ "eval_steps": 10,
125
+ }
126
+ )
127
+ normalize_config(cfg)
128
+ cli_args = TrainerCliArgs()
129
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
130
+
131
+ model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
132
+ assert (
133
+ model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
134
+ == torch.uint8
135
+ )
136
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
137
+
138
+ @with_temp_dir
139
+ def test_16bit_lora_w_fa2(self, temp_dir):
140
+ # pylint: disable=duplicate-code
141
+ cfg = DictDefault(
142
+ {
143
+ "base_model": "hf-internal-testing/Mixtral-tiny",
144
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
145
+ "flash_attention": True,
146
+ "sequence_len": 1024,
147
+ "adapter": "lora",
148
+ "lora_r": 4,
149
+ "lora_alpha": 8,
150
+ "lora_dropout": 0.1,
151
+ "lora_target_modules": [
152
+ "o_proj",
153
+ "w3",
154
+ "k_proj",
155
+ "v_proj",
156
+ "w1",
157
+ "q_proj",
158
+ "w2",
159
+ ],
160
+ "val_set_size": 0.1,
161
+ "special_tokens": {},
162
+ "datasets": [
163
+ {
164
+ "path": "mhenrichsen/alpaca_2k_test",
165
+ "type": "alpaca",
166
+ },
167
+ ],
168
+ "num_epochs": 2,
169
+ "micro_batch_size": 2,
170
+ "gradient_accumulation_steps": 1,
171
+ "output_dir": temp_dir,
172
+ "learning_rate": 0.00001,
173
+ "optimizer": "adamw_bnb_8bit",
174
+ "lr_scheduler": "cosine",
175
+ "max_steps": 20,
176
+ "save_steps": 10,
177
+ "eval_steps": 10,
178
+ }
179
+ )
180
+ if is_torch_bf16_gpu_available():
181
+ cfg.bf16 = True
182
+ else:
183
+ cfg.fp16 = True
184
+ normalize_config(cfg)
185
+ cli_args = TrainerCliArgs()
186
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
187
+
188
+ model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
189
+ assert (
190
+ model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
191
+ == torch.float32
192
+ )
193
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
194
+
195
+ @with_temp_dir
196
+ def test_16bit_lora_wo_fa2(self, temp_dir):
197
+ # pylint: disable=duplicate-code
198
+ cfg = DictDefault(
199
+ {
200
+ "base_model": "hf-internal-testing/Mixtral-tiny",
201
+ "tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
202
+ "flash_attention": False,
203
+ "sequence_len": 1024,
204
+ "adapter": "lora",
205
+ "lora_r": 4,
206
+ "lora_alpha": 8,
207
+ "lora_dropout": 0.1,
208
+ "lora_target_modules": [
209
+ "o_proj",
210
+ "w3",
211
+ "k_proj",
212
+ "v_proj",
213
+ "w1",
214
+ "q_proj",
215
+ "w2",
216
+ ],
217
+ "val_set_size": 0.1,
218
+ "special_tokens": {},
219
+ "datasets": [
220
+ {
221
+ "path": "mhenrichsen/alpaca_2k_test",
222
+ "type": "alpaca",
223
+ },
224
+ ],
225
+ "num_epochs": 2,
226
+ "micro_batch_size": 2,
227
+ "gradient_accumulation_steps": 1,
228
+ "output_dir": temp_dir,
229
+ "learning_rate": 0.00001,
230
+ "optimizer": "adamw_bnb_8bit",
231
+ "lr_scheduler": "cosine",
232
+ "max_steps": 20,
233
+ "save_steps": 10,
234
+ "eval_steps": 10,
235
+ }
236
+ )
237
+ normalize_config(cfg)
238
+ if is_torch_bf16_gpu_available():
239
+ cfg.bf16 = True
240
+ else:
241
+ cfg.fp16 = True
242
+ cli_args = TrainerCliArgs()
243
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
244
+
245
+ model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
246
+ assert (
247
+ model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
248
+ == torch.float32
249
+ )
250
  assert (Path(temp_dir) / "adapter_model.bin").exists()
251
 
252
  @with_temp_dir