winglian commited on
Commit
ffd1043
1 Parent(s): 3369c4d

attempt to find linear modules for qlora

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +24 -2
src/axolotl/utils/models.py CHANGED
@@ -4,6 +4,7 @@ import os
4
  from pathlib import Path
5
  from typing import Optional, Tuple, TYPE_CHECKING
6
 
 
7
  import torch
8
  import transformers
9
  from torch import nn
@@ -334,6 +335,24 @@ def load_llama_adapter(model, cfg):
334
  return model, peft_config
335
 
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  def load_lora(model, cfg):
338
  # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
339
 
@@ -343,12 +362,15 @@ def load_lora(model, cfg):
343
  PeftModel,
344
  )
345
 
346
- lora_config = None
 
 
 
347
 
348
  lora_config = LoraConfig(
349
  r=cfg.lora_r,
350
  lora_alpha=cfg.lora_alpha,
351
- target_modules=cfg.lora_target_modules,
352
  lora_dropout=cfg.lora_dropout,
353
  fan_in_fan_out=cfg.lora_fan_in_fan_out,
354
  modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
 
4
  from pathlib import Path
5
  from typing import Optional, Tuple, TYPE_CHECKING
6
 
7
+ import bitsandbytes as bnb
8
  import torch
9
  import transformers
10
  from torch import nn
 
335
  return model, peft_config
336
 
337
 
338
+ def find_all_linear_names(bits, model):
339
+ cls = (
340
+ bnb.nn.Linear4bit
341
+ if bits == 4
342
+ else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
343
+ )
344
+ lora_module_names = set()
345
+ for name, module in model.named_modules():
346
+ if isinstance(module, cls):
347
+ names = name.split(".")
348
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
349
+
350
+ if "lm_head" in lora_module_names: # needed for 16-bit
351
+ lora_module_names.remove("lm_head")
352
+
353
+ return list(lora_module_names)
354
+
355
+
356
  def load_lora(model, cfg):
357
  # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
358
 
 
362
  PeftModel,
363
  )
364
 
365
+ bits = 4 if cfg.load_in_4bits else 8 if cfg.load_in_8bits else None
366
+ linear_names = find_all_linear_names(bits, model)
367
+ logging.info(f"found linear modules: {repr(linear_names)}")
368
+ lora_target_modules = cfg.lora_target_modules + linear_names
369
 
370
  lora_config = LoraConfig(
371
  r=cfg.lora_r,
372
  lora_alpha=cfg.lora_alpha,
373
+ target_modules=lora_target_modules,
374
  lora_dropout=cfg.lora_dropout,
375
  fan_in_fan_out=cfg.lora_fan_in_fan_out,
376
  modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,