winglian commited on
Commit
62a7741
1 Parent(s): 31b9e0c

Fix for check with cfg and merge_lora (#600)

Browse files
.github/workflows/tests.yml CHANGED
@@ -61,7 +61,7 @@ jobs:
61
  uses: actions/setup-python@v4
62
  with:
63
  python-version: "3.10"
64
- cache: 'pip' # caching pip dependencies
65
 
66
  - name: Install dependencies
67
  run: |
 
61
  uses: actions/setup-python@v4
62
  with:
63
  python-version: "3.10"
64
+ # cache: 'pip' # caching pip dependencies
65
 
66
  - name: Install dependencies
67
  run: |
src/axolotl/cli/__init__.py CHANGED
@@ -70,7 +70,7 @@ def do_merge_lora(
70
  model.to(dtype=torch.float16)
71
 
72
  if cfg.local_rank == 0:
73
- LOG.info("saving merged model")
74
  model.save_pretrained(
75
  str(Path(cfg.output_dir) / "merged"),
76
  safe_serialization=safe_serialization,
 
70
  model.to(dtype=torch.float16)
71
 
72
  if cfg.local_rank == 0:
73
+ LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
74
  model.save_pretrained(
75
  str(Path(cfg.output_dir) / "merged"),
76
  safe_serialization=safe_serialization,
src/axolotl/cli/merge_lora.py CHANGED
@@ -13,12 +13,12 @@ from axolotl.common.cli import TrainerCliArgs
13
  def do_cli(config: Path = Path("examples/"), **kwargs):
14
  # pylint: disable=duplicate-code
15
  print_axolotl_text_art()
16
- parsed_cfg = load_cfg(config, **kwargs)
17
  parser = transformers.HfArgumentParser((TrainerCliArgs))
18
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
19
  return_remaining_strings=True
20
  )
21
  parsed_cli_args.merge_lora = True
 
22
 
23
  do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
24
 
 
13
  def do_cli(config: Path = Path("examples/"), **kwargs):
14
  # pylint: disable=duplicate-code
15
  print_axolotl_text_art()
 
16
  parser = transformers.HfArgumentParser((TrainerCliArgs))
17
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
18
  return_remaining_strings=True
19
  )
20
  parsed_cli_args.merge_lora = True
21
+ parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
22
 
23
  do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
24