Nanobit commited on
Commit
d5f8589
·
unverified ·
1 Parent(s): 03e5907

chore(callback): Remove old peft saving code (#510)

Browse files
src/axolotl/utils/callbacks.py CHANGED
@@ -43,29 +43,6 @@ LOG = logging.getLogger("axolotl.callbacks")
43
  IGNORE_INDEX = -100
44
 
45
 
46
- class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
47
- """Callback to save the PEFT adapter"""
48
-
49
- def on_save(
50
- self,
51
- args: TrainingArguments,
52
- state: TrainerState,
53
- control: TrainerControl,
54
- **kwargs,
55
- ):
56
- checkpoint_folder = os.path.join(
57
- args.output_dir,
58
- f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
59
- )
60
-
61
- peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
62
- kwargs["model"].save_pretrained(
63
- peft_model_path, save_safetensors=args.save_safetensors
64
- )
65
-
66
- return control
67
-
68
-
69
  class EvalFirstStepCallback(
70
  TrainerCallback
71
  ): # pylint: disable=too-few-public-methods disable=unused-argument
 
43
  IGNORE_INDEX = -100
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  class EvalFirstStepCallback(
47
  TrainerCallback
48
  ): # pylint: disable=too-few-public-methods disable=unused-argument
src/axolotl/utils/trainer.py CHANGED
@@ -31,7 +31,6 @@ from axolotl.utils.callbacks import (
31
  EvalFirstStepCallback,
32
  GPUStatsCallback,
33
  SaveBetterTransformerModelCallback,
34
- SavePeftModelCallback,
35
  bench_eval_callback_factory,
36
  log_prediction_callback_factory,
37
  )
@@ -711,12 +710,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
711
  if cfg.relora_steps:
712
  callbacks.append(ReLoRACallback(cfg))
713
 
714
- if cfg.local_rank == 0 and cfg.adapter in [
715
- "lora",
716
- "qlora",
717
- ]: # only save in rank 0
718
- callbacks.append(SavePeftModelCallback)
719
-
720
  if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
721
  callbacks.append(SaveBetterTransformerModelCallback)
722
 
 
31
  EvalFirstStepCallback,
32
  GPUStatsCallback,
33
  SaveBetterTransformerModelCallback,
 
34
  bench_eval_callback_factory,
35
  log_prediction_callback_factory,
36
  )
 
710
  if cfg.relora_steps:
711
  callbacks.append(ReLoRACallback(cfg))
712
 
 
 
 
 
 
 
713
  if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
714
  callbacks.append(SaveBetterTransformerModelCallback)
715