Nanobit commited on
Commit
0d6708b
·
1 Parent(s): 7576d85

Add callback save peft_model on_save

Browse files
src/axolotl/utils/callbacks.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
4
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
5
+
6
+ class SavePeftModelCallback(TrainerCallback):
7
+ def on_save(
8
+ self,
9
+ args: TrainingArguments,
10
+ state: TrainerState,
11
+ control: TrainerControl,
12
+ **kwargs,
13
+ ):
14
+ checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
15
+
16
+ peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
17
+ kwargs["model"].save_pretrained(peft_model_path)
18
+
19
+ return control
src/axolotl/utils/trainer.py CHANGED
@@ -13,6 +13,7 @@ from transformers import EarlyStoppingCallback
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
 
16
 
17
 
18
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
@@ -188,6 +189,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
188
  data_collator_kwargs["padding"] = "longest"
189
  else:
190
  data_collator_kwargs["pad_to_multiple_of"] = 8
 
 
 
 
 
191
  trainer = transformers.Trainer(
192
  model=model,
193
  train_dataset=train_dataset,
 
13
  from transformers.trainer_pt_utils import get_parameter_names
14
 
15
  from axolotl.utils.schedulers import InterpolatingLogScheduler
16
+ from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
19
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
 
189
  data_collator_kwargs["padding"] = "longest"
190
  else:
191
  data_collator_kwargs["pad_to_multiple_of"] = 8
192
+
193
+ callbacks = []
194
+ if cfg.adapter == 'lora':
195
+ callbacks.append(SavePeftModelCallback)
196
+
197
  trainer = transformers.Trainer(
198
  model=model,
199
  train_dataset=train_dataset,