|
"""Callbacks for Trainer class""" |
|
|
|
import logging |
|
import os |
|
|
|
from optimum.bettertransformer import BetterTransformer |
|
from transformers import ( |
|
TrainerCallback, |
|
TrainerControl, |
|
TrainerState, |
|
TrainingArguments, |
|
) |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy |
|
|
|
from axolotl.utils.bench import log_gpu_memory_usage |
|
|
|
LOG = logging.getLogger("axolotl.callbacks") |
|
|
|
|
|
class SavePeftModelCallback(TrainerCallback): |
|
"""Callback to save the PEFT adapter""" |
|
|
|
def on_save( |
|
self, |
|
args: TrainingArguments, |
|
state: TrainerState, |
|
control: TrainerControl, |
|
**kwargs, |
|
): |
|
checkpoint_folder = os.path.join( |
|
args.output_dir, |
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", |
|
) |
|
|
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") |
|
kwargs["model"].save_pretrained(peft_model_path) |
|
|
|
return control |
|
|
|
|
|
class SaveBetterTransformerModelCallback( |
|
TrainerCallback |
|
): |
|
"""Callback to save the BetterTransformer wrapped model""" |
|
|
|
def on_step_end( |
|
self, |
|
args: TrainingArguments, |
|
state: TrainerState, |
|
control: TrainerControl, |
|
**kwargs, |
|
): |
|
|
|
if ( |
|
args.save_strategy == IntervalStrategy.STEPS |
|
and args.save_steps > 0 |
|
and state.global_step % args.save_steps == 0 |
|
): |
|
control.should_save = True |
|
|
|
if control.should_save: |
|
checkpoint_folder = os.path.join( |
|
args.output_dir, |
|
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", |
|
) |
|
|
|
model = BetterTransformer.reverse(kwargs["model"]) |
|
model.save_pretrained(checkpoint_folder) |
|
|
|
|
|
|
|
|
|
control.should_save = False |
|
return control |
|
|
|
|
|
class GPUStatsCallback( |
|
TrainerCallback |
|
): |
|
"""Callback to track GPU utilization""" |
|
|
|
def __init__(self, cfg): |
|
self.cfg = cfg |
|
self.logged = False |
|
|
|
def on_step_end( |
|
self, |
|
args: TrainingArguments, |
|
state: TrainerState, |
|
control: TrainerControl, |
|
**kwargs, |
|
): |
|
if not self.logged and state.global_step > 1: |
|
log_gpu_memory_usage(LOG, "while training", self.cfg.device) |
|
self.logged = True |
|
return control |
|
|