winglian commited on
Commit
3e3229e
1 Parent(s): 1d21aa6

fix for qwen w lora (#906)

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +10 -3
src/axolotl/utils/models.py CHANGED
@@ -412,15 +412,22 @@ def load_model(
412
  module.to(torch.float32)
413
 
414
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
 
 
 
 
 
 
415
  if (cfg.adapter == "lora" and load_in_8bit) or (
416
  cfg.adapter == "qlora" and cfg.load_in_4bit
417
  ):
418
  LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
419
  if cfg.gradient_checkpointing:
420
  model.gradient_checkpointing_enable()
421
- model = prepare_model_for_kbit_training(
422
- model, use_gradient_checkpointing=cfg.gradient_checkpointing
423
- )
 
424
  needs_fa2_dtype = True
425
 
426
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
 
412
  module.to(torch.float32)
413
 
414
  needs_fa2_dtype = cfg.adapter or cfg.fsdp
415
+ skip_prepare_model_for_kbit_training = False
416
+
417
+ if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
418
+ # Qwen doesn't play nicely with LoRA if this is enabled
419
+ skip_prepare_model_for_kbit_training = True
420
+
421
  if (cfg.adapter == "lora" and load_in_8bit) or (
422
  cfg.adapter == "qlora" and cfg.load_in_4bit
423
  ):
424
  LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
425
  if cfg.gradient_checkpointing:
426
  model.gradient_checkpointing_enable()
427
+ if not skip_prepare_model_for_kbit_training:
428
+ model = prepare_model_for_kbit_training(
429
+ model, use_gradient_checkpointing=cfg.gradient_checkpointing
430
+ )
431
  needs_fa2_dtype = True
432
 
433
  # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to