attempt to find linear modules for qlora
Browse files- src/axolotl/utils/models.py +24 -2
src/axolotl/utils/models.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
from pathlib import Path
|
5 |
from typing import Optional, Tuple, TYPE_CHECKING
|
6 |
|
|
|
7 |
import torch
|
8 |
import transformers
|
9 |
from torch import nn
|
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
|
|
334 |
return model, peft_config
|
335 |
|
336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
def load_lora(model, cfg):
|
338 |
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
339 |
|
@@ -343,12 +362,15 @@ def load_lora(model, cfg):
|
|
343 |
PeftModel,
|
344 |
)
|
345 |
|
346 |
-
|
|
|
|
|
|
|
347 |
|
348 |
lora_config = LoraConfig(
|
349 |
r=cfg.lora_r,
|
350 |
lora_alpha=cfg.lora_alpha,
|
351 |
-
target_modules=
|
352 |
lora_dropout=cfg.lora_dropout,
|
353 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
354 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
|
|
4 |
from pathlib import Path
|
5 |
from typing import Optional, Tuple, TYPE_CHECKING
|
6 |
|
7 |
+
import bitsandbytes as bnb
|
8 |
import torch
|
9 |
import transformers
|
10 |
from torch import nn
|
|
|
335 |
return model, peft_config
|
336 |
|
337 |
|
338 |
+
def find_all_linear_names(bits, model):
|
339 |
+
cls = (
|
340 |
+
bnb.nn.Linear4bit
|
341 |
+
if bits == 4
|
342 |
+
else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
|
343 |
+
)
|
344 |
+
lora_module_names = set()
|
345 |
+
for name, module in model.named_modules():
|
346 |
+
if isinstance(module, cls):
|
347 |
+
names = name.split(".")
|
348 |
+
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
349 |
+
|
350 |
+
if "lm_head" in lora_module_names: # needed for 16-bit
|
351 |
+
lora_module_names.remove("lm_head")
|
352 |
+
|
353 |
+
return list(lora_module_names)
|
354 |
+
|
355 |
+
|
356 |
def load_lora(model, cfg):
|
357 |
# type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
358 |
|
|
|
362 |
PeftModel,
|
363 |
)
|
364 |
|
365 |
+
bits = 4 if cfg.load_in_4bits else 8 if cfg.load_in_8bits else None
|
366 |
+
linear_names = find_all_linear_names(bits, model)
|
367 |
+
logging.info(f"found linear modules: {repr(linear_names)}")
|
368 |
+
lora_target_modules = cfg.lora_target_modules + linear_names
|
369 |
|
370 |
lora_config = LoraConfig(
|
371 |
r=cfg.lora_r,
|
372 |
lora_alpha=cfg.lora_alpha,
|
373 |
+
target_modules=lora_target_modules,
|
374 |
lora_dropout=cfg.lora_dropout,
|
375 |
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
376 |
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|