fix types w lora (#478)
Browse files- src/axolotl/utils/models.py +18 -17
src/axolotl/utils/models.py
CHANGED
@@ -11,7 +11,6 @@ import bitsandbytes as bnb
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
from optimum.bettertransformer import BetterTransformer
|
14 |
-
from peft.tuners.lora import LoraLayer
|
15 |
from transformers import ( # noqa: F401
|
16 |
AutoConfig,
|
17 |
AutoModelForCausalLM,
|
@@ -348,6 +347,14 @@ def load_model(
|
|
348 |
if model.device.type == "cuda":
|
349 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
if not cfg.gptq and (
|
352 |
(cfg.adapter == "lora" and load_in_8bit)
|
353 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
@@ -357,6 +364,16 @@ def load_model(
|
|
357 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
358 |
)
|
359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
361 |
|
362 |
if cfg.ddp and not load_in_8bit:
|
@@ -500,22 +517,6 @@ def load_lora(model, cfg):
|
|
500 |
else:
|
501 |
model = get_peft_model(model, lora_config)
|
502 |
|
503 |
-
for name, module in model.named_modules():
|
504 |
-
if isinstance(module, LoraLayer):
|
505 |
-
module = module.to(cfg.torch_dtype)
|
506 |
-
if "norm" in name:
|
507 |
-
module = module.to(torch.float32)
|
508 |
-
if "lm_head" in name or "embed_tokens" in name:
|
509 |
-
if hasattr(module, "weight"):
|
510 |
-
module = module.to(cfg.torch_dtype)
|
511 |
-
|
512 |
-
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
513 |
-
# convert them back to fp16/bf16 for flash-attn compatibility.
|
514 |
-
if cfg.flash_attention and cfg.is_llama_derived_model:
|
515 |
-
for name, module in model.named_modules():
|
516 |
-
if "norm" in name:
|
517 |
-
module = module.to(cfg.torch_dtype)
|
518 |
-
|
519 |
model.print_trainable_parameters()
|
520 |
|
521 |
return model, lora_config
|
|
|
11 |
import torch
|
12 |
import transformers
|
13 |
from optimum.bettertransformer import BetterTransformer
|
|
|
14 |
from transformers import ( # noqa: F401
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
|
|
347 |
if model.device.type == "cuda":
|
348 |
log_gpu_memory_usage(LOG, "after model load", model.device)
|
349 |
|
350 |
+
# make sure these are fp32 per Ramesh et al. (2021)
|
351 |
+
for name, module in model.named_modules():
|
352 |
+
if "norm" in name:
|
353 |
+
module.to(torch.float32)
|
354 |
+
if "lm_head" in name or "embed_tokens" in name:
|
355 |
+
if hasattr(module, "weight"):
|
356 |
+
module.to(torch.float32)
|
357 |
+
|
358 |
if not cfg.gptq and (
|
359 |
(cfg.adapter == "lora" and load_in_8bit)
|
360 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
364 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
365 |
)
|
366 |
|
367 |
+
# LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
|
368 |
+
# convert them back to fp16/bf16 for flash-attn compatibility.
|
369 |
+
if cfg.flash_attention and cfg.is_llama_derived_model:
|
370 |
+
for name, module in model.named_modules():
|
371 |
+
if "norm" in name:
|
372 |
+
module.to(cfg.torch_dtype)
|
373 |
+
if "lm_head" in name or "embed_tokens" in name:
|
374 |
+
if hasattr(module, "weight"):
|
375 |
+
module.to(cfg.torch_dtype)
|
376 |
+
|
377 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
378 |
|
379 |
if cfg.ddp and not load_in_8bit:
|
|
|
517 |
else:
|
518 |
model = get_peft_model(model, lora_config)
|
519 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
520 |
model.print_trainable_parameters()
|
521 |
|
522 |
return model, lora_config
|