Fix(model): Linear detected and added to target module with rope linear (#738)
Browse files* Fix(model): Linear detected and added to target module with rope linear
* fix: exclude layer instead
src/axolotl/utils/models.py
CHANGED
@@ -507,7 +507,11 @@ def find_all_linear_names(model):
|
|
507 |
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
508 |
lora_module_names = set()
|
509 |
for name, module in model.named_modules():
|
510 |
-
if
|
|
|
|
|
|
|
|
|
511 |
names = name.split(".")
|
512 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
513 |
|
|
|
507 |
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
508 |
lora_module_names = set()
|
509 |
for name, module in model.named_modules():
|
510 |
+
if (
|
511 |
+
isinstance(module, cls)
|
512 |
+
or "Linear" in module.__class__.__name__
|
513 |
+
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
514 |
+
):
|
515 |
names = name.split(".")
|
516 |
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
517 |
|