File size: 10,939 Bytes
392dfd9
 
9105935
05fffb5
ce24f5e
949a27b
8d959a7
ce24f5e
 
37293dc
ce24f5e
 
 
 
37293dc
 
1edc30c
fec6bcc
ce24f5e
b1f4f7a
8cec513
2e22404
37293dc
29241cf
32e6fe9
37293dc
2e22404
6045345
ce24f5e
392dfd9
 
 
 
553a86b
 
 
d75adb9
05fffb5
a6028d3
29241cf
 
 
 
 
 
 
 
 
 
 
 
 
 
247825b
47ad389
247825b
 
82971e1
247825b
 
 
 
dc77c8e
f36e227
fec6bcc
 
 
 
 
87e073d
c4e4f81
 
 
 
 
6045345
572d114
974dc00
 
 
572d114
 
 
 
d653859
fec6bcc
9105935
247825b
d653859
 
c4e4f81
 
 
 
 
 
d653859
572d114
fec6bcc
d653859
 
988aeb9
d653859
33d4017
d653859
 
 
988aeb9
 
 
 
 
56f9ca5
d653859
 
 
 
fec6bcc
988aeb9
 
 
fec6bcc
988aeb9
fec6bcc
d653859
949a27b
 
f2a2029
82971e1
f2a2029
 
87d7825
 
 
f2a2029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc67862
 
 
 
ce24f5e
f2a2029
2393801
ce24f5e
 
29241cf
949a27b
f2a2029
 
ce24f5e
392dfd9
a1f9850
ce24f5e
 
93acb64
392dfd9
1d5ab84
b565ecf
f2a2029
 
 
 
 
ce24f5e
3aad5f3
 
8cec513
5a631b3
ce24f5e
52765ac
32e6fe9
efb3b2c
 
32e6fe9
b565ecf
 
392dfd9
2e22404
2bb0b78
21f17cc
553a86b
21f17cc
 
a1f9850
21f17cc
 
 
 
32e6fe9
553a86b
32e6fe9
 
ce24f5e
2bb0b78
7181022
949a27b
a276c9c
 
1d5ab84
553a86b
1d5ab84
a5bf838
1d5ab84
 
553a86b
a276c9c
 
 
 
289d5c4
1d5ab84
 
bd3b537
553a86b
dc77c8e
c4e4f81
 
dc77c8e
c4e4f81
dc77c8e
 
949a27b
 
9105935
a276c9c
9105935
 
2bb0b78
 
 
8d959a7
 
 
 
553a86b
8d959a7
 
902dd0a
2255bb7
553a86b
2255bb7
902dd0a
f2a2029
d1aed4c
488a67d
 
1edc30c
 
a276c9c
1edc30c
488a67d
d1aed4c
488a67d
d1aed4c
8d959a7
553a86b
1d5ab84
553a86b
0a472e1
 
2bc1a5b
 
 
0a472e1
2bc1a5b
37293dc
 
2bc1a5b
0a472e1
553a86b
2bc1a5b
 
4ac9e25
 
 
86a91e2
8792199
488a67d
 
 
8792199
 
 
ce24f5e
553a86b
915c56c
 
bdbca8f
894cba0
2bb0b78
894cba0
1edc30c
 
a276c9c
392dfd9
a6028d3
ce24f5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""

import importlib
import logging
import os
import random
import signal
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import fire
import torch
import yaml

# add src to the pythonpath so we don't need to pip install this
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer

from axolotl.logging_config import configure_logging
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)

configure_logging()
LOG = logging.getLogger("axolotl.scripts")

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"


def print_axolotl_text_art():
    ascii_art = """
                           dP            dP   dP
                           88            88   88
.d8888b. dP.  .dP .d8888b. 88 .d8888b. d8888P 88
88'  `88  `8bd8'  88'  `88 88 88'  `88   88   88
88.  .88  .d88b.  88.  .88 88 88.  .88   88   88
`88888P8 dP'  `dP `88888P' dP `88888P'   dP   dP
"""

    if is_main_process():
        print(ascii_art)


def get_multi_line_input() -> Optional[str]:
    print("Give me an instruction (Ctrl + D to finish): ")
    instruction = ""
    for line in sys.stdin:
        instruction += line  # pylint: disable=consider-using-join
    # instruction = pathlib.Path("/proc/self/fd/0").read_text()
    return instruction


def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
    default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}

    for token, symbol in default_tokens.items():
        # If the token isn't already specified in the config, add it
        if not (cfg.special_tokens and token in cfg.special_tokens):
            tokenizer.add_special_tokens({token: symbol})

    prompter_module = None
    if prompter:
        prompter_module = getattr(
            importlib.import_module("axolotl.prompters"), prompter
        )

    if cfg.landmark_attention:
        from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id

        set_model_mem_id(model, tokenizer)
        model.set_mem_cache_args(
            max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
        )

    while True:
        print("=" * 80)
        # support for multiline inputs
        instruction = get_multi_line_input()
        if not instruction:
            return
        if prompter_module:
            prompt: str = next(
                prompter_module().build_prompt(instruction=instruction.strip("\n"))
            )
        else:
            prompt = instruction.strip()
        batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

        print("=" * 40)
        model.eval()
        with torch.no_grad():
            generation_config = GenerationConfig(
                repetition_penalty=1.1,
                max_new_tokens=1024,
                temperature=0.9,
                top_p=0.95,
                top_k=40,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                do_sample=True,
                use_cache=True,
                return_dict_in_generate=True,
                output_attentions=False,
                output_hidden_states=False,
                output_scores=False,
            )
            streamer = TextStreamer(tokenizer)
            generated = model.generate(
                inputs=batch["input_ids"].to(cfg.device),
                generation_config=generation_config,
                streamer=streamer,
            )
        print("=" * 40)
        print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))


def choose_config(path: Path):
    yaml_files = list(path.glob("*.yml"))

    if not yaml_files:
        raise ValueError(
            "No YAML config files found in the specified directory. Are you using a .yml extension?"
        )

    print("Choose a YAML file:")
    for idx, file in enumerate(yaml_files):
        print(f"{idx + 1}. {file}")

    chosen_file = None
    while chosen_file is None:
        try:
            choice = int(input("Enter the number of your choice: "))
            if 1 <= choice <= len(yaml_files):
                chosen_file = yaml_files[choice - 1]
            else:
                print("Invalid choice. Please choose a number from the list.")
        except ValueError:
            print("Invalid input. Please enter a number.")

    return chosen_file


def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
    return not any(el in list2 for el in list1)


def train(
    config: Path = Path("configs/"),
    prepare_ds_only: bool = False,
    **kwargs,
):
    print_axolotl_text_art()
    if Path(config).is_dir():
        config = choose_config(config)

    # load the config from the yaml file
    with open(config, encoding="utf-8") as file:
        cfg: DictDefault = DictDefault(yaml.safe_load(file))
    # if there are any options passed in the cli, if it is something that seems valid from the yaml,
    # then overwrite the value
    cfg_keys = cfg.keys()
    for k, _ in kwargs.items():
        # if not strict, allow writing to cfg even if it's not in the yml already
        if k in cfg_keys or not cfg.strict:
            # handle booleans
            if isinstance(cfg[k], bool):
                cfg[k] = bool(kwargs[k])
            else:
                cfg[k] = kwargs[k]

    validate_config(cfg)

    normalize_config(cfg)

    setup_wandb_env_vars(cfg)

    # load the tokenizer first
    LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
    tokenizer = load_tokenizer(cfg)

    if (
        check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
    ):  # don't need to load dataset for these
        train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)

    if cfg.debug or "debug" in kwargs:
        LOG.info("check_dataset_labels...")
        check_dataset_labels(
            train_dataset.select(
                [random.randrange(0, len(train_dataset) - 1) for _ in range(5)]  # nosec
            ),
            tokenizer,
        )

    if prepare_ds_only:
        LOG.info("Finished preparing dataset. Exiting...")
        return

    # Load the model and tokenizer
    LOG.info("loading model and (optionally) peft_config...")
    model, peft_config = load_model(cfg, tokenizer)

    safe_serialization = cfg.save_safetensors is True

    if "merge_lora" in kwargs and cfg.adapter is not None:
        LOG.info("running merge of LoRA with base model")
        model = model.merge_and_unload()
        model.to(dtype=torch.float16)

        if cfg.local_rank == 0:
            LOG.info("saving merged model")
            model.save_pretrained(
                str(Path(cfg.output_dir) / "merged"),
                safe_serialization=safe_serialization,
            )
            tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
        return

    if cfg.inference:
        LOG.info("calling do_inference function")
        prompter: Optional[str] = "AlpacaPrompter"
        if "prompter" in kwargs:
            if kwargs["prompter"] == "None":
                prompter = None
            else:
                prompter = kwargs["prompter"]
        do_inference(cfg, model, tokenizer, prompter=prompter)
        return

    if "shard" in kwargs:
        model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
        return

    trainer = setup_trainer(
        cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
    )

    model.config.use_cache = False

    if torch.__version__ >= "2" and sys.platform != "win32":
        LOG.info("Compiling torch model")
        model = torch.compile(model)

    # go ahead and presave, so we have the adapter config available to inspect
    if peft_config:
        LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
        peft_config.save_pretrained(cfg.output_dir)

    # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
    if cfg.local_rank == 0:

        def terminate_handler(_, __, model):
            if cfg.flash_optimum:
                model = BetterTransformer.reverse(model)
            model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
            sys.exit(0)

        signal.signal(
            signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
        )

    LOG.info("Starting trainer...")
    if cfg.group_by_length:
        LOG.info("hang tight... sorting dataset for group_by_length")
    resume_from_checkpoint = cfg.resume_from_checkpoint
    if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
        possible_checkpoints = [
            str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
        ]
        if len(possible_checkpoints) > 0:
            sorted_paths = sorted(
                possible_checkpoints,
                key=lambda path: int(path.split("-")[-1]),
            )
            resume_from_checkpoint = sorted_paths[-1]
            LOG.info(
                f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
            )

    if not Path(cfg.output_dir).is_dir():
        os.makedirs(cfg.output_dir, exist_ok=True)
    tokenizer.save_pretrained(cfg.output_dir)
    if cfg.flash_optimum:
        with torch.backends.cuda.sdp_kernel(
            enable_flash=True, enable_math=True, enable_mem_efficient=True
        ):
            trainer.train(resume_from_checkpoint=resume_from_checkpoint)
    else:
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")

    # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
    # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
    if cfg.fsdp:
        trainer.save_model(cfg.output_dir)
    elif cfg.local_rank == 0:
        if cfg.flash_optimum:
            model = BetterTransformer.reverse(model)
        model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)


if __name__ == "__main__":
    fire.Fire(train)