|
""" |
|
CLI to shard a trained model into 10GiB chunks |
|
""" |
|
import logging |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import fire |
|
import transformers |
|
|
|
from axolotl.cli import load_cfg, print_axolotl_text_art |
|
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer |
|
from axolotl.utils.dict import DictDefault |
|
|
|
LOG = logging.getLogger("axolotl.scripts") |
|
|
|
|
|
def shard( |
|
*, |
|
cfg: DictDefault, |
|
cli_args: TrainerCliArgs, |
|
): |
|
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) |
|
safe_serialization = cfg.save_safetensors is True |
|
LOG.debug("Re-saving model w/ sharding") |
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) |
|
|
|
|
|
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): |
|
|
|
print_axolotl_text_art() |
|
parsed_cfg = load_cfg(config, **kwargs) |
|
parser = transformers.HfArgumentParser((TrainerCliArgs)) |
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses( |
|
return_remaining_strings=True |
|
) |
|
parsed_cli_args.shard = True |
|
|
|
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(do_cli) |
|
|