winglian commited on
Commit
5783839
1 Parent(s): cbbf039

download model weights on preprocess step (#1693)

Browse files
src/axolotl/cli/preprocess.py CHANGED
@@ -7,7 +7,9 @@ from typing import Union
7
 
8
  import fire
9
  import transformers
 
10
  from colorama import Fore
 
11
 
12
  from axolotl.cli import (
13
  check_accelerate_default_config,
@@ -71,6 +73,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
71
  else:
72
  load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
73
 
 
 
 
 
 
74
  LOG.info(
75
  Fore.GREEN
76
  + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
 
7
 
8
  import fire
9
  import transformers
10
+ from accelerate import init_empty_weights
11
  from colorama import Fore
12
+ from transformers import AutoModelForCausalLM
13
 
14
  from axolotl.cli import (
15
  check_accelerate_default_config,
 
73
  else:
74
  load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
75
 
76
+ if parsed_cli_args.download:
77
+ model_name = parsed_cfg.base_model
78
+ with init_empty_weights():
79
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
80
+
81
  LOG.info(
82
  Fore.GREEN
83
  + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
src/axolotl/common/cli.py CHANGED
@@ -40,6 +40,7 @@ class PreprocessCliArgs:
40
  debug_text_only: bool = field(default=False)
41
  debug_num_examples: int = field(default=1)
42
  prompter: Optional[str] = field(default=None)
 
43
 
44
 
45
  def load_model_and_tokenizer(
 
40
  debug_text_only: bool = field(default=False)
41
  debug_num_examples: int = field(default=1)
42
  prompter: Optional[str] = field(default=None)
43
+ download: Optional[bool] = field(default=True)
44
 
45
 
46
  def load_model_and_tokenizer(