winglian commited on
Commit
8c2e05a
1 Parent(s): 2d65f47

relora: magnitude pruning of the optimizer (#1245)

Browse files

* magnitude pruning of the optimizer

* add alpaca chat template and fix relora patch

* fix handling of lora adapter for relora

* fix merge and save call

* fixes for 8-bit lora merge

* save intermediate checkpoint adapters

* auto merge

* fix eval check

* handle relora annealing

* fix anneal step logic

* chore: lint

* misx fix

* fix types

* Update tests/e2e/test_relora_llama.py

* check for safetensors saved from relora

src/axolotl/core/trainer_builder.py CHANGED
@@ -126,6 +126,10 @@ class AxolotlTrainingArguments(TrainingArguments):
126
  default=None,
127
  metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
128
  )
 
 
 
 
129
  bench_split: Optional[str] = field(
130
  default="eval", metadata={"help": "The benchmark split to run on"}
131
  )
@@ -478,10 +482,14 @@ class ReLoRATrainer(AxolotlTrainer):
478
  warmup_steps = (
479
  self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
480
  )
 
 
 
481
  self.lr_scheduler = ReLoRAScheduler(
482
  optimizer,
483
  lr_scheduler,
484
  self.args.relora_steps,
 
485
  warmup_steps,
486
  )
487
  else:
@@ -893,6 +901,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
893
  ] = self.cfg.micro_batch_size
894
  training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
895
  training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
 
896
  training_arguments_kwargs = self.hook_pre_create_training_args(
897
  training_arguments_kwargs
898
  )
 
126
  default=None,
127
  metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
128
  )
129
+ relora_anneal_steps: Optional[int] = field(
130
+ default=None,
131
+ metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
132
+ )
133
  bench_split: Optional[str] = field(
134
  default="eval", metadata={"help": "The benchmark split to run on"}
135
  )
 
482
  warmup_steps = (
483
  self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
484
  )
485
+ anneal_steps = (
486
+ self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
487
+ )
488
  self.lr_scheduler = ReLoRAScheduler(
489
  optimizer,
490
  lr_scheduler,
491
  self.args.relora_steps,
492
+ anneal_steps,
493
  warmup_steps,
494
  )
495
  else:
 
901
  ] = self.cfg.micro_batch_size
902
  training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
903
  training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
904
+ training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
905
  training_arguments_kwargs = self.hook_pre_create_training_args(
906
  training_arguments_kwargs
907
  )
src/axolotl/monkeypatch/relora.py CHANGED
@@ -4,14 +4,16 @@ import json
4
  import logging
5
  import os.path
6
  import shutil
 
7
  from pathlib import Path
8
- from typing import Dict, List, Sequence
9
 
10
  import bitsandbytes as bnb
11
  import peft
12
  import safetensors.torch as st
13
  import torch
14
  from huggingface_hub import snapshot_download
 
15
  from torch.optim.lr_scheduler import LRScheduler
16
  from torch.optim.optimizer import Optimizer
17
  from transformers import (
@@ -23,23 +25,50 @@ from transformers import (
23
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
24
 
25
  from axolotl.utils.dict import DictDefault
26
- from axolotl.utils.distributed import is_main_process
27
 
28
  LOG = logging.getLogger("axolotl.relora")
29
 
30
 
31
- def reset_optimizer(optimizer: torch.optim.Optimizer):
32
- for group in optimizer.param_groups:
33
- for param in group["params"]:
34
- param_state = optimizer.state[param]
35
- for key in param_state:
36
- if "qmap" in key:
37
- continue
38
 
39
- if key == "step" and isinstance(param_state[key], int):
40
- param_state[key] = 0
41
- else:
42
- param_state[key] = torch.zeros_like(param_state[key])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  class ReLoRACallback(TrainerCallback):
@@ -97,6 +126,25 @@ class ReLoRACallback(TrainerCallback):
97
  "relora",
98
  )
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  with torch.no_grad():
101
  merge_and_save(
102
  model,
@@ -107,7 +155,11 @@ class ReLoRACallback(TrainerCallback):
107
  actually_save=is_main_process(),
108
  cpu_offload=self.cpu_offload,
109
  )
110
- reset_optimizer(optimizer)
 
 
 
 
111
 
112
  if self.quantized:
113
  self.last_full_model = checkpoint_folder
@@ -197,11 +249,13 @@ class ReLoRAScheduler(LRScheduler):
197
  inner_schedule: LRScheduler,
198
  relora_steps: int,
199
  warmup_steps: int,
 
200
  min_lr_scale: float = 0.001,
201
  ) -> None:
202
  self.inner_schedule = inner_schedule
203
  self.relora_steps = relora_steps
204
  self.warmup_steps = warmup_steps
 
205
  self.min_lr_scale = min_lr_scale
206
  super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
207
 
@@ -210,10 +264,20 @@ class ReLoRAScheduler(LRScheduler):
210
 
211
  original = self.inner_schedule.get_lr()
212
  step = self.last_epoch
 
213
  if step < self.relora_steps:
214
  scale = 1
215
  else:
216
- cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
 
 
 
 
 
 
 
 
 
217
  scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
218
 
219
  if isinstance(original, Sequence):
@@ -238,7 +302,11 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
238
 
239
  def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
240
  if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
241
- adapter = layer.active_adapter
 
 
 
 
242
  return (
243
  peft.utils.transpose(
244
  layer.lora_B[adapter].weight.detach().to(device)
@@ -248,7 +316,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
248
  * layer.scaling[adapter]
249
  )
250
 
251
- return layer.get_delta_weight().to(device)
252
 
253
 
254
  def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
@@ -273,9 +341,9 @@ def update_weights(
273
  ):
274
  if reinit:
275
  for adapter_name in target.lora_A:
276
- target.reset_lora_parameters(adapter_name)
277
  for adapter_name in target.lora_embedding_A:
278
- target.reset_lora_parameters(adapter_name)
279
 
280
  if isinstance(target, peft.tuners.lora.Linear4bit):
281
  # This could be faster, but the quantization of Linear4bit weights occurs
@@ -286,7 +354,9 @@ def update_weights(
286
  target.weight.data = new_weight.cpu()
287
  target.to(device)
288
  elif isinstance(target, peft.tuners.lora.Linear8bitLt):
289
- target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
 
 
290
  else:
291
  target.weight.data = new_weight.to(device)
292
 
@@ -304,14 +374,17 @@ def merge_and_save(
304
 
305
  if not quantized:
306
  for module_name, target in modules.items():
307
- update = target.get_delta_weight(target.active_adapter).detach()
 
 
 
308
  target.weight.data += update
309
 
310
  if reinit:
311
  for adapter_name in target.lora_A:
312
- target.reset_lora_parameters(adapter_name)
313
  for adapter_name in target.lora_embedding_A:
314
- target.reset_lora_parameters(adapter_name)
315
  return
316
 
317
  os.makedirs(model_dst, exist_ok=True)
@@ -363,6 +436,7 @@ def merge_and_save(
363
  LOG.info(f"saving tensors to {shard_fn}")
364
  st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
365
 
 
366
  del in_tensors
367
  del out_tensors
368
  torch.cuda.empty_cache()
 
4
  import logging
5
  import os.path
6
  import shutil
7
+ from functools import partial
8
  from pathlib import Path
9
+ from typing import Dict, List, Sequence, Union
10
 
11
  import bitsandbytes as bnb
12
  import peft
13
  import safetensors.torch as st
14
  import torch
15
  from huggingface_hub import snapshot_download
16
+ from torch.distributed.optim import ZeroRedundancyOptimizer
17
  from torch.optim.lr_scheduler import LRScheduler
18
  from torch.optim.optimizer import Optimizer
19
  from transformers import (
 
25
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
26
 
27
  from axolotl.utils.dict import DictDefault
28
+ from axolotl.utils.distributed import barrier, is_main_process
29
 
30
  LOG = logging.getLogger("axolotl.relora")
31
 
32
 
33
+ @torch.no_grad()
34
+ def magnitude_pruning_(tensor, prune_ratio):
35
+ tensor_magnitude = torch.abs(tensor)
36
+ threshold = torch.quantile(
37
+ tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
38
+ ).to(dtype=tensor.dtype)
 
39
 
40
+ mask = tensor_magnitude > threshold
41
+ tensor.mul_(mask.to(dtype=tensor.dtype))
42
+
43
+
44
+ def reset_optimizer(
45
+ optimizer: torch.optim.Optimizer,
46
+ *,
47
+ reset_params: list[str], # where str is the key to a torch.nn.Parameter
48
+ optimizer_state_keys: list[str],
49
+ ):
50
+ pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
51
+ n_zeros = 0
52
+ n_total = 0
53
+
54
+ optimizer_state = optimizer.state
55
+ if isinstance(optimizer, ZeroRedundancyOptimizer):
56
+ optimizer_state = optimizer.optim.state
57
+
58
+ for param in reset_params:
59
+ param_state = optimizer_state[param]
60
+ if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
61
+ continue
62
+ for key in optimizer_state_keys:
63
+ pruning_fn(
64
+ param_state[key]
65
+ ) # pruning fn has to be inplace to keep the same keys in the dict
66
+ n_total += param_state[key].numel()
67
+ n_zeros += torch.sum(param_state[key] == 0).item()
68
+
69
+ _zeroed = n_zeros / (1e-7 + n_total) * 100
70
+ LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
71
+ LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
72
 
73
 
74
  class ReLoRACallback(TrainerCallback):
 
126
  "relora",
127
  )
128
 
129
+ if "adam" in args.optim.lower():
130
+ optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
131
+ else:
132
+ raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
133
+
134
+ lora_params = [
135
+ n
136
+ for n, p in model.named_parameters()
137
+ if p.requires_grad and "lora_" in n
138
+ ]
139
+
140
+ model.save_pretrained(
141
+ os.path.join(
142
+ args.output_dir,
143
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
144
+ "adapter",
145
+ ),
146
+ safe_serialization=True,
147
+ )
148
  with torch.no_grad():
149
  merge_and_save(
150
  model,
 
155
  actually_save=is_main_process(),
156
  cpu_offload=self.cpu_offload,
157
  )
158
+ reset_optimizer(
159
+ optimizer,
160
+ reset_params=lora_params,
161
+ optimizer_state_keys=optimizer_state_keys,
162
+ )
163
 
164
  if self.quantized:
165
  self.last_full_model = checkpoint_folder
 
249
  inner_schedule: LRScheduler,
250
  relora_steps: int,
251
  warmup_steps: int,
252
+ anneal_steps: int = 1,
253
  min_lr_scale: float = 0.001,
254
  ) -> None:
255
  self.inner_schedule = inner_schedule
256
  self.relora_steps = relora_steps
257
  self.warmup_steps = warmup_steps
258
+ self.anneal_steps = anneal_steps
259
  self.min_lr_scale = min_lr_scale
260
  super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
261
 
 
264
 
265
  original = self.inner_schedule.get_lr()
266
  step = self.last_epoch
267
+
268
  if step < self.relora_steps:
269
  scale = 1
270
  else:
271
+ per_relora_progress = step % self.relora_steps
272
+ if per_relora_progress < self.warmup_steps:
273
+ cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
274
+ elif per_relora_progress > (self.relora_steps - self.anneal_steps):
275
+ cycle_t = min(
276
+ 1.0,
277
+ (self.relora_steps - per_relora_progress) / self.anneal_steps,
278
+ )
279
+ else:
280
+ cycle_t = 1
281
  scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
282
 
283
  if isinstance(original, Sequence):
 
302
 
303
  def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
304
  if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
305
+ adapter: Union[List[str], str] = layer.active_adapter
306
+ if isinstance(adapter, list):
307
+ if len(adapter) > 1:
308
+ raise ValueError("unhandled relora for multiple adapters")
309
+ adapter = adapter[0]
310
  return (
311
  peft.utils.transpose(
312
  layer.lora_B[adapter].weight.detach().to(device)
 
316
  * layer.scaling[adapter]
317
  )
318
 
319
+ raise ValueError("unhandled lora layer type")
320
 
321
 
322
  def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
 
341
  ):
342
  if reinit:
343
  for adapter_name in target.lora_A:
344
+ target.reset_lora_parameters(adapter_name, True)
345
  for adapter_name in target.lora_embedding_A:
346
+ target.reset_lora_parameters(adapter_name, True)
347
 
348
  if isinstance(target, peft.tuners.lora.Linear4bit):
349
  # This could be faster, but the quantization of Linear4bit weights occurs
 
354
  target.weight.data = new_weight.cpu()
355
  target.to(device)
356
  elif isinstance(target, peft.tuners.lora.Linear8bitLt):
357
+ target.weight.data = (
358
+ bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
359
+ )
360
  else:
361
  target.weight.data = new_weight.to(device)
362
 
 
374
 
375
  if not quantized:
376
  for module_name, target in modules.items():
377
+ active_adapter = target.active_adapter
378
+ if isinstance(active_adapter, list):
379
+ active_adapter = active_adapter[0]
380
+ update = target.get_delta_weight(active_adapter).detach()
381
  target.weight.data += update
382
 
383
  if reinit:
384
  for adapter_name in target.lora_A:
385
+ target.reset_lora_parameters(adapter_name, True)
386
  for adapter_name in target.lora_embedding_A:
387
+ target.reset_lora_parameters(adapter_name, True)
388
  return
389
 
390
  os.makedirs(model_dst, exist_ok=True)
 
436
  LOG.info(f"saving tensors to {shard_fn}")
437
  st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
438
 
439
+ barrier()
440
  del in_tensors
441
  del out_tensors
442
  torch.cuda.empty_cache()
src/axolotl/prompt_strategies/instruct.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the InstructShareGPTPromptTokenizingStrategy class"""
2
+ from typing import Any, Dict, Optional
3
+
4
+ from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
5
+ from axolotl.prompters import ShareGPTPrompterV2
6
+
7
+
8
+ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
9
+ conversation = (
10
+ ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
11
+ )
12
+ strategy = InstructShareGPTPromptTokenizingStrategy(
13
+ # pylint: disable=duplicate-code
14
+ ShareGPTPrompterV2(
15
+ conversation=conversation,
16
+ ),
17
+ tokenizer,
18
+ cfg.train_on_inputs,
19
+ cfg.sequence_len,
20
+ )
21
+ return strategy
22
+
23
+
24
+ class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
25
+ """
26
+ basic sharegpt strategy to grab conversations from the sample row
27
+ """
28
+
29
+ def get_conversation_thread(self, prompt):
30
+ return [
31
+ {"from": "human", "value": prompt["instruction"]},
32
+ {"from": "gpt", "value": prompt["output"]},
33
+ ]
src/axolotl/utils/chat_templates.py CHANGED
@@ -19,6 +19,7 @@ def chat_templates(user_choice: str):
19
  """
20
 
21
  templates = {
 
22
  "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
23
  "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
24
  }
 
19
  """
20
 
21
  templates = {
22
+ "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
23
  "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
24
  "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
25
  }
src/axolotl/utils/config.py CHANGED
@@ -447,7 +447,11 @@ def validate_config(cfg):
447
  "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
448
  )
449
 
450
- if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
 
 
 
 
451
  raise ValueError(
452
  "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
453
  )
 
447
  "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
448
  )
449
 
450
+ if (
451
+ cfg.val_set_size == 0
452
+ and (cfg.eval_steps or cfg.evaluation_strategy)
453
+ and not cfg.test_datasets
454
+ ):
455
  raise ValueError(
456
  "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
457
  )
src/axolotl/utils/data.py CHANGED
@@ -140,7 +140,7 @@ def load_tokenized_prepared_datasets(
140
  + "|".join(
141
  sorted(
142
  [
143
- f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
144
  for d in cfg_datasets
145
  ]
146
  )
 
140
  + "|".join(
141
  sorted(
142
  [
143
+ f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
144
  for d in cfg_datasets
145
  ]
146
  )
src/axolotl/utils/models.py CHANGED
@@ -8,7 +8,13 @@ import addict
8
  import bitsandbytes as bnb
9
  import torch
10
  import transformers
11
- from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
 
 
 
 
 
 
12
  from peft.tuners.lora import QuantLinear
13
  from transformers import ( # noqa: F401
14
  AddedToken,
@@ -628,6 +634,9 @@ def load_model(
628
  LOG.exception(err)
629
  raise err
630
 
 
 
 
631
  embeddings_len = (
632
  math.ceil(len(tokenizer) / 32) * 32
633
  if cfg.resize_token_embeddings_to_32x
@@ -782,7 +791,7 @@ def load_adapter(model, cfg, adapter, inference=False):
782
 
783
  def load_llama_adapter(model, cfg):
784
  # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
785
- from peft import AdaptionPromptConfig, PeftModel, get_peft_model
786
 
787
  peft_config = AdaptionPromptConfig(
788
  adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -828,7 +837,7 @@ def find_all_linear_names(model):
828
  def load_lora(model, cfg, inference=False, config_only=False):
829
  # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
830
 
831
- from peft import LoraConfig, PeftModel, get_peft_model
832
 
833
  lora_target_modules = list(cfg.lora_target_modules or [])
834
 
 
8
  import bitsandbytes as bnb
9
  import torch
10
  import transformers
11
+ from peft import (
12
+ LoftQConfig,
13
+ PeftConfig,
14
+ PeftModel,
15
+ PeftModelForCausalLM,
16
+ prepare_model_for_kbit_training,
17
+ )
18
  from peft.tuners.lora import QuantLinear
19
  from transformers import ( # noqa: F401
20
  AddedToken,
 
634
  LOG.exception(err)
635
  raise err
636
 
637
+ if isinstance(model, (PeftModel, PeftModelForCausalLM)):
638
+ model = model.merge_and_unload()
639
+
640
  embeddings_len = (
641
  math.ceil(len(tokenizer) / 32) * 32
642
  if cfg.resize_token_embeddings_to_32x
 
791
 
792
  def load_llama_adapter(model, cfg):
793
  # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
794
+ from peft import AdaptionPromptConfig, get_peft_model
795
 
796
  peft_config = AdaptionPromptConfig(
797
  adapter_layers=cfg.peft_adapter.layers, # layers (L)
 
837
  def load_lora(model, cfg, inference=False, config_only=False):
838
  # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
839
 
840
+ from peft import LoraConfig, get_peft_model
841
 
842
  lora_target_modules = list(cfg.lora_target_modules or [])
843
 
tests/e2e/patched/test_mistral_samplepack.py CHANGED
@@ -7,8 +7,6 @@ import os
7
  import unittest
8
  from pathlib import Path
9
 
10
- from transformers.utils import is_torch_bf16_gpu_available
11
-
12
  from axolotl.cli import load_datasets
13
  from axolotl.common.cli import TrainerCliArgs
14
  from axolotl.train import train
@@ -63,6 +61,7 @@ class TestMistral(unittest.TestCase):
63
  "max_steps": 20,
64
  "save_steps": 10,
65
  "eval_steps": 10,
 
66
  }
67
  )
68
  normalize_config(cfg)
@@ -103,12 +102,9 @@ class TestMistral(unittest.TestCase):
103
  "max_steps": 20,
104
  "save_steps": 10,
105
  "eval_steps": 10,
 
106
  }
107
  )
108
- if is_torch_bf16_gpu_available():
109
- cfg.bf16 = True
110
- else:
111
- cfg.fp16 = True
112
  normalize_config(cfg)
113
  cli_args = TrainerCliArgs()
114
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
7
  import unittest
8
  from pathlib import Path
9
 
 
 
10
  from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
12
  from axolotl.train import train
 
61
  "max_steps": 20,
62
  "save_steps": 10,
63
  "eval_steps": 10,
64
+ "bf16": "auto",
65
  }
66
  )
67
  normalize_config(cfg)
 
102
  "max_steps": 20,
103
  "save_steps": 10,
104
  "eval_steps": 10,
105
+ "bf16": "auto",
106
  }
107
  )
 
 
 
 
108
  normalize_config(cfg)
109
  cli_args = TrainerCliArgs()
110
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
tests/e2e/test_relora_llama.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for relora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from axolotl.cli import load_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ from .utils import with_temp_dir
17
+
18
+ LOG = logging.getLogger("axolotl.tests.e2e")
19
+ os.environ["WANDB_DISABLED"] = "true"
20
+
21
+
22
+ class TestReLoraLlama(unittest.TestCase):
23
+ """
24
+ Test case for Llama models using LoRA
25
+ """
26
+
27
+ @with_temp_dir
28
+ def test_relora(self, temp_dir):
29
+ # pylint: disable=duplicate-code
30
+ cfg = DictDefault(
31
+ {
32
+ "base_model": "JackFram/llama-68m",
33
+ "tokenizer_type": "LlamaTokenizer",
34
+ "sequence_len": 1024,
35
+ "load_in_8bit": True,
36
+ "adapter": "lora",
37
+ "lora_r": 32,
38
+ "lora_alpha": 16,
39
+ "lora_dropout": 0.05,
40
+ "lora_target_modules": ["q_proj", "v_proj"],
41
+ "relora_steps": 25,
42
+ "relora_warmup_steps": 5,
43
+ "relora_anneal_steps": 5,
44
+ "relora_cpu_offload": True,
45
+ "val_set_size": 0.0,
46
+ "special_tokens": {},
47
+ "datasets": [
48
+ {
49
+ "path": "mhenrichsen/alpaca_2k_test",
50
+ "type": "alpaca",
51
+ },
52
+ ],
53
+ "warmup_steps": 15,
54
+ "num_epochs": 2,
55
+ "micro_batch_size": 4,
56
+ "gradient_accumulation_steps": 1,
57
+ "output_dir": temp_dir,
58
+ "learning_rate": 0.00001,
59
+ "optimizer": "adamw_torch",
60
+ "lr_scheduler": "cosine",
61
+ }
62
+ )
63
+ normalize_config(cfg)
64
+ cli_args = TrainerCliArgs()
65
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
66
+
67
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
68
+ assert (Path(temp_dir) / "model.safetensors").exists()