winglian commited on
Commit
0b7ba57
1 Parent(s): 71bd062

fix types w lora (#478)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +18 -17
src/axolotl/utils/models.py CHANGED
@@ -11,7 +11,6 @@ import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
  from optimum.bettertransformer import BetterTransformer
14
- from peft.tuners.lora import LoraLayer
15
  from transformers import ( # noqa: F401
16
  AutoConfig,
17
  AutoModelForCausalLM,
@@ -348,6 +347,14 @@ def load_model(
348
  if model.device.type == "cuda":
349
  log_gpu_memory_usage(LOG, "after model load", model.device)
350
 
 
 
 
 
 
 
 
 
351
  if not cfg.gptq and (
352
  (cfg.adapter == "lora" and load_in_8bit)
353
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -357,6 +364,16 @@ def load_model(
357
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
358
  )
359
 
 
 
 
 
 
 
 
 
 
 
360
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
361
 
362
  if cfg.ddp and not load_in_8bit:
@@ -500,22 +517,6 @@ def load_lora(model, cfg):
500
  else:
501
  model = get_peft_model(model, lora_config)
502
 
503
- for name, module in model.named_modules():
504
- if isinstance(module, LoraLayer):
505
- module = module.to(cfg.torch_dtype)
506
- if "norm" in name:
507
- module = module.to(torch.float32)
508
- if "lm_head" in name or "embed_tokens" in name:
509
- if hasattr(module, "weight"):
510
- module = module.to(cfg.torch_dtype)
511
-
512
- # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
513
- # convert them back to fp16/bf16 for flash-attn compatibility.
514
- if cfg.flash_attention and cfg.is_llama_derived_model:
515
- for name, module in model.named_modules():
516
- if "norm" in name:
517
- module = module.to(cfg.torch_dtype)
518
-
519
  model.print_trainable_parameters()
520
 
521
  return model, lora_config
 
11
  import torch
12
  import transformers
13
  from optimum.bettertransformer import BetterTransformer
 
14
  from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
 
347
  if model.device.type == "cuda":
348
  log_gpu_memory_usage(LOG, "after model load", model.device)
349
 
350
+ # make sure these are fp32 per Ramesh et al. (2021)
351
+ for name, module in model.named_modules():
352
+ if "norm" in name:
353
+ module.to(torch.float32)
354
+ if "lm_head" in name or "embed_tokens" in name:
355
+ if hasattr(module, "weight"):
356
+ module.to(torch.float32)
357
+
358
  if not cfg.gptq and (
359
  (cfg.adapter == "lora" and load_in_8bit)
360
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
364
  model, use_gradient_checkpointing=cfg.gradient_checkpointing
365
  )
366
 
367
+ # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to
368
+ # convert them back to fp16/bf16 for flash-attn compatibility.
369
+ if cfg.flash_attention and cfg.is_llama_derived_model:
370
+ for name, module in model.named_modules():
371
+ if "norm" in name:
372
+ module.to(cfg.torch_dtype)
373
+ if "lm_head" in name or "embed_tokens" in name:
374
+ if hasattr(module, "weight"):
375
+ module.to(cfg.torch_dtype)
376
+
377
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
378
 
379
  if cfg.ddp and not load_in_8bit:
 
517
  else:
518
  model = get_peft_model(model, lora_config)
519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
520
  model.print_trainable_parameters()
521
 
522
  return model, lora_config