relora: magnitude pruning of the optimizer (#1245)
Browse files* magnitude pruning of the optimizer
* add alpaca chat template and fix relora patch
* fix handling of lora adapter for relora
* fix merge and save call
* fixes for 8-bit lora merge
* save intermediate checkpoint adapters
* auto merge
* fix eval check
* handle relora annealing
* fix anneal step logic
* chore: lint
* misx fix
* fix types
* Update tests/e2e/test_relora_llama.py
* check for safetensors saved from relora
- src/axolotl/core/trainer_builder.py +9 -0
- src/axolotl/monkeypatch/relora.py +97 -23
- src/axolotl/prompt_strategies/instruct.py +33 -0
- src/axolotl/utils/chat_templates.py +1 -0
- src/axolotl/utils/config.py +5 -1
- src/axolotl/utils/data.py +1 -1
- src/axolotl/utils/models.py +12 -3
- tests/e2e/patched/test_mistral_samplepack.py +2 -6
- tests/e2e/test_relora_llama.py +68 -0
src/axolotl/core/trainer_builder.py
CHANGED
@@ -126,6 +126,10 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|
126 |
default=None,
|
127 |
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
128 |
)
|
|
|
|
|
|
|
|
|
129 |
bench_split: Optional[str] = field(
|
130 |
default="eval", metadata={"help": "The benchmark split to run on"}
|
131 |
)
|
@@ -478,10 +482,14 @@ class ReLoRATrainer(AxolotlTrainer):
|
|
478 |
warmup_steps = (
|
479 |
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
480 |
)
|
|
|
|
|
|
|
481 |
self.lr_scheduler = ReLoRAScheduler(
|
482 |
optimizer,
|
483 |
lr_scheduler,
|
484 |
self.args.relora_steps,
|
|
|
485 |
warmup_steps,
|
486 |
)
|
487 |
else:
|
@@ -893,6 +901,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
893 |
] = self.cfg.micro_batch_size
|
894 |
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
895 |
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
|
|
896 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
897 |
training_arguments_kwargs
|
898 |
)
|
|
|
126 |
default=None,
|
127 |
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
128 |
)
|
129 |
+
relora_anneal_steps: Optional[int] = field(
|
130 |
+
default=None,
|
131 |
+
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
132 |
+
)
|
133 |
bench_split: Optional[str] = field(
|
134 |
default="eval", metadata={"help": "The benchmark split to run on"}
|
135 |
)
|
|
|
482 |
warmup_steps = (
|
483 |
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
484 |
)
|
485 |
+
anneal_steps = (
|
486 |
+
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
487 |
+
)
|
488 |
self.lr_scheduler = ReLoRAScheduler(
|
489 |
optimizer,
|
490 |
lr_scheduler,
|
491 |
self.args.relora_steps,
|
492 |
+
anneal_steps,
|
493 |
warmup_steps,
|
494 |
)
|
495 |
else:
|
|
|
901 |
] = self.cfg.micro_batch_size
|
902 |
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
903 |
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
904 |
+
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
|
905 |
training_arguments_kwargs = self.hook_pre_create_training_args(
|
906 |
training_arguments_kwargs
|
907 |
)
|
src/axolotl/monkeypatch/relora.py
CHANGED
@@ -4,14 +4,16 @@ import json
|
|
4 |
import logging
|
5 |
import os.path
|
6 |
import shutil
|
|
|
7 |
from pathlib import Path
|
8 |
-
from typing import Dict, List, Sequence
|
9 |
|
10 |
import bitsandbytes as bnb
|
11 |
import peft
|
12 |
import safetensors.torch as st
|
13 |
import torch
|
14 |
from huggingface_hub import snapshot_download
|
|
|
15 |
from torch.optim.lr_scheduler import LRScheduler
|
16 |
from torch.optim.optimizer import Optimizer
|
17 |
from transformers import (
|
@@ -23,23 +25,50 @@ from transformers import (
|
|
23 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
24 |
|
25 |
from axolotl.utils.dict import DictDefault
|
26 |
-
from axolotl.utils.distributed import is_main_process
|
27 |
|
28 |
LOG = logging.getLogger("axolotl.relora")
|
29 |
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
continue
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
class ReLoRACallback(TrainerCallback):
|
@@ -97,6 +126,25 @@ class ReLoRACallback(TrainerCallback):
|
|
97 |
"relora",
|
98 |
)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
with torch.no_grad():
|
101 |
merge_and_save(
|
102 |
model,
|
@@ -107,7 +155,11 @@ class ReLoRACallback(TrainerCallback):
|
|
107 |
actually_save=is_main_process(),
|
108 |
cpu_offload=self.cpu_offload,
|
109 |
)
|
110 |
-
reset_optimizer(
|
|
|
|
|
|
|
|
|
111 |
|
112 |
if self.quantized:
|
113 |
self.last_full_model = checkpoint_folder
|
@@ -197,11 +249,13 @@ class ReLoRAScheduler(LRScheduler):
|
|
197 |
inner_schedule: LRScheduler,
|
198 |
relora_steps: int,
|
199 |
warmup_steps: int,
|
|
|
200 |
min_lr_scale: float = 0.001,
|
201 |
) -> None:
|
202 |
self.inner_schedule = inner_schedule
|
203 |
self.relora_steps = relora_steps
|
204 |
self.warmup_steps = warmup_steps
|
|
|
205 |
self.min_lr_scale = min_lr_scale
|
206 |
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
207 |
|
@@ -210,10 +264,20 @@ class ReLoRAScheduler(LRScheduler):
|
|
210 |
|
211 |
original = self.inner_schedule.get_lr()
|
212 |
step = self.last_epoch
|
|
|
213 |
if step < self.relora_steps:
|
214 |
scale = 1
|
215 |
else:
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
218 |
|
219 |
if isinstance(original, Sequence):
|
@@ -238,7 +302,11 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
|
238 |
|
239 |
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
240 |
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
241 |
-
adapter = layer.active_adapter
|
|
|
|
|
|
|
|
|
242 |
return (
|
243 |
peft.utils.transpose(
|
244 |
layer.lora_B[adapter].weight.detach().to(device)
|
@@ -248,7 +316,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
|
|
248 |
* layer.scaling[adapter]
|
249 |
)
|
250 |
|
251 |
-
|
252 |
|
253 |
|
254 |
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
@@ -273,9 +341,9 @@ def update_weights(
|
|
273 |
):
|
274 |
if reinit:
|
275 |
for adapter_name in target.lora_A:
|
276 |
-
target.reset_lora_parameters(adapter_name)
|
277 |
for adapter_name in target.lora_embedding_A:
|
278 |
-
target.reset_lora_parameters(adapter_name)
|
279 |
|
280 |
if isinstance(target, peft.tuners.lora.Linear4bit):
|
281 |
# This could be faster, but the quantization of Linear4bit weights occurs
|
@@ -286,7 +354,9 @@ def update_weights(
|
|
286 |
target.weight.data = new_weight.cpu()
|
287 |
target.to(device)
|
288 |
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
289 |
-
target.weight =
|
|
|
|
|
290 |
else:
|
291 |
target.weight.data = new_weight.to(device)
|
292 |
|
@@ -304,14 +374,17 @@ def merge_and_save(
|
|
304 |
|
305 |
if not quantized:
|
306 |
for module_name, target in modules.items():
|
307 |
-
|
|
|
|
|
|
|
308 |
target.weight.data += update
|
309 |
|
310 |
if reinit:
|
311 |
for adapter_name in target.lora_A:
|
312 |
-
target.reset_lora_parameters(adapter_name)
|
313 |
for adapter_name in target.lora_embedding_A:
|
314 |
-
target.reset_lora_parameters(adapter_name)
|
315 |
return
|
316 |
|
317 |
os.makedirs(model_dst, exist_ok=True)
|
@@ -363,6 +436,7 @@ def merge_and_save(
|
|
363 |
LOG.info(f"saving tensors to {shard_fn}")
|
364 |
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
365 |
|
|
|
366 |
del in_tensors
|
367 |
del out_tensors
|
368 |
torch.cuda.empty_cache()
|
|
|
4 |
import logging
|
5 |
import os.path
|
6 |
import shutil
|
7 |
+
from functools import partial
|
8 |
from pathlib import Path
|
9 |
+
from typing import Dict, List, Sequence, Union
|
10 |
|
11 |
import bitsandbytes as bnb
|
12 |
import peft
|
13 |
import safetensors.torch as st
|
14 |
import torch
|
15 |
from huggingface_hub import snapshot_download
|
16 |
+
from torch.distributed.optim import ZeroRedundancyOptimizer
|
17 |
from torch.optim.lr_scheduler import LRScheduler
|
18 |
from torch.optim.optimizer import Optimizer
|
19 |
from transformers import (
|
|
|
25 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
26 |
|
27 |
from axolotl.utils.dict import DictDefault
|
28 |
+
from axolotl.utils.distributed import barrier, is_main_process
|
29 |
|
30 |
LOG = logging.getLogger("axolotl.relora")
|
31 |
|
32 |
|
33 |
+
@torch.no_grad()
|
34 |
+
def magnitude_pruning_(tensor, prune_ratio):
|
35 |
+
tensor_magnitude = torch.abs(tensor)
|
36 |
+
threshold = torch.quantile(
|
37 |
+
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
38 |
+
).to(dtype=tensor.dtype)
|
|
|
39 |
|
40 |
+
mask = tensor_magnitude > threshold
|
41 |
+
tensor.mul_(mask.to(dtype=tensor.dtype))
|
42 |
+
|
43 |
+
|
44 |
+
def reset_optimizer(
|
45 |
+
optimizer: torch.optim.Optimizer,
|
46 |
+
*,
|
47 |
+
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
48 |
+
optimizer_state_keys: list[str],
|
49 |
+
):
|
50 |
+
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
|
51 |
+
n_zeros = 0
|
52 |
+
n_total = 0
|
53 |
+
|
54 |
+
optimizer_state = optimizer.state
|
55 |
+
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
56 |
+
optimizer_state = optimizer.optim.state
|
57 |
+
|
58 |
+
for param in reset_params:
|
59 |
+
param_state = optimizer_state[param]
|
60 |
+
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
61 |
+
continue
|
62 |
+
for key in optimizer_state_keys:
|
63 |
+
pruning_fn(
|
64 |
+
param_state[key]
|
65 |
+
) # pruning fn has to be inplace to keep the same keys in the dict
|
66 |
+
n_total += param_state[key].numel()
|
67 |
+
n_zeros += torch.sum(param_state[key] == 0).item()
|
68 |
+
|
69 |
+
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
70 |
+
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
71 |
+
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
|
72 |
|
73 |
|
74 |
class ReLoRACallback(TrainerCallback):
|
|
|
126 |
"relora",
|
127 |
)
|
128 |
|
129 |
+
if "adam" in args.optim.lower():
|
130 |
+
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
131 |
+
else:
|
132 |
+
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
133 |
+
|
134 |
+
lora_params = [
|
135 |
+
n
|
136 |
+
for n, p in model.named_parameters()
|
137 |
+
if p.requires_grad and "lora_" in n
|
138 |
+
]
|
139 |
+
|
140 |
+
model.save_pretrained(
|
141 |
+
os.path.join(
|
142 |
+
args.output_dir,
|
143 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
144 |
+
"adapter",
|
145 |
+
),
|
146 |
+
safe_serialization=True,
|
147 |
+
)
|
148 |
with torch.no_grad():
|
149 |
merge_and_save(
|
150 |
model,
|
|
|
155 |
actually_save=is_main_process(),
|
156 |
cpu_offload=self.cpu_offload,
|
157 |
)
|
158 |
+
reset_optimizer(
|
159 |
+
optimizer,
|
160 |
+
reset_params=lora_params,
|
161 |
+
optimizer_state_keys=optimizer_state_keys,
|
162 |
+
)
|
163 |
|
164 |
if self.quantized:
|
165 |
self.last_full_model = checkpoint_folder
|
|
|
249 |
inner_schedule: LRScheduler,
|
250 |
relora_steps: int,
|
251 |
warmup_steps: int,
|
252 |
+
anneal_steps: int = 1,
|
253 |
min_lr_scale: float = 0.001,
|
254 |
) -> None:
|
255 |
self.inner_schedule = inner_schedule
|
256 |
self.relora_steps = relora_steps
|
257 |
self.warmup_steps = warmup_steps
|
258 |
+
self.anneal_steps = anneal_steps
|
259 |
self.min_lr_scale = min_lr_scale
|
260 |
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
261 |
|
|
|
264 |
|
265 |
original = self.inner_schedule.get_lr()
|
266 |
step = self.last_epoch
|
267 |
+
|
268 |
if step < self.relora_steps:
|
269 |
scale = 1
|
270 |
else:
|
271 |
+
per_relora_progress = step % self.relora_steps
|
272 |
+
if per_relora_progress < self.warmup_steps:
|
273 |
+
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
274 |
+
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
275 |
+
cycle_t = min(
|
276 |
+
1.0,
|
277 |
+
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
278 |
+
)
|
279 |
+
else:
|
280 |
+
cycle_t = 1
|
281 |
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
282 |
|
283 |
if isinstance(original, Sequence):
|
|
|
302 |
|
303 |
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
304 |
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
305 |
+
adapter: Union[List[str], str] = layer.active_adapter
|
306 |
+
if isinstance(adapter, list):
|
307 |
+
if len(adapter) > 1:
|
308 |
+
raise ValueError("unhandled relora for multiple adapters")
|
309 |
+
adapter = adapter[0]
|
310 |
return (
|
311 |
peft.utils.transpose(
|
312 |
layer.lora_B[adapter].weight.detach().to(device)
|
|
|
316 |
* layer.scaling[adapter]
|
317 |
)
|
318 |
|
319 |
+
raise ValueError("unhandled lora layer type")
|
320 |
|
321 |
|
322 |
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
|
|
341 |
):
|
342 |
if reinit:
|
343 |
for adapter_name in target.lora_A:
|
344 |
+
target.reset_lora_parameters(adapter_name, True)
|
345 |
for adapter_name in target.lora_embedding_A:
|
346 |
+
target.reset_lora_parameters(adapter_name, True)
|
347 |
|
348 |
if isinstance(target, peft.tuners.lora.Linear4bit):
|
349 |
# This could be faster, but the quantization of Linear4bit weights occurs
|
|
|
354 |
target.weight.data = new_weight.cpu()
|
355 |
target.to(device)
|
356 |
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
357 |
+
target.weight.data = (
|
358 |
+
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
359 |
+
)
|
360 |
else:
|
361 |
target.weight.data = new_weight.to(device)
|
362 |
|
|
|
374 |
|
375 |
if not quantized:
|
376 |
for module_name, target in modules.items():
|
377 |
+
active_adapter = target.active_adapter
|
378 |
+
if isinstance(active_adapter, list):
|
379 |
+
active_adapter = active_adapter[0]
|
380 |
+
update = target.get_delta_weight(active_adapter).detach()
|
381 |
target.weight.data += update
|
382 |
|
383 |
if reinit:
|
384 |
for adapter_name in target.lora_A:
|
385 |
+
target.reset_lora_parameters(adapter_name, True)
|
386 |
for adapter_name in target.lora_embedding_A:
|
387 |
+
target.reset_lora_parameters(adapter_name, True)
|
388 |
return
|
389 |
|
390 |
os.makedirs(model_dst, exist_ok=True)
|
|
|
436 |
LOG.info(f"saving tensors to {shard_fn}")
|
437 |
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
438 |
|
439 |
+
barrier()
|
440 |
del in_tensors
|
441 |
del out_tensors
|
442 |
torch.cuda.empty_cache()
|
src/axolotl/prompt_strategies/instruct.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
2 |
+
from typing import Any, Dict, Optional
|
3 |
+
|
4 |
+
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
5 |
+
from axolotl.prompters import ShareGPTPrompterV2
|
6 |
+
|
7 |
+
|
8 |
+
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
9 |
+
conversation = (
|
10 |
+
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
11 |
+
)
|
12 |
+
strategy = InstructShareGPTPromptTokenizingStrategy(
|
13 |
+
# pylint: disable=duplicate-code
|
14 |
+
ShareGPTPrompterV2(
|
15 |
+
conversation=conversation,
|
16 |
+
),
|
17 |
+
tokenizer,
|
18 |
+
cfg.train_on_inputs,
|
19 |
+
cfg.sequence_len,
|
20 |
+
)
|
21 |
+
return strategy
|
22 |
+
|
23 |
+
|
24 |
+
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
25 |
+
"""
|
26 |
+
basic sharegpt strategy to grab conversations from the sample row
|
27 |
+
"""
|
28 |
+
|
29 |
+
def get_conversation_thread(self, prompt):
|
30 |
+
return [
|
31 |
+
{"from": "human", "value": prompt["instruction"]},
|
32 |
+
{"from": "gpt", "value": prompt["output"]},
|
33 |
+
]
|
src/axolotl/utils/chat_templates.py
CHANGED
@@ -19,6 +19,7 @@ def chat_templates(user_choice: str):
|
|
19 |
"""
|
20 |
|
21 |
templates = {
|
|
|
22 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
23 |
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
24 |
}
|
|
|
19 |
"""
|
20 |
|
21 |
templates = {
|
22 |
+
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
23 |
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
24 |
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
25 |
}
|
src/axolotl/utils/config.py
CHANGED
@@ -447,7 +447,11 @@ def validate_config(cfg):
|
|
447 |
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
448 |
)
|
449 |
|
450 |
-
if
|
|
|
|
|
|
|
|
|
451 |
raise ValueError(
|
452 |
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
453 |
)
|
|
|
447 |
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
448 |
)
|
449 |
|
450 |
+
if (
|
451 |
+
cfg.val_set_size == 0
|
452 |
+
and (cfg.eval_steps or cfg.evaluation_strategy)
|
453 |
+
and not cfg.test_datasets
|
454 |
+
):
|
455 |
raise ValueError(
|
456 |
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
457 |
)
|
src/axolotl/utils/data.py
CHANGED
@@ -140,7 +140,7 @@ def load_tokenized_prepared_datasets(
|
|
140 |
+ "|".join(
|
141 |
sorted(
|
142 |
[
|
143 |
-
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
144 |
for d in cfg_datasets
|
145 |
]
|
146 |
)
|
|
|
140 |
+ "|".join(
|
141 |
sorted(
|
142 |
[
|
143 |
+
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
144 |
for d in cfg_datasets
|
145 |
]
|
146 |
)
|
src/axolotl/utils/models.py
CHANGED
@@ -8,7 +8,13 @@ import addict
|
|
8 |
import bitsandbytes as bnb
|
9 |
import torch
|
10 |
import transformers
|
11 |
-
from peft import
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from peft.tuners.lora import QuantLinear
|
13 |
from transformers import ( # noqa: F401
|
14 |
AddedToken,
|
@@ -628,6 +634,9 @@ def load_model(
|
|
628 |
LOG.exception(err)
|
629 |
raise err
|
630 |
|
|
|
|
|
|
|
631 |
embeddings_len = (
|
632 |
math.ceil(len(tokenizer) / 32) * 32
|
633 |
if cfg.resize_token_embeddings_to_32x
|
@@ -782,7 +791,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|
782 |
|
783 |
def load_llama_adapter(model, cfg):
|
784 |
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
785 |
-
from peft import AdaptionPromptConfig,
|
786 |
|
787 |
peft_config = AdaptionPromptConfig(
|
788 |
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
@@ -828,7 +837,7 @@ def find_all_linear_names(model):
|
|
828 |
def load_lora(model, cfg, inference=False, config_only=False):
|
829 |
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
830 |
|
831 |
-
from peft import LoraConfig,
|
832 |
|
833 |
lora_target_modules = list(cfg.lora_target_modules or [])
|
834 |
|
|
|
8 |
import bitsandbytes as bnb
|
9 |
import torch
|
10 |
import transformers
|
11 |
+
from peft import (
|
12 |
+
LoftQConfig,
|
13 |
+
PeftConfig,
|
14 |
+
PeftModel,
|
15 |
+
PeftModelForCausalLM,
|
16 |
+
prepare_model_for_kbit_training,
|
17 |
+
)
|
18 |
from peft.tuners.lora import QuantLinear
|
19 |
from transformers import ( # noqa: F401
|
20 |
AddedToken,
|
|
|
634 |
LOG.exception(err)
|
635 |
raise err
|
636 |
|
637 |
+
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
638 |
+
model = model.merge_and_unload()
|
639 |
+
|
640 |
embeddings_len = (
|
641 |
math.ceil(len(tokenizer) / 32) * 32
|
642 |
if cfg.resize_token_embeddings_to_32x
|
|
|
791 |
|
792 |
def load_llama_adapter(model, cfg):
|
793 |
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
794 |
+
from peft import AdaptionPromptConfig, get_peft_model
|
795 |
|
796 |
peft_config = AdaptionPromptConfig(
|
797 |
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
|
|
837 |
def load_lora(model, cfg, inference=False, config_only=False):
|
838 |
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
839 |
|
840 |
+
from peft import LoraConfig, get_peft_model
|
841 |
|
842 |
lora_target_modules = list(cfg.lora_target_modules or [])
|
843 |
|
tests/e2e/patched/test_mistral_samplepack.py
CHANGED
@@ -7,8 +7,6 @@ import os
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
10 |
-
from transformers.utils import is_torch_bf16_gpu_available
|
11 |
-
|
12 |
from axolotl.cli import load_datasets
|
13 |
from axolotl.common.cli import TrainerCliArgs
|
14 |
from axolotl.train import train
|
@@ -63,6 +61,7 @@ class TestMistral(unittest.TestCase):
|
|
63 |
"max_steps": 20,
|
64 |
"save_steps": 10,
|
65 |
"eval_steps": 10,
|
|
|
66 |
}
|
67 |
)
|
68 |
normalize_config(cfg)
|
@@ -103,12 +102,9 @@ class TestMistral(unittest.TestCase):
|
|
103 |
"max_steps": 20,
|
104 |
"save_steps": 10,
|
105 |
"eval_steps": 10,
|
|
|
106 |
}
|
107 |
)
|
108 |
-
if is_torch_bf16_gpu_available():
|
109 |
-
cfg.bf16 = True
|
110 |
-
else:
|
111 |
-
cfg.fp16 = True
|
112 |
normalize_config(cfg)
|
113 |
cli_args = TrainerCliArgs()
|
114 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
7 |
import unittest
|
8 |
from pathlib import Path
|
9 |
|
|
|
|
|
10 |
from axolotl.cli import load_datasets
|
11 |
from axolotl.common.cli import TrainerCliArgs
|
12 |
from axolotl.train import train
|
|
|
61 |
"max_steps": 20,
|
62 |
"save_steps": 10,
|
63 |
"eval_steps": 10,
|
64 |
+
"bf16": "auto",
|
65 |
}
|
66 |
)
|
67 |
normalize_config(cfg)
|
|
|
102 |
"max_steps": 20,
|
103 |
"save_steps": 10,
|
104 |
"eval_steps": 10,
|
105 |
+
"bf16": "auto",
|
106 |
}
|
107 |
)
|
|
|
|
|
|
|
|
|
108 |
normalize_config(cfg)
|
109 |
cli_args = TrainerCliArgs()
|
110 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
tests/e2e/test_relora_llama.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E tests for relora llama
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from axolotl.cli import load_datasets
|
11 |
+
from axolotl.common.cli import TrainerCliArgs
|
12 |
+
from axolotl.train import train
|
13 |
+
from axolotl.utils.config import normalize_config
|
14 |
+
from axolotl.utils.dict import DictDefault
|
15 |
+
|
16 |
+
from .utils import with_temp_dir
|
17 |
+
|
18 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
19 |
+
os.environ["WANDB_DISABLED"] = "true"
|
20 |
+
|
21 |
+
|
22 |
+
class TestReLoraLlama(unittest.TestCase):
|
23 |
+
"""
|
24 |
+
Test case for Llama models using LoRA
|
25 |
+
"""
|
26 |
+
|
27 |
+
@with_temp_dir
|
28 |
+
def test_relora(self, temp_dir):
|
29 |
+
# pylint: disable=duplicate-code
|
30 |
+
cfg = DictDefault(
|
31 |
+
{
|
32 |
+
"base_model": "JackFram/llama-68m",
|
33 |
+
"tokenizer_type": "LlamaTokenizer",
|
34 |
+
"sequence_len": 1024,
|
35 |
+
"load_in_8bit": True,
|
36 |
+
"adapter": "lora",
|
37 |
+
"lora_r": 32,
|
38 |
+
"lora_alpha": 16,
|
39 |
+
"lora_dropout": 0.05,
|
40 |
+
"lora_target_modules": ["q_proj", "v_proj"],
|
41 |
+
"relora_steps": 25,
|
42 |
+
"relora_warmup_steps": 5,
|
43 |
+
"relora_anneal_steps": 5,
|
44 |
+
"relora_cpu_offload": True,
|
45 |
+
"val_set_size": 0.0,
|
46 |
+
"special_tokens": {},
|
47 |
+
"datasets": [
|
48 |
+
{
|
49 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
50 |
+
"type": "alpaca",
|
51 |
+
},
|
52 |
+
],
|
53 |
+
"warmup_steps": 15,
|
54 |
+
"num_epochs": 2,
|
55 |
+
"micro_batch_size": 4,
|
56 |
+
"gradient_accumulation_steps": 1,
|
57 |
+
"output_dir": temp_dir,
|
58 |
+
"learning_rate": 0.00001,
|
59 |
+
"optimizer": "adamw_torch",
|
60 |
+
"lr_scheduler": "cosine",
|
61 |
+
}
|
62 |
+
)
|
63 |
+
normalize_config(cfg)
|
64 |
+
cli_args = TrainerCliArgs()
|
65 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
66 |
+
|
67 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
68 |
+
assert (Path(temp_dir) / "model.safetensors").exists()
|