|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" |
|
|
|
import logging |
|
import os |
|
import signal |
|
import sys |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import torch |
|
|
|
|
|
from datasets import Dataset |
|
from optimum.bettertransformer import BetterTransformer |
|
|
|
from axolotl.common.cli import TrainerCliArgs |
|
from axolotl.logging_config import configure_logging |
|
from axolotl.utils.dict import DictDefault |
|
from axolotl.utils.models import load_model, load_tokenizer |
|
from axolotl.utils.trainer import setup_trainer |
|
|
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
|
src_dir = os.path.join(project_root, "src") |
|
sys.path.insert(0, src_dir) |
|
|
|
configure_logging() |
|
LOG = logging.getLogger("axolotl.train") |
|
|
|
|
|
@dataclass |
|
class TrainDatasetMeta: |
|
""" |
|
dataclass to capture the dataset specific options for training |
|
""" |
|
|
|
train_dataset: Dataset |
|
eval_dataset: Optional[Dataset] = None |
|
total_num_steps: Optional[int] = None |
|
|
|
|
|
def train( |
|
*, |
|
cfg: DictDefault, |
|
cli_args: TrainerCliArgs, |
|
dataset_meta: TrainDatasetMeta, |
|
): |
|
|
|
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") |
|
tokenizer = load_tokenizer(cfg) |
|
|
|
train_dataset = dataset_meta.train_dataset |
|
eval_dataset = dataset_meta.eval_dataset |
|
total_num_steps = dataset_meta.total_num_steps |
|
|
|
|
|
LOG.info("loading model and (optionally) peft_config...") |
|
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) |
|
|
|
safe_serialization = cfg.save_safetensors is True |
|
|
|
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: |
|
possible_checkpoints = [ |
|
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") |
|
] |
|
if len(possible_checkpoints) > 0: |
|
sorted_paths = sorted( |
|
possible_checkpoints, |
|
key=lambda path: int(path.split("-")[-1]), |
|
) |
|
cfg.resume_from_checkpoint = sorted_paths[-1] |
|
LOG.info( |
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" |
|
) |
|
resume_from_checkpoint = cfg.resume_from_checkpoint |
|
|
|
trainer = setup_trainer( |
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps |
|
) |
|
|
|
model.config.use_cache = False |
|
|
|
if torch.__version__ >= "2" and sys.platform != "win32": |
|
LOG.info("Compiling torch model") |
|
model = torch.compile(model) |
|
|
|
|
|
if peft_config: |
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") |
|
peft_config.save_pretrained(cfg.output_dir) |
|
|
|
if not Path(cfg.output_dir).is_dir(): |
|
os.makedirs(cfg.output_dir, exist_ok=True) |
|
tokenizer.save_pretrained(str(Path(cfg.output_dir))) |
|
model.config.save_pretrained(str(Path(cfg.output_dir))) |
|
|
|
|
|
if cfg.local_rank == 0: |
|
|
|
def terminate_handler(_, __, model): |
|
if cfg.flash_optimum: |
|
model = BetterTransformer.reverse(model) |
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) |
|
sys.exit(0) |
|
|
|
signal.signal( |
|
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) |
|
) |
|
|
|
LOG.info("Starting trainer...") |
|
if cfg.group_by_length: |
|
LOG.info("hang tight... sorting dataset for group_by_length") |
|
|
|
if cfg.flash_optimum: |
|
with torch.backends.cuda.sdp_kernel( |
|
enable_flash=True, enable_math=True, enable_mem_efficient=True |
|
): |
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
|
else: |
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
|
|
|
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") |
|
|
|
if cfg.relora_steps: |
|
if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): |
|
model = model.merge_and_unload() |
|
else: |
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
if cfg.fsdp: |
|
trainer.save_model(cfg.output_dir) |
|
elif cfg.local_rank == 0: |
|
if cfg.flash_optimum: |
|
model = BetterTransformer.reverse(model) |
|
|
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) |
|
|
|
return model, tokenizer |
|
|