winglian Nanobit commited on
Commit
861ceca
1 Parent(s): 1078d3e

refactor scripts/finetune.py into new cli modules (#550)

Browse files

* refactor scripts/finetune.py into new cli modules

* continue to support scripts/finetune.py

* update readme with updated cli commands

* Update scripts/finetune.py

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

README.md CHANGED
@@ -76,11 +76,11 @@ pip3 install -e .[flash-attn]
76
  pip3 install -U git+https://github.com/huggingface/peft.git
77
 
78
  # finetune lora
79
- accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml
80
 
81
  # inference
82
- accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
83
- --inference --lora_model_dir="./lora-out"
84
  ```
85
 
86
  ## Installation
@@ -674,14 +674,14 @@ strict:
674
 
675
  Run
676
  ```bash
677
- accelerate launch scripts/finetune.py your_config.yml
678
  ```
679
 
680
  #### Multi-GPU
681
 
682
  You can optionally pre-tokenize dataset with the following before finetuning:
683
  ```bash
684
- CUDA_VISIBLE_DEVICES="" accelerate ... --prepare_ds_only
685
  ```
686
 
687
  ##### Config
@@ -720,16 +720,16 @@ Pass the appropriate flag to the train command:
720
 
721
  - Pretrained LORA:
722
  ```bash
723
- --inference --lora_model_dir="./lora-output-dir"
724
  ```
725
  - Full weights finetune:
726
  ```bash
727
- --inference --base_model="./completed-model"
728
  ```
729
  - Full weights finetune w/ a prompt from a text file:
730
  ```bash
731
- cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
732
- --base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
733
  ```
734
 
735
  ### Merge LORA to base
@@ -737,13 +737,13 @@ Pass the appropriate flag to the train command:
737
  Add below flag to train command above
738
 
739
  ```bash
740
- --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
741
  ```
742
 
743
  If you run out of CUDA memory, you can try to merge in system RAM with
744
 
745
  ```bash
746
- CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
747
  ```
748
 
749
  ## Common Errors 🧰
 
76
  pip3 install -U git+https://github.com/huggingface/peft.git
77
 
78
  # finetune lora
79
+ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
80
 
81
  # inference
82
+ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
83
+ --lora_model_dir="./lora-out"
84
  ```
85
 
86
  ## Installation
 
674
 
675
  Run
676
  ```bash
677
+ accelerate launch -m axolotl.cli.train your_config.yml
678
  ```
679
 
680
  #### Multi-GPU
681
 
682
  You can optionally pre-tokenize dataset with the following before finetuning:
683
  ```bash
684
+ CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
685
  ```
686
 
687
  ##### Config
 
720
 
721
  - Pretrained LORA:
722
  ```bash
723
+ python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
724
  ```
725
  - Full weights finetune:
726
  ```bash
727
+ python -m axolotl.cli.inference examples/your_config.yml --base_model="./completed-model"
728
  ```
729
  - Full weights finetune w/ a prompt from a text file:
730
  ```bash
731
+ cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
732
+ --base_model="./completed-model" --prompter=None --load_in_8bit=True
733
  ```
734
 
735
  ### Merge LORA to base
 
737
  Add below flag to train command above
738
 
739
  ```bash
740
+ python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
741
  ```
742
 
743
  If you run out of CUDA memory, you can try to merge in system RAM with
744
 
745
  ```bash
746
+ CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
747
  ```
748
 
749
  ## Common Errors 🧰
scripts/finetune.py CHANGED
@@ -1,269 +1,34 @@
1
  """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2
-
3
- import importlib
4
  import logging
5
- import os
6
- import random
7
- import sys
8
  from pathlib import Path
9
- from typing import Any, Dict, List, Optional, Union
10
 
11
  import fire
12
- import torch
13
  import transformers
14
- import yaml
15
-
16
- # add src to the pythonpath so we don't need to pip install this
17
- from accelerate.commands.config import config_args
18
- from art import text2art
19
- from transformers import GenerationConfig, TextStreamer
20
-
21
- from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
22
- from axolotl.logging_config import configure_logging
23
- from axolotl.train import TrainDatasetMeta, train
24
- from axolotl.utils.config import normalize_config, validate_config
25
- from axolotl.utils.data import prepare_dataset
26
- from axolotl.utils.dict import DictDefault
27
- from axolotl.utils.distributed import is_main_process
28
- from axolotl.utils.models import load_tokenizer
29
- from axolotl.utils.tokenization import check_dataset_labels
30
- from axolotl.utils.wandb_ import setup_wandb_env_vars
31
-
32
- project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
33
- src_dir = os.path.join(project_root, "src")
34
- sys.path.insert(0, src_dir)
35
-
36
- configure_logging()
37
- LOG = logging.getLogger("axolotl.scripts")
38
-
39
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
40
-
41
-
42
- def print_axolotl_text_art(suffix=None):
43
- font = "nancyj"
44
- ascii_text = " axolotl"
45
- if suffix:
46
- ascii_text += f" x {suffix}"
47
- ascii_art = text2art(" axolotl", font=font)
48
 
49
- if is_main_process():
50
- print(ascii_art)
51
-
52
-
53
- def get_multi_line_input() -> Optional[str]:
54
- print("Give me an instruction (Ctrl + D to finish): ")
55
- instruction = ""
56
- for line in sys.stdin:
57
- instruction += line # pylint: disable=consider-using-join
58
- # instruction = pathlib.Path("/proc/self/fd/0").read_text()
59
- return instruction
60
-
61
-
62
- def do_merge_lora(
63
- *,
64
- cfg: DictDefault,
65
- cli_args: TrainerCliArgs,
66
- ):
67
- model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
68
- safe_serialization = cfg.save_safetensors is True
69
-
70
- LOG.info("running merge of LoRA with base model")
71
- model = model.merge_and_unload()
72
- model.to(dtype=torch.float16)
73
-
74
- if cfg.local_rank == 0:
75
- LOG.info("saving merged model")
76
- model.save_pretrained(
77
- str(Path(cfg.output_dir) / "merged"),
78
- safe_serialization=safe_serialization,
79
- )
80
- tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
81
 
 
82
 
83
- def shard(
84
- *,
85
- cfg: DictDefault,
86
- cli_args: TrainerCliArgs,
87
- ):
88
- model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
89
- safe_serialization = cfg.save_safetensors is True
90
- LOG.debug("Re-saving model w/ sharding")
91
- model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
92
 
93
-
94
- def do_inference(
95
- *,
96
- cfg: DictDefault,
97
- cli_args: TrainerCliArgs,
98
- ):
99
- model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
100
- prompter = cli_args.prompter
101
- default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
102
-
103
- for token, symbol in default_tokens.items():
104
- # If the token isn't already specified in the config, add it
105
- if not (cfg.special_tokens and token in cfg.special_tokens):
106
- tokenizer.add_special_tokens({token: symbol})
107
-
108
- prompter_module = None
109
- if prompter:
110
- prompter_module = getattr(
111
- importlib.import_module("axolotl.prompters"), prompter
112
- )
113
-
114
- if cfg.landmark_attention:
115
- from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
116
-
117
- set_model_mem_id(model, tokenizer)
118
- model.set_mem_cache_args(
119
- max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
120
- )
121
-
122
- model = model.to(cfg.device)
123
-
124
- while True:
125
- print("=" * 80)
126
- # support for multiline inputs
127
- instruction = get_multi_line_input()
128
- if not instruction:
129
- return
130
- if prompter_module:
131
- prompt: str = next(
132
- prompter_module().build_prompt(instruction=instruction.strip("\n"))
133
- )
134
- else:
135
- prompt = instruction.strip()
136
- batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
137
-
138
- print("=" * 40)
139
- model.eval()
140
- with torch.no_grad():
141
- generation_config = GenerationConfig(
142
- repetition_penalty=1.1,
143
- max_new_tokens=1024,
144
- temperature=0.9,
145
- top_p=0.95,
146
- top_k=40,
147
- bos_token_id=tokenizer.bos_token_id,
148
- eos_token_id=tokenizer.eos_token_id,
149
- pad_token_id=tokenizer.pad_token_id,
150
- do_sample=True,
151
- use_cache=True,
152
- return_dict_in_generate=True,
153
- output_attentions=False,
154
- output_hidden_states=False,
155
- output_scores=False,
156
- )
157
- streamer = TextStreamer(tokenizer)
158
- generated = model.generate(
159
- inputs=batch["input_ids"].to(cfg.device),
160
- generation_config=generation_config,
161
- streamer=streamer,
162
  )
163
- print("=" * 40)
164
- print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
165
-
166
-
167
- def choose_config(path: Path):
168
- yaml_files = list(path.glob("*.yml"))
169
-
170
- if not yaml_files:
171
- raise ValueError(
172
- "No YAML config files found in the specified directory. Are you using a .yml extension?"
173
  )
174
-
175
- if len(yaml_files) == 1:
176
- print(f"Using default YAML file '{yaml_files[0]}'")
177
- return yaml_files[0]
178
-
179
- print("Choose a YAML file:")
180
- for idx, file in enumerate(yaml_files):
181
- print(f"{idx + 1}. {file}")
182
-
183
- chosen_file = None
184
- while chosen_file is None:
185
- try:
186
- choice = int(input("Enter the number of your choice: "))
187
- if 1 <= choice <= len(yaml_files):
188
- chosen_file = yaml_files[choice - 1]
189
- else:
190
- print("Invalid choice. Please choose a number from the list.")
191
- except ValueError:
192
- print("Invalid input. Please enter a number.")
193
-
194
- return chosen_file
195
-
196
-
197
- def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
198
- return not any(el in list2 for el in list1)
199
-
200
-
201
- def load_cfg(config: Path = Path("examples/"), **kwargs):
202
- if Path(config).is_dir():
203
- config = choose_config(config)
204
-
205
- # load the config from the yaml file
206
- with open(config, encoding="utf-8") as file:
207
- cfg: DictDefault = DictDefault(yaml.safe_load(file))
208
- # if there are any options passed in the cli, if it is something that seems valid from the yaml,
209
- # then overwrite the value
210
- cfg_keys = cfg.keys()
211
- for k, _ in kwargs.items():
212
- # if not strict, allow writing to cfg even if it's not in the yml already
213
- if k in cfg_keys or not cfg.strict:
214
- # handle booleans
215
- if isinstance(cfg[k], bool):
216
- cfg[k] = bool(kwargs[k])
217
- else:
218
- cfg[k] = kwargs[k]
219
-
220
- validate_config(cfg)
221
-
222
- normalize_config(cfg)
223
-
224
- setup_wandb_env_vars(cfg)
225
- return cfg
226
-
227
-
228
- def load_datasets(
229
- *,
230
- cfg: DictDefault,
231
- cli_args: TrainerCliArgs,
232
- ) -> TrainDatasetMeta:
233
- tokenizer = load_tokenizer(cfg)
234
-
235
- train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
236
-
237
- if cli_args.debug or cfg.debug:
238
- LOG.info("check_dataset_labels...")
239
- check_dataset_labels(
240
- train_dataset.select(
241
- [
242
- random.randrange(0, len(train_dataset) - 1) # nosec
243
- for _ in range(cli_args.debug_num_examples)
244
- ]
245
- ),
246
- tokenizer,
247
- num_examples=cli_args.debug_num_examples,
248
- text_only=cli_args.debug_text_only,
249
- )
250
-
251
- return TrainDatasetMeta(
252
- train_dataset=train_dataset,
253
- eval_dataset=eval_dataset,
254
- total_num_steps=total_num_steps,
255
  )
256
-
257
-
258
- def check_accelerate_default_config():
259
- if Path(config_args.default_yaml_config_file).exists():
260
- LOG.warning(
261
- f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
262
- )
263
-
264
-
265
- def do_cli(config: Path = Path("examples/"), **kwargs):
266
- print_axolotl_text_art()
267
  parsed_cfg = load_cfg(config, **kwargs)
268
  check_accelerate_default_config()
269
  parser = transformers.HfArgumentParser((TrainerCliArgs))
 
1
  """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
 
 
2
  import logging
 
 
 
3
  from pathlib import Path
 
4
 
5
  import fire
 
6
  import transformers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from axolotl.cli import (
9
+ check_accelerate_default_config,
10
+ do_inference,
11
+ do_merge_lora,
12
+ load_cfg,
13
+ load_datasets,
14
+ print_axolotl_text_art,
15
+ )
16
+ from axolotl.cli.shard import shard
17
+ from axolotl.common.cli import TrainerCliArgs
18
+ from axolotl.train import train
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ LOG = logging.getLogger("axolotl.scripts.finetune")
21
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def do_cli(config: Path = Path("examples/"), **kwargs):
24
+ print_axolotl_text_art()
25
+ LOG.warning(
26
+ str(
27
+ PendingDeprecationWarning(
28
+ "scripts/finetune.py will be replaced with calling axolotl.cli.train"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
 
 
 
 
 
 
 
 
 
 
30
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
 
 
 
 
 
 
 
 
 
 
 
32
  parsed_cfg = load_cfg(config, **kwargs)
33
  check_accelerate_default_config()
34
  parser = transformers.HfArgumentParser((TrainerCliArgs))
src/axolotl/cli/__init__.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
2
+
3
+ import importlib
4
+ import logging
5
+ import os
6
+ import random
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional, Union
10
+
11
+ import torch
12
+ import yaml
13
+
14
+ # add src to the pythonpath so we don't need to pip install this
15
+ from accelerate.commands.config import config_args
16
+ from art import text2art
17
+ from transformers import GenerationConfig, TextStreamer
18
+
19
+ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
20
+ from axolotl.logging_config import configure_logging
21
+ from axolotl.train import TrainDatasetMeta
22
+ from axolotl.utils.config import normalize_config, validate_config
23
+ from axolotl.utils.data import prepare_dataset
24
+ from axolotl.utils.dict import DictDefault
25
+ from axolotl.utils.distributed import is_main_process
26
+ from axolotl.utils.models import load_tokenizer
27
+ from axolotl.utils.tokenization import check_dataset_labels
28
+ from axolotl.utils.wandb_ import setup_wandb_env_vars
29
+
30
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
31
+ src_dir = os.path.join(project_root, "src")
32
+ sys.path.insert(0, src_dir)
33
+
34
+ configure_logging()
35
+ LOG = logging.getLogger("axolotl.scripts")
36
+
37
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
38
+
39
+
40
+ def print_axolotl_text_art(suffix=None):
41
+ font = "nancyj"
42
+ ascii_text = " axolotl"
43
+ if suffix:
44
+ ascii_text += f" x {suffix}"
45
+ ascii_art = text2art(" axolotl", font=font)
46
+
47
+ if is_main_process():
48
+ print(ascii_art)
49
+
50
+
51
+ def get_multi_line_input() -> Optional[str]:
52
+ print("Give me an instruction (Ctrl + D to finish): ")
53
+ instruction = ""
54
+ for line in sys.stdin:
55
+ instruction += line # pylint: disable=consider-using-join
56
+ # instruction = pathlib.Path("/proc/self/fd/0").read_text()
57
+ return instruction
58
+
59
+
60
+ def do_merge_lora(
61
+ *,
62
+ cfg: DictDefault,
63
+ cli_args: TrainerCliArgs,
64
+ ):
65
+ model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
66
+ safe_serialization = cfg.save_safetensors is True
67
+
68
+ LOG.info("running merge of LoRA with base model")
69
+ model = model.merge_and_unload()
70
+ model.to(dtype=torch.float16)
71
+
72
+ if cfg.local_rank == 0:
73
+ LOG.info("saving merged model")
74
+ model.save_pretrained(
75
+ str(Path(cfg.output_dir) / "merged"),
76
+ safe_serialization=safe_serialization,
77
+ )
78
+ tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
79
+
80
+
81
+ def do_inference(
82
+ *,
83
+ cfg: DictDefault,
84
+ cli_args: TrainerCliArgs,
85
+ ):
86
+ model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
87
+ prompter = cli_args.prompter
88
+ default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
89
+
90
+ for token, symbol in default_tokens.items():
91
+ # If the token isn't already specified in the config, add it
92
+ if not (cfg.special_tokens and token in cfg.special_tokens):
93
+ tokenizer.add_special_tokens({token: symbol})
94
+
95
+ prompter_module = None
96
+ if prompter:
97
+ prompter_module = getattr(
98
+ importlib.import_module("axolotl.prompters"), prompter
99
+ )
100
+
101
+ if cfg.landmark_attention:
102
+ from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
103
+
104
+ set_model_mem_id(model, tokenizer)
105
+ model.set_mem_cache_args(
106
+ max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
107
+ )
108
+
109
+ model = model.to(cfg.device)
110
+
111
+ while True:
112
+ print("=" * 80)
113
+ # support for multiline inputs
114
+ instruction = get_multi_line_input()
115
+ if not instruction:
116
+ return
117
+ if prompter_module:
118
+ prompt: str = next(
119
+ prompter_module().build_prompt(instruction=instruction.strip("\n"))
120
+ )
121
+ else:
122
+ prompt = instruction.strip()
123
+ batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
124
+
125
+ print("=" * 40)
126
+ model.eval()
127
+ with torch.no_grad():
128
+ generation_config = GenerationConfig(
129
+ repetition_penalty=1.1,
130
+ max_new_tokens=1024,
131
+ temperature=0.9,
132
+ top_p=0.95,
133
+ top_k=40,
134
+ bos_token_id=tokenizer.bos_token_id,
135
+ eos_token_id=tokenizer.eos_token_id,
136
+ pad_token_id=tokenizer.pad_token_id,
137
+ do_sample=True,
138
+ use_cache=True,
139
+ return_dict_in_generate=True,
140
+ output_attentions=False,
141
+ output_hidden_states=False,
142
+ output_scores=False,
143
+ )
144
+ streamer = TextStreamer(tokenizer)
145
+ generated = model.generate(
146
+ inputs=batch["input_ids"].to(cfg.device),
147
+ generation_config=generation_config,
148
+ streamer=streamer,
149
+ )
150
+ print("=" * 40)
151
+ print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
152
+
153
+
154
+ def choose_config(path: Path):
155
+ yaml_files = list(path.glob("*.yml"))
156
+
157
+ if not yaml_files:
158
+ raise ValueError(
159
+ "No YAML config files found in the specified directory. Are you using a .yml extension?"
160
+ )
161
+
162
+ if len(yaml_files) == 1:
163
+ print(f"Using default YAML file '{yaml_files[0]}'")
164
+ return yaml_files[0]
165
+
166
+ print("Choose a YAML file:")
167
+ for idx, file in enumerate(yaml_files):
168
+ print(f"{idx + 1}. {file}")
169
+
170
+ chosen_file = None
171
+ while chosen_file is None:
172
+ try:
173
+ choice = int(input("Enter the number of your choice: "))
174
+ if 1 <= choice <= len(yaml_files):
175
+ chosen_file = yaml_files[choice - 1]
176
+ else:
177
+ print("Invalid choice. Please choose a number from the list.")
178
+ except ValueError:
179
+ print("Invalid input. Please enter a number.")
180
+
181
+ return chosen_file
182
+
183
+
184
+ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
185
+ return not any(el in list2 for el in list1)
186
+
187
+
188
+ def load_cfg(config: Path = Path("examples/"), **kwargs):
189
+ if Path(config).is_dir():
190
+ config = choose_config(config)
191
+
192
+ # load the config from the yaml file
193
+ with open(config, encoding="utf-8") as file:
194
+ cfg: DictDefault = DictDefault(yaml.safe_load(file))
195
+ # if there are any options passed in the cli, if it is something that seems valid from the yaml,
196
+ # then overwrite the value
197
+ cfg_keys = cfg.keys()
198
+ for k, _ in kwargs.items():
199
+ # if not strict, allow writing to cfg even if it's not in the yml already
200
+ if k in cfg_keys or not cfg.strict:
201
+ # handle booleans
202
+ if isinstance(cfg[k], bool):
203
+ cfg[k] = bool(kwargs[k])
204
+ else:
205
+ cfg[k] = kwargs[k]
206
+
207
+ validate_config(cfg)
208
+
209
+ normalize_config(cfg)
210
+
211
+ setup_wandb_env_vars(cfg)
212
+ return cfg
213
+
214
+
215
+ def load_datasets(
216
+ *,
217
+ cfg: DictDefault,
218
+ cli_args: TrainerCliArgs,
219
+ ) -> TrainDatasetMeta:
220
+ tokenizer = load_tokenizer(cfg)
221
+
222
+ train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
223
+
224
+ if cli_args.debug or cfg.debug:
225
+ LOG.info("check_dataset_labels...")
226
+ check_dataset_labels(
227
+ train_dataset.select(
228
+ [
229
+ random.randrange(0, len(train_dataset) - 1) # nosec
230
+ for _ in range(cli_args.debug_num_examples)
231
+ ]
232
+ ),
233
+ tokenizer,
234
+ num_examples=cli_args.debug_num_examples,
235
+ text_only=cli_args.debug_text_only,
236
+ )
237
+
238
+ return TrainDatasetMeta(
239
+ train_dataset=train_dataset,
240
+ eval_dataset=eval_dataset,
241
+ total_num_steps=total_num_steps,
242
+ )
243
+
244
+
245
+ def check_accelerate_default_config():
246
+ if Path(config_args.default_yaml_config_file).exists():
247
+ LOG.warning(
248
+ f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
249
+ )
src/axolotl/cli/inference.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI to run inference on a trained model
3
+ """
4
+ from pathlib import Path
5
+
6
+ import fire
7
+ import transformers
8
+
9
+ from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
10
+ from axolotl.common.cli import TrainerCliArgs
11
+
12
+
13
+ def do_cli(config: Path = Path("examples/"), **kwargs):
14
+ # pylint: disable=duplicate-code
15
+ print_axolotl_text_art()
16
+ parsed_cfg = load_cfg(config, **kwargs)
17
+ parser = transformers.HfArgumentParser((TrainerCliArgs))
18
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
19
+ return_remaining_strings=True
20
+ )
21
+ parsed_cli_args.inference = True
22
+
23
+ do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
24
+
25
+
26
+ fire.Fire(do_cli)
src/axolotl/cli/merge_lora.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI to run merge a trained LoRA into a base model
3
+ """
4
+ from pathlib import Path
5
+
6
+ import fire
7
+ import transformers
8
+
9
+ from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
10
+ from axolotl.common.cli import TrainerCliArgs
11
+
12
+
13
+ def do_cli(config: Path = Path("examples/"), **kwargs):
14
+ # pylint: disable=duplicate-code
15
+ print_axolotl_text_art()
16
+ parsed_cfg = load_cfg(config, **kwargs)
17
+ parser = transformers.HfArgumentParser((TrainerCliArgs))
18
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
19
+ return_remaining_strings=True
20
+ )
21
+ parsed_cli_args.merge_lora = True
22
+
23
+ do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
24
+
25
+
26
+ fire.Fire(do_cli)
src/axolotl/cli/shard.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI to shard a trained model into 10GiB chunks
3
+ """
4
+ import logging
5
+ from pathlib import Path
6
+
7
+ import fire
8
+ import transformers
9
+
10
+ from axolotl.cli import load_cfg, print_axolotl_text_art
11
+ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
12
+ from axolotl.utils.dict import DictDefault
13
+
14
+ LOG = logging.getLogger("axolotl.scripts")
15
+
16
+
17
+ def shard(
18
+ *,
19
+ cfg: DictDefault,
20
+ cli_args: TrainerCliArgs,
21
+ ):
22
+ model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
23
+ safe_serialization = cfg.save_safetensors is True
24
+ LOG.debug("Re-saving model w/ sharding")
25
+ model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
26
+
27
+
28
+ def do_cli(config: Path = Path("examples/"), **kwargs):
29
+ # pylint: disable=duplicate-code
30
+ print_axolotl_text_art()
31
+ parsed_cfg = load_cfg(config, **kwargs)
32
+ parser = transformers.HfArgumentParser((TrainerCliArgs))
33
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
34
+ return_remaining_strings=True
35
+ )
36
+ parsed_cli_args.shard = True
37
+
38
+ shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
39
+
40
+
41
+ fire.Fire(do_cli)
src/axolotl/cli/train.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CLI to run training on a model
3
+ """
4
+ from pathlib import Path
5
+
6
+ import fire
7
+ import transformers
8
+
9
+ from axolotl.cli import (
10
+ check_accelerate_default_config,
11
+ load_cfg,
12
+ load_datasets,
13
+ print_axolotl_text_art,
14
+ )
15
+ from axolotl.common.cli import TrainerCliArgs
16
+ from axolotl.train import train
17
+
18
+
19
+ def do_cli(config: Path = Path("examples/"), **kwargs):
20
+ # pylint: disable=duplicate-code
21
+ print_axolotl_text_art()
22
+ parsed_cfg = load_cfg(config, **kwargs)
23
+ check_accelerate_default_config()
24
+ parser = transformers.HfArgumentParser((TrainerCliArgs))
25
+ parsed_cli_args, _ = parser.parse_args_into_dataclasses(
26
+ return_remaining_strings=True
27
+ )
28
+
29
+ dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
30
+ if parsed_cli_args.prepare_ds_only:
31
+ return
32
+ train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
33
+
34
+
35
+ fire.Fire(do_cli)