winglian commited on
Commit
c4cf567
2 Parent(s): c49729d 13ac4d8

Merge branch 'main' into quadratic-warmup

Browse files
.github/workflows/base.yml CHANGED
@@ -12,6 +12,7 @@ jobs:
12
  # this job needs to be run on self-hosted GPU runners...
13
  runs-on: self-hosted
14
  strategy:
 
15
  matrix:
16
  include:
17
  - cuda: "118"
@@ -25,7 +26,7 @@ jobs:
25
  pytorch: 2.0.0
26
  axolotl_extras:
27
  - cuda: "117"
28
- cuda_version: 11.7.0
29
  python_version: "3.9"
30
  pytorch: 1.13.1
31
  axolotl_extras:
 
12
  # this job needs to be run on self-hosted GPU runners...
13
  runs-on: self-hosted
14
  strategy:
15
+ fail-fast: false
16
  matrix:
17
  include:
18
  - cuda: "118"
 
26
  pytorch: 2.0.0
27
  axolotl_extras:
28
  - cuda: "117"
29
+ cuda_version: 11.7.1
30
  python_version: "3.9"
31
  pytorch: 1.13.1
32
  axolotl_extras:
.github/workflows/main.yml CHANGED
@@ -11,6 +11,7 @@ jobs:
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
  strategy:
 
14
  matrix:
15
  include:
16
  - cuda: cu118
@@ -29,7 +30,7 @@ jobs:
29
  pytorch: 2.0.0
30
  axolotl_extras: gptq
31
  - cuda: cu117
32
- cuda_version: 11.7.0
33
  python_version: "3.9"
34
  pytorch: 1.13.1
35
  axolotl_extras:
@@ -84,7 +85,7 @@ jobs:
84
  pytorch: 2.0.0
85
  axolotl_extras: gptq
86
  - cuda: cu117
87
- cuda_version: 11.7.0
88
  python_version: "3.9"
89
  pytorch: 1.13.1
90
  axolotl_extras:
 
11
  if: github.repository_owner == 'OpenAccess-AI-Collective'
12
  # this job needs to be run on self-hosted GPU runners...
13
  strategy:
14
+ fail-fast: false
15
  matrix:
16
  include:
17
  - cuda: cu118
 
30
  pytorch: 2.0.0
31
  axolotl_extras: gptq
32
  - cuda: cu117
33
+ cuda_version: 11.7.1
34
  python_version: "3.9"
35
  pytorch: 1.13.1
36
  axolotl_extras:
 
85
  pytorch: 2.0.0
86
  axolotl_extras: gptq
87
  - cuda: cu117
88
+ cuda_version: 11.7.1
89
  python_version: "3.9"
90
  pytorch: 1.13.1
91
  axolotl_extras:
.github/workflows/tests.yml CHANGED
@@ -7,6 +7,7 @@ jobs:
7
  test:
8
  runs-on: ubuntu-latest
9
  strategy:
 
10
  matrix:
11
  python_version: ["3.9", "3.10"]
12
  timeout-minutes: 10
 
7
  test:
8
  runs-on: ubuntu-latest
9
  strategy:
10
+ fail-fast: false
11
  matrix:
12
  python_version: ["3.9", "3.10"]
13
  timeout-minutes: 10
.pre-commit-config.yaml CHANGED
@@ -1,5 +1,5 @@
1
  default_language_version:
2
- python: python3.9
3
 
4
  repos:
5
  - repo: https://github.com/pre-commit/pre-commit-hooks
 
1
  default_language_version:
2
+ python: python3
3
 
4
  repos:
5
  - repo: https://github.com/pre-commit/pre-commit-hooks
README.md CHANGED
@@ -138,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
138
  ```json
139
  {"instruction": "...", "input": "...", "output": "..."}
140
  ```
141
- - `sharegpt`: conversations
142
  ```json
143
  {"conversations": [{"from": "...", "value": "..."}]}
144
  ```
@@ -195,6 +195,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
195
  ```json
196
  {"message_1": "...", "message_2": "..."}
197
  ```
 
 
 
 
198
  - `context_qa`: in context question answering from an article
199
  ```json
200
  {"article": "...", "question": "...", "answer": "..."}
@@ -233,7 +237,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
233
  #### How to add custom prompts
234
 
235
  1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
236
- 2. Use your custom file name as the dataset type.
237
 
238
  Optionally, download some datasets, see [data/README.md](data/README.md)
239
 
@@ -251,10 +255,18 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
251
 
252
  - dataset
253
  ```yaml
 
 
 
254
  datasets:
255
- - path: vicgalle/alpaca-gpt4 # local or huggingface repo
 
 
 
 
 
 
256
  type: alpaca # format from earlier
257
- sequence_len: 2048 # max token length / prompt
258
  ```
259
 
260
  - loading
@@ -264,6 +276,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
264
  bf16: true # require >=ampere
265
  fp16: true
266
  tf32: true # require >=ampere
 
 
267
  ```
268
  Note: Repo does not do 4-bit quantization.
269
 
@@ -300,6 +314,8 @@ model_type: AutoModelForCausalLM
300
  tokenizer_type: AutoTokenizer
301
  # Trust remote code for untrusted source
302
  trust_remote_code:
 
 
303
 
304
  # whether you are training a 4-bit GPTQ quantized model
305
  gptq: true
@@ -320,10 +336,10 @@ tf32: true # require >=ampere
320
 
321
  # a list of one or more datasets to finetune the model with
322
  datasets:
323
- # this can be either a hf dataset, or relative path
324
  - path: vicgalle/alpaca-gpt4
325
  # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
326
- type: alpaca # format OR format:prompt_style (chat/instruct)
327
  data_files: # path to source data files
328
  shards: # number of shards to split data into
329
 
@@ -332,6 +348,8 @@ datasets:
332
  dataset_prepared_path: data/last_run_prepared
333
  # push prepared dataset to hub
334
  push_dataset_to_hub: # repo path
 
 
335
  # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
336
  # required to be true when used in combination with `push_dataset_to_hub`
337
  hf_use_auth_token: # boolean
@@ -420,7 +438,15 @@ log_sweep_max_lr:
420
  optimizer:
421
  # specify weight decay
422
  weight_decay:
423
-
 
 
 
 
 
 
 
 
424
  # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
425
  xformers_attention:
426
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
@@ -500,16 +526,16 @@ Pass the appropriate flag to the train command:
500
 
501
  - Pretrained LORA:
502
  ```bash
503
- --inference --lora_model_dir ./completed-model
504
  ```
505
  - Full weights finetune:
506
  ```bash
507
- --inference --base_model ./completed-model
508
  ```
509
  - Full weights finetune w/ a prompt from a text file:
510
  ```bash
511
  cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
512
- --base_model ./completed-model --inference --prompter=None --load_in_8bit=True
513
  ```
514
 
515
  ### Merge LORA to base
@@ -520,6 +546,12 @@ Add below flag to train command above
520
  --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
521
  ```
522
 
 
 
 
 
 
 
523
  ## Common Errors 🧰
524
 
525
  > Cuda out of memory
@@ -552,6 +584,16 @@ Building something cool with Axolotl? Consider adding a badge to your model card
552
 
553
  [<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
554
 
 
 
 
 
 
 
 
 
 
 
555
  ## Contributing 🤝
556
 
557
  Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
 
138
  ```json
139
  {"instruction": "...", "input": "...", "output": "..."}
140
  ```
141
+ - `sharegpt:chat`: conversations
142
  ```json
143
  {"conversations": [{"from": "...", "value": "..."}]}
144
  ```
 
195
  ```json
196
  {"message_1": "...", "message_2": "..."}
197
  ```
198
+ - `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
199
+ ```json
200
+ {"system_prompt": "...", "question": "...", "response": "..."}
201
+ ```
202
  - `context_qa`: in context question answering from an article
203
  ```json
204
  {"article": "...", "question": "...", "answer": "..."}
 
237
  #### How to add custom prompts
238
 
239
  1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
240
+ 2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
241
 
242
  Optionally, download some datasets, see [data/README.md](data/README.md)
243
 
 
255
 
256
  - dataset
257
  ```yaml
258
+ sequence_len: 2048 # max token length for prompt
259
+
260
+ # huggingface repo
261
  datasets:
262
+ - path: vicgalle/alpaca-gpt4
263
+ type: alpaca # format from earlier
264
+
265
+ # local
266
+ datasets:
267
+ - path: json
268
+ data_files: data.jsonl # or json
269
  type: alpaca # format from earlier
 
270
  ```
271
 
272
  - loading
 
276
  bf16: true # require >=ampere
277
  fp16: true
278
  tf32: true # require >=ampere
279
+ bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
280
+ float16: true # use instead of fp16 when you don't want AMP
281
  ```
282
  Note: Repo does not do 4-bit quantization.
283
 
 
314
  tokenizer_type: AutoTokenizer
315
  # Trust remote code for untrusted source
316
  trust_remote_code:
317
+ # use_fast option for tokenizer loading from_pretrained, default to True
318
+ tokenizer_use_fast:
319
 
320
  # whether you are training a 4-bit GPTQ quantized model
321
  gptq: true
 
336
 
337
  # a list of one or more datasets to finetune the model with
338
  datasets:
339
+ # hf dataset repo | "json" for local dataset, make sure to fill data_files
340
  - path: vicgalle/alpaca-gpt4
341
  # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
342
+ type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
343
  data_files: # path to source data files
344
  shards: # number of shards to split data into
345
 
 
348
  dataset_prepared_path: data/last_run_prepared
349
  # push prepared dataset to hub
350
  push_dataset_to_hub: # repo path
351
+ # push checkpoints to hub
352
+ hub_model_id: # repo path
353
  # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
354
  # required to be true when used in combination with `push_dataset_to_hub`
355
  hf_use_auth_token: # boolean
 
438
  optimizer:
439
  # specify weight decay
440
  weight_decay:
441
+ # adamw hyperparams
442
+ adam_beta1:
443
+ adam_beta2:
444
+ adam_epsilon:
445
+ # Gradient clipping max norm
446
+ max_grad_norm:
447
+
448
+ # whether to bettertransformers
449
+ flash_optimum:
450
  # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
451
  xformers_attention:
452
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
 
526
 
527
  - Pretrained LORA:
528
  ```bash
529
+ --inference --lora_model_dir="./lora-output-dir"
530
  ```
531
  - Full weights finetune:
532
  ```bash
533
+ --inference --base_model="./completed-model"
534
  ```
535
  - Full weights finetune w/ a prompt from a text file:
536
  ```bash
537
  cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
538
+ --base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
539
  ```
540
 
541
  ### Merge LORA to base
 
546
  --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
547
  ```
548
 
549
+ If you run out of CUDA memory, you can try to merge in system RAM with
550
+
551
+ ```bash
552
+ CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
553
+ ```
554
+
555
  ## Common Errors 🧰
556
 
557
  > Cuda out of memory
 
584
 
585
  [<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
586
 
587
+ ## Community Showcase
588
+
589
+ Open Access AI Collective
590
+ - [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
591
+ - [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
592
+ - [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
593
+
594
+ PocketDoc Labs
595
+ - [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
596
+
597
  ## Contributing 🤝
598
 
599
  Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
data/README.md CHANGED
@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
10
  ## Convert the JSON data files to JSONL.
11
 
12
  ```shell
13
- python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
14
- python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
15
- python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
16
- python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
17
  ```
18
  ---
19
 
 
10
  ## Convert the JSON data files to JSONL.
11
 
12
  ```shell
13
+ python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
14
+ python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
15
+ python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
16
+ python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
17
  ```
18
  ---
19
 
docker/Dockerfile-base CHANGED
@@ -77,7 +77,7 @@ FROM base-builder
77
  RUN python3 -m pip uninstall -y apex
78
  RUN git clone https://github.com/NVIDIA/apex
79
  # `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
80
- RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .
81
 
82
  RUN mkdir -p /workspace/builds
83
  COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
@@ -97,4 +97,4 @@ RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
97
  RUN git lfs install --skip-repo
98
  RUN pip3 install awscli && \
99
  # The base image ships with `pydantic==1.8.2` which is not working
100
- pip3 install -U --no-cache-dir pydantic
 
77
  RUN python3 -m pip uninstall -y apex
78
  RUN git clone https://github.com/NVIDIA/apex
79
  # `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
80
+ RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
81
 
82
  RUN mkdir -p /workspace/builds
83
  COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
 
97
  RUN git lfs install --skip-repo
98
  RUN pip3 install awscli && \
99
  # The base image ships with `pydantic==1.8.2` which is not working
100
+ pip3 install -U --no-cache-dir pydantic==1.10.10
examples/openllama-3b/config.yml CHANGED
@@ -26,17 +26,18 @@ wandb_watch:
26
  wandb_run_id:
27
  wandb_log_model:
28
  output_dir: ./openllama-out
29
- batch_size: 16
30
- micro_batch_size: 4
31
  num_epochs: 3
32
  optimizer: adamw_bnb_8bit
33
  torchdistx_path:
34
  lr_scheduler: cosine
35
- learning_rate: 0.0002
36
  train_on_inputs: false
37
  group_by_length: false
 
38
  bf16: false
39
- fp16: true
40
  tf32: false
41
  gradient_checkpointing: true
42
  early_stopping_patience:
@@ -52,7 +53,7 @@ eval_steps: 50
52
  save_steps:
53
  debug:
54
  deepspeed:
55
- weight_decay: 0.0
56
  fsdp:
57
  fsdp_config:
58
  special_tokens:
 
26
  wandb_run_id:
27
  wandb_log_model:
28
  output_dir: ./openllama-out
29
+ gradient_accumulation_steps: 1
30
+ micro_batch_size: 1
31
  num_epochs: 3
32
  optimizer: adamw_bnb_8bit
33
  torchdistx_path:
34
  lr_scheduler: cosine
35
+ learning_rate: 0.00001
36
  train_on_inputs: false
37
  group_by_length: false
38
+ float16: true
39
  bf16: false
40
+ fp16: false
41
  tf32: false
42
  gradient_checkpointing: true
43
  early_stopping_patience:
 
53
  save_steps:
54
  debug:
55
  deepspeed:
56
+ weight_decay: 0.1
57
  fsdp:
58
  fsdp_config:
59
  special_tokens:
examples/pythia-12b/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Pythia 12B
2
+
3
+ - Single-GPU A100 only (?)
4
+
5
+ ```shell
6
+ python scripts/finetune.py examples/pythia-12b/config.yml
7
+ ```
8
+
9
+ ⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
examples/pythia-12b/config.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: EleutherAI/pythia-12b-deduped
2
+ base_model_config: EleutherAI/pythia-12b-deduped
3
+ base_model_ignore_patterns: pytorch* # prefer safetensors
4
+ model_type: GPTNeoXForCausalLM
5
+ tokenizer_type: AutoTokenizer
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ gptq: false
9
+ device_map: auto
10
+ datasets:
11
+ - path: vicgalle/alpaca-gpt4
12
+ type: alpaca
13
+ dataset_prepared_path: last_run_prepared
14
+ val_set_size: 0.05
15
+ adapter:
16
+ lora_model_dir:
17
+ sequence_len: 2048
18
+ max_packed_sequence_len: 2048
19
+ lora_r: 64
20
+ lora_alpha: 32
21
+ lora_dropout: 0.0
22
+ lora_target_modules:
23
+ lora_target_linear: true
24
+ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
25
+ wandb_project:
26
+ wandb_watch:
27
+ wandb_run_id:
28
+ wandb_log_model:
29
+ output_dir: ./pythia-12b
30
+ gradient_accumulation_steps: 1
31
+ micro_batch_size: 1
32
+ num_epochs: 5
33
+ learning_rate: 0.00003
34
+ optimizer: adamw_bnb_8bit
35
+ lr_scheduler: cosine
36
+ train_on_inputs: false
37
+ group_by_length: false
38
+ bf16: false
39
+ fp16: false
40
+ float16: true
41
+ tf32: true
42
+ flash_optimum: true
43
+ early_stopping_patience:
44
+ resume_from_checkpoint:
45
+ local_rank:
46
+ gradient_checkpointing: true
47
+ fsdp:
48
+ fsdp_config:
49
+ collator_pad_to_longest: true
examples/redpajama/config-3b.yml CHANGED
@@ -1,7 +1,7 @@
1
  base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
2
  base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
3
  model_type: GPTNeoXForCausalLM
4
- tokenizer_type: GPTNeoXTokenizer
5
  trust_remote_code:
6
  load_in_8bit: false
7
  datasets:
 
1
  base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
2
  base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
3
  model_type: GPTNeoXForCausalLM
4
+ tokenizer_type: AutoTokenizer
5
  trust_remote_code:
6
  load_in_8bit: false
7
  datasets:
requirements.txt CHANGED
@@ -11,6 +11,7 @@ sentencepiece
11
  wandb
12
  einops
13
  xformers
 
14
  # qlora things
15
  bert-score==0.3.13
16
  evaluate==0.4.0
 
11
  wandb
12
  einops
13
  xformers
14
+ optimum
15
  # qlora things
16
  bert-score==0.3.13
17
  evaluate==0.4.0
scripts/finetune.py CHANGED
@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
12
  import fire
13
  import torch
14
  import yaml
 
 
 
15
  from transformers import GenerationConfig, TextStreamer
16
 
17
- from axolotl.utils.data import load_prepare_datasets
18
  from axolotl.utils.dict import DictDefault
19
  from axolotl.utils.models import load_model, load_tokenizer
20
-
21
- # add src to the pythonpath so we don't need to pip install this
22
  from axolotl.utils.tokenization import check_dataset_labels
23
  from axolotl.utils.trainer import setup_trainer
24
  from axolotl.utils.validation import validate_config
@@ -63,7 +64,7 @@ def get_multi_line_input() -> Optional[str]:
63
  return instruction
64
 
65
 
66
- def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
67
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
68
 
69
  for token, symbol in default_tokens.items():
@@ -217,9 +218,20 @@ def train(
217
  if (
218
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
219
  ): # don't need to load dataset for these
220
- train_dataset, eval_dataset = load_prepare_datasets(
221
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
222
- )
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  if cfg.debug or "debug" in kwargs:
225
  logging.info("check_dataset_labels...")
@@ -257,13 +269,13 @@ def train(
257
 
258
  if cfg.inference:
259
  logging.info("calling do_inference function")
260
- inf_kwargs: Dict[str, Any] = {}
261
  if "prompter" in kwargs:
262
  if kwargs["prompter"] == "None":
263
- inf_kwargs["prompter"] = None
264
  else:
265
- inf_kwargs["prompter"] = kwargs["prompter"]
266
- do_inference(cfg, model, tokenizer, **inf_kwargs)
267
  return
268
 
269
  if "shard" in kwargs:
@@ -285,12 +297,15 @@ def train(
285
 
286
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
287
  if cfg.local_rank == 0:
 
 
 
 
 
 
 
288
  signal.signal(
289
- signal.SIGINT,
290
- lambda signal, frame: (
291
- model.save_pretrained(cfg.output_dir),
292
- sys.exit(0),
293
- ),
294
  )
295
 
296
  logging.info("Starting trainer...")
@@ -313,13 +328,21 @@ def train(
313
 
314
  if not Path(cfg.output_dir).is_dir():
315
  os.makedirs(cfg.output_dir, exist_ok=True)
316
- trainer.train(resume_from_checkpoint=resume_from_checkpoint)
 
 
 
 
 
 
317
 
318
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
319
 
320
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
321
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
322
  if cfg.local_rank == 0:
 
 
323
  model.save_pretrained(cfg.output_dir)
324
 
325
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
 
12
  import fire
13
  import torch
14
  import yaml
15
+
16
+ # add src to the pythonpath so we don't need to pip install this
17
+ from optimum.bettertransformer import BetterTransformer
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
+ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
21
  from axolotl.utils.dict import DictDefault
22
  from axolotl.utils.models import load_model, load_tokenizer
 
 
23
  from axolotl.utils.tokenization import check_dataset_labels
24
  from axolotl.utils.trainer import setup_trainer
25
  from axolotl.utils.validation import validate_config
 
64
  return instruction
65
 
66
 
67
+ def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
68
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
69
 
70
  for token, symbol in default_tokens.items():
 
218
  if (
219
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
220
  ): # don't need to load dataset for these
221
+ if not cfg.pretraining_dataset:
222
+ train_dataset, eval_dataset = load_prepare_datasets(
223
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
224
+ )
225
+ else:
226
+ train_dataset = load_pretraining_dataset(
227
+ cfg.pretraining_dataset,
228
+ tokenizer,
229
+ max_tokens=cfg.sequence_len,
230
+ seed=cfg.seed,
231
+ )
232
+ # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
233
+ train_dataset = train_dataset.with_format("torch")
234
+ eval_dataset = None
235
 
236
  if cfg.debug or "debug" in kwargs:
237
  logging.info("check_dataset_labels...")
 
269
 
270
  if cfg.inference:
271
  logging.info("calling do_inference function")
272
+ prompter: Optional[str] = "AlpacaPrompter"
273
  if "prompter" in kwargs:
274
  if kwargs["prompter"] == "None":
275
+ prompter = None
276
  else:
277
+ prompter = kwargs["prompter"]
278
+ do_inference(cfg, model, tokenizer, prompter=prompter)
279
  return
280
 
281
  if "shard" in kwargs:
 
297
 
298
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
299
  if cfg.local_rank == 0:
300
+
301
+ def terminate_handler(_, __, model):
302
+ if cfg.flash_optimum:
303
+ model = BetterTransformer.reverse(model)
304
+ model.save_pretrained(cfg.output_dir)
305
+ sys.exit(0)
306
+
307
  signal.signal(
308
+ signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
 
 
 
 
309
  )
310
 
311
  logging.info("Starting trainer...")
 
328
 
329
  if not Path(cfg.output_dir).is_dir():
330
  os.makedirs(cfg.output_dir, exist_ok=True)
331
+ if cfg.flash_optimum:
332
+ with torch.backends.cuda.sdp_kernel(
333
+ enable_flash=True, enable_math=True, enable_mem_efficient=True
334
+ ):
335
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
336
+ else:
337
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
338
 
339
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
340
 
341
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
342
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
343
  if cfg.local_rank == 0:
344
+ if cfg.flash_optimum:
345
+ model = BetterTransformer.reverse(model)
346
  model.save_pretrained(cfg.output_dir)
347
 
348
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
src/axolotl/datasets.py CHANGED
@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
126
  buffer_len = 0
127
 
128
  if example:
 
129
  # just going to drop data points that are too long
130
  if len(example["input_ids"]) <= self.seq_length:
131
  input_ids = example["input_ids"]
 
126
  buffer_len = 0
127
 
128
  if example:
129
+ # FIXME
130
  # just going to drop data points that are too long
131
  if len(example["input_ids"]) <= self.seq_length:
132
  input_ids = example["input_ids"]
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import (
6
  AlpacaPromptTokenizingStrategy,
7
  InstructionPromptTokenizingStrategy,
8
  )
9
- from axolotl.prompters import AlpacaPrompter, PromptStyle
10
 
11
 
12
  def load(tokenizer, cfg):
@@ -20,11 +20,38 @@ def load(tokenizer, cfg):
20
 
21
  class AlpacaConcisePrompter(AlpacaPrompter):
22
  """
23
- Alpaca Prompter extending the system prompt to ask for concise answers
24
  """
25
 
26
- system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
27
- system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
@@ -64,7 +91,7 @@ def load_concise(tokenizer, cfg):
64
 
65
  def load_qa(tokenizer, cfg):
66
  return AlpacaQAPromptTokenizingStrategy(
67
- AlpacaPrompter(PromptStyle.CHAT.value),
68
  tokenizer,
69
  cfg.train_on_inputs,
70
  cfg.sequence_len,
@@ -73,7 +100,16 @@ def load_qa(tokenizer, cfg):
73
 
74
  def load_camel_ai(tokenizer, cfg):
75
  return CamelAIPromptTokenizingStrategy(
76
- AlpacaPrompter(PromptStyle.CHAT.value),
 
 
 
 
 
 
 
 
 
77
  tokenizer,
78
  cfg.train_on_inputs,
79
  cfg.sequence_len,
 
6
  AlpacaPromptTokenizingStrategy,
7
  InstructionPromptTokenizingStrategy,
8
  )
9
+ from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
10
 
11
 
12
  def load(tokenizer, cfg):
 
20
 
21
  class AlpacaConcisePrompter(AlpacaPrompter):
22
  """
23
+ Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
24
  """
25
 
26
+ system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
27
+ system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
28
+
29
+
30
+ class AlpacaChatPrompter(AlpacaPrompter):
31
+ """
32
+ Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
33
+ """
34
+
35
+ system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
36
+ system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
37
+
38
+ def __init__(self): # pylint: disable=super-init-not-called
39
+ self.prompt_style = PromptStyle.CHAT.value
40
+ self.match_prompt_style()
41
+
42
+
43
+ class NoSystemPrompter(AlpacaPrompter):
44
+ """
45
+ Null Prompter with no system prompts
46
+ """
47
+
48
+ system_prompt = ""
49
+ system_no_input_prompt = ""
50
+ turn_format = "{instruction} {input} "
51
+ turn_no_input_format = "{instruction} "
52
+
53
+ def __init__(self): # pylint: disable=super-init-not-called
54
+ pass
55
 
56
 
57
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
 
91
 
92
  def load_qa(tokenizer, cfg):
93
  return AlpacaQAPromptTokenizingStrategy(
94
+ AlpacaChatPrompter(),
95
  tokenizer,
96
  cfg.train_on_inputs,
97
  cfg.sequence_len,
 
100
 
101
  def load_camel_ai(tokenizer, cfg):
102
  return CamelAIPromptTokenizingStrategy(
103
+ AlpacaChatPrompter(),
104
+ tokenizer,
105
+ cfg.train_on_inputs,
106
+ cfg.sequence_len,
107
+ )
108
+
109
+
110
+ def load_no_prompt(tokenizer, cfg):
111
+ return AlpacaPromptTokenizingStrategy(
112
+ UnpromptedPrompter(PromptStyle.CHAT.value),
113
  tokenizer,
114
  cfg.train_on_inputs,
115
  cfg.sequence_len,
src/axolotl/prompt_strategies/alpaca_instruct.py CHANGED
@@ -1,7 +1,7 @@
1
  """Module loading the AlpacaInstructPromptTokenizingStrategy class"""
2
 
3
  from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
4
- from axolotl.prompters import AlpacaPrompter, PromptStyle
5
 
6
 
7
  def load(tokenizer, cfg):
@@ -11,3 +11,12 @@ def load(tokenizer, cfg):
11
  cfg.train_on_inputs,
12
  cfg.sequence_len,
13
  )
 
 
 
 
 
 
 
 
 
 
1
  """Module loading the AlpacaInstructPromptTokenizingStrategy class"""
2
 
3
  from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
4
+ from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
5
 
6
 
7
  def load(tokenizer, cfg):
 
11
  cfg.train_on_inputs,
12
  cfg.sequence_len,
13
  )
14
+
15
+
16
+ def load_no_prompt(tokenizer, cfg):
17
+ return AlpacaPromptTokenizingStrategy(
18
+ UnpromptedPrompter(PromptStyle.INSTRUCT.value),
19
+ tokenizer,
20
+ cfg.train_on_inputs,
21
+ cfg.sequence_len,
22
+ )
src/axolotl/prompt_strategies/alpaca_w_system.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt strategies loader for alpaca instruction datasets with system prompts
3
+ """
4
+ from typing import Generator, Tuple, Union
5
+
6
+ from axolotl.prompt_tokenizers import PromptTokenizingStrategy
7
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
8
+
9
+
10
+ class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
11
+ """
12
+ Tokenizing strategy for instruction-based prompts.
13
+ """
14
+
15
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
16
+ return (
17
+ prompt["instruction"],
18
+ prompt["input"] if "input" in prompt else "",
19
+ prompt["output"],
20
+ prompt["system"],
21
+ )
22
+
23
+ def tokenize_prompt(self, prompt):
24
+ # pylint: disable=duplicate-code
25
+ (
26
+ instruction,
27
+ input, # pylint: disable=redefined-builtin
28
+ response,
29
+ system,
30
+ ) = self.parse_instruction_fields(prompt)
31
+ user_prompt = next(
32
+ iter(
33
+ self.prompter.build_prompt_w_system(
34
+ system,
35
+ instruction,
36
+ input,
37
+ )
38
+ )
39
+ )
40
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
41
+ if not self.train_on_inputs:
42
+ user_prompt_len = len(tokenized_prompt["input_ids"])
43
+ # TODO this could be sped up using numpy array slicing
44
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
45
+ tokenized_res_prompt = self._tokenize(
46
+ response, strip_bos_token=True, add_eos_token=True
47
+ )
48
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
49
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
50
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
51
+
52
+ return tokenized_prompt
53
+
54
+
55
+ class SystemDataPrompter(AlpacaPrompter):
56
+ """
57
+ Alpaca Style Prompter that uses system prompts from the dataset
58
+ """
59
+
60
+ def build_prompt_w_system(
61
+ self,
62
+ system: str,
63
+ instruction: str,
64
+ input: Union[None, str] = None, # pylint: disable=redefined-builtin
65
+ output: Union[None, str] = None,
66
+ ) -> Generator[str, None, None]:
67
+ # returns the full prompt from instruction and optional input
68
+ # if a label (=response, =output) is provided, it's also appended.
69
+ if input:
70
+ res = system + self.turn_format.format(instruction=instruction, input=input)
71
+ else:
72
+ res = system + self.turn_no_input_format.format(instruction=instruction)
73
+ if output:
74
+ res = f"{res}{output}"
75
+ yield res
76
+
77
+
78
+ class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
79
+ """
80
+ Tokenizing strategy for OpenOrca datasets
81
+ """
82
+
83
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
84
+ return (
85
+ prompt["question"],
86
+ "",
87
+ prompt["response"],
88
+ prompt["system_prompt"],
89
+ )
90
+
91
+
92
+ def load(tokenizer, cfg):
93
+ return load_chat(tokenizer, cfg)
94
+
95
+
96
+ def load_instruct(tokenizer, cfg):
97
+ return InstructionWSystemPromptTokenizingStrategy(
98
+ SystemDataPrompter(PromptStyle.INSTRUCT.value),
99
+ tokenizer,
100
+ cfg.train_on_inputs,
101
+ cfg.sequence_len,
102
+ )
103
+
104
+
105
+ def load_chat(tokenizer, cfg):
106
+ return InstructionWSystemPromptTokenizingStrategy(
107
+ SystemDataPrompter(PromptStyle.CHAT.value),
108
+ tokenizer,
109
+ cfg.train_on_inputs,
110
+ cfg.sequence_len,
111
+ )
112
+
113
+
114
+ def load_open_orca(tokenizer, cfg):
115
+ return OpenOrcaPromptTokenizingStrategy(
116
+ SystemDataPrompter(PromptStyle.INSTRUCT.value),
117
+ tokenizer,
118
+ cfg.train_on_inputs,
119
+ cfg.sequence_len,
120
+ )
src/axolotl/prompt_tokenizers.py CHANGED
@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
87
  Tokenizing strategy for instruction-based prompts.
88
  """
89
 
90
- def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
 
 
91
  raise NotImplementedError
92
 
93
  def tokenize_prompt(self, prompt):
@@ -96,25 +98,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
96
  input, # pylint: disable=redefined-builtin
97
  response,
98
  ) = self.parse_instruction_fields(prompt)
99
- full_prompt = self._build_full_prompt(instruction, input, response)
100
- tokenized_full_prompt = self._tokenize(full_prompt)
101
- if not self.train_on_inputs:
102
- user_prompt = next(
103
- iter(
104
- self.prompter.build_prompt(
105
- instruction,
106
- input,
107
- )
108
  )
109
  )
110
- tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
111
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
 
112
  # TODO this could be sped up using numpy array slicing
113
- tokenized_full_prompt["labels"] = [
114
- -100
115
- ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
 
 
 
 
116
 
117
- return tokenized_full_prompt
118
 
119
  def _build_full_prompt(
120
  self, instruction, input, response # pylint: disable=redefined-builtin
@@ -436,7 +440,7 @@ def parse_tokenized_to_result(
436
  result: Dict[str, List[int]],
437
  current_len: int,
438
  res: Dict[str, List[int]],
439
- labels: list[int],
440
  pad_token_id: Union[int, None] = None,
441
  ) -> Tuple[Dict[str, List[int]], int]:
442
  """
 
87
  Tokenizing strategy for instruction-based prompts.
88
  """
89
 
90
+ def parse_instruction_fields(
91
+ self, prompt
92
+ ) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
93
  raise NotImplementedError
94
 
95
  def tokenize_prompt(self, prompt):
 
98
  input, # pylint: disable=redefined-builtin
99
  response,
100
  ) = self.parse_instruction_fields(prompt)
101
+ user_prompt = next(
102
+ iter(
103
+ self.prompter.build_prompt(
104
+ instruction,
105
+ input,
 
 
 
 
106
  )
107
  )
108
+ )
109
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
110
+ if not self.train_on_inputs:
111
+ user_prompt_len = len(tokenized_prompt["input_ids"])
112
  # TODO this could be sped up using numpy array slicing
113
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
114
+ tokenized_res_prompt = self._tokenize(
115
+ response, strip_bos_token=True, add_eos_token=True
116
+ )
117
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
118
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
119
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
120
 
121
+ return tokenized_prompt
122
 
123
  def _build_full_prompt(
124
  self, instruction, input, response # pylint: disable=redefined-builtin
 
440
  result: Dict[str, List[int]],
441
  current_len: int,
442
  res: Dict[str, List[int]],
443
+ labels: List[int],
444
  pad_token_id: Union[int, None] = None,
445
  ) -> Tuple[Dict[str, List[int]], int]:
446
  """
src/axolotl/prompters.py CHANGED
@@ -24,6 +24,8 @@ class AlpacaPrompter:
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
 
 
27
  prompt_style: Optional[PromptStyle] = None
28
 
29
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
@@ -32,23 +34,13 @@ class AlpacaPrompter:
32
 
33
  def match_prompt_style(self):
34
  if self.prompt_style == PromptStyle.INSTRUCT.value:
35
- self.prompt_input = (
36
- self.system_prompt
37
- + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
38
- )
39
- self.prompt_no_input = (
40
- self.system_no_input_prompt
41
- + "### Instruction:\n{instruction}\n\n### Response:\n"
42
  )
43
- self.response_split = "### Response:"
44
  if self.prompt_style == PromptStyle.CHAT.value:
45
- self.prompt_input = (
46
- self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
47
- )
48
- self.prompt_no_input = (
49
- self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
50
- )
51
- self.response_split = "ASSISTANT:"
52
 
53
  def build_prompt(
54
  self,
@@ -59,16 +51,17 @@ class AlpacaPrompter:
59
  # returns the full prompt from instruction and optional input
60
  # if a label (=response, =output) is provided, it's also appended.
61
  if input:
62
- res = self.prompt_input.format(instruction=instruction, input=input)
 
 
63
  else:
64
- res = self.prompt_no_input.format(instruction=instruction)
 
 
65
  if output:
66
  res = f"{res}{output}"
67
  yield res
68
 
69
- def get_response(self, output: str) -> str:
70
- return output.split(self.response_split)[1].strip()
71
-
72
 
73
  class UnpromptedPrompter(AlpacaPrompter):
74
  """
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
93
  """
94
 
95
  system_prompt = (
96
- "Choose the answer that best answers the question. Explain your reasoning."
 
 
 
97
  )
98
 
99
 
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
102
  Prompter for multiple choice concise
103
  """
104
 
105
- prompt_input = "Choose the answer that best answers the question. Be concise in your response.\n\nUSER: {instruction}\n{input}\nASSISTANT:\n"
 
 
 
 
 
106
 
107
 
108
  class SummarizeTLDRPrompter(AlpacaPrompter):
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
110
  Prompter for summarize TLDR
111
  """
112
 
113
- prompt_no_input = (
114
- "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
115
- )
 
 
 
116
 
117
 
118
  class CompletionPrompter:
@@ -128,9 +132,6 @@ class CompletionPrompter:
128
  ) -> Generator[str, None, None]:
129
  yield instruction
130
 
131
- def get_response(self, output: str) -> str:
132
- return output.strip()
133
-
134
 
135
  class GPTeacherPrompter(AlpacaPrompter):
136
  """
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
210
  res = f"{res}{label}"
211
  yield res
212
 
213
- def get_response(self, output: str) -> str:
214
- return output.split(self.response_split)[1].strip()
215
-
216
 
217
  class SeparatorStyle(Enum):
218
  """Different separator style."""
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
289
  sep2=" ",
290
  )
291
 
292
- # def match_prompt_style(self):
293
- # if self.prompt_style == PromptStyle.chat.value:
294
- # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
295
- # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
296
- # self.response_split = "ASSISTANT:"
297
-
298
  def build_prompt(self, source) -> Generator[str, None, None]:
299
  # ignore the system prompt if provided
300
  if source[0]["from"] == "system":
 
24
 
25
  system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
26
  system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
27
+ turn_format: str
28
+ turn_no_input_format: str
29
  prompt_style: Optional[PromptStyle] = None
30
 
31
  def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
 
34
 
35
  def match_prompt_style(self):
36
  if self.prompt_style == PromptStyle.INSTRUCT.value:
37
+ self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
38
+ self.turn_no_input_format = (
39
+ "### Instruction:\n{instruction}\n\n### Response:\n"
 
 
 
 
40
  )
 
41
  if self.prompt_style == PromptStyle.CHAT.value:
42
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
43
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
 
 
 
 
 
44
 
45
  def build_prompt(
46
  self,
 
51
  # returns the full prompt from instruction and optional input
52
  # if a label (=response, =output) is provided, it's also appended.
53
  if input:
54
+ res = self.system_prompt + self.turn_format.format(
55
+ instruction=instruction, input=input
56
+ )
57
  else:
58
+ res = self.system_no_input_prompt + self.turn_no_input_format.format(
59
+ instruction=instruction
60
+ )
61
  if output:
62
  res = f"{res}{output}"
63
  yield res
64
 
 
 
 
65
 
66
  class UnpromptedPrompter(AlpacaPrompter):
67
  """
 
86
  """
87
 
88
  system_prompt = (
89
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
90
+ )
91
+ system_no_input_prompt = (
92
+ "Choose the answer that best answers the question. Explain your reasoning.\n"
93
  )
94
 
95
 
 
98
  Prompter for multiple choice concise
99
  """
100
 
101
+ system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
102
+ system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
103
+
104
+ def match_prompt_style(self):
105
+ self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
106
+ self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
107
 
108
 
109
  class SummarizeTLDRPrompter(AlpacaPrompter):
 
111
  Prompter for summarize TLDR
112
  """
113
 
114
+ system_prompt = ""
115
+ system_no_input_prompt = ""
116
+
117
+ def match_prompt_style(self):
118
+ self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
119
+ self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
120
 
121
 
122
  class CompletionPrompter:
 
132
  ) -> Generator[str, None, None]:
133
  yield instruction
134
 
 
 
 
135
 
136
  class GPTeacherPrompter(AlpacaPrompter):
137
  """
 
211
  res = f"{res}{label}"
212
  yield res
213
 
 
 
 
214
 
215
  class SeparatorStyle(Enum):
216
  """Different separator style."""
 
287
  sep2=" ",
288
  )
289
 
 
 
 
 
 
 
290
  def build_prompt(self, source) -> Generator[str, None, None]:
291
  # ignore the system prompt if provided
292
  if source[0]["from"] == "system":
src/axolotl/utils/callbacks.py CHANGED
@@ -2,13 +2,14 @@
2
 
3
  import os
4
 
 
5
  from transformers import (
6
  TrainerCallback,
7
  TrainerControl,
8
  TrainerState,
9
  TrainingArguments,
10
  )
11
- from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
12
 
13
 
14
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
30
  kwargs["model"].save_pretrained(peft_model_path)
31
 
32
  return control
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import os
4
 
5
+ from optimum.bettertransformer import BetterTransformer
6
  from transformers import (
7
  TrainerCallback,
8
  TrainerControl,
9
  TrainerState,
10
  TrainingArguments,
11
  )
12
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
13
 
14
 
15
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
 
31
  kwargs["model"].save_pretrained(peft_model_path)
32
 
33
  return control
34
+
35
+
36
+ class SaveBetterTransformerModelCallback(
37
+ TrainerCallback
38
+ ): # pylint: disable=too-few-public-methods
39
+ """Callback to save the BetterTransformer wrapped model"""
40
+
41
+ def on_step_end(
42
+ self,
43
+ args: TrainingArguments,
44
+ state: TrainerState,
45
+ control: TrainerControl,
46
+ **kwargs,
47
+ ):
48
+ # Save
49
+ if (
50
+ args.save_strategy == IntervalStrategy.STEPS
51
+ and args.save_steps > 0
52
+ and state.global_step % args.save_steps == 0
53
+ ):
54
+ control.should_save = True
55
+
56
+ if control.should_save:
57
+ checkpoint_folder = os.path.join(
58
+ args.output_dir,
59
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
60
+ )
61
+
62
+ model = BetterTransformer.reverse(kwargs["model"])
63
+ model.save_pretrained(checkpoint_folder)
64
+ # FIXME - need to cleanup old checkpoints
65
+
66
+ # since we're saving here, we don't need the trainer loop to attempt to save too b/c
67
+ # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
68
+ control.should_save = False
69
+ return control
src/axolotl/utils/data.py CHANGED
@@ -1,10 +1,11 @@
1
  """Module containing data utilities"""
2
-
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
 
8
  from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
9
  from huggingface_hub import hf_hub_download
10
  from transformers import PreTrainedTokenizerBase
@@ -101,13 +102,26 @@ def load_tokenized_prepared_datasets(
101
  pass
102
 
103
  # prefer local dataset, even if hub exists
104
- if Path(d.path).exists():
105
- ds = load_dataset(
106
- "json",
107
- data_files=d.path,
108
- streaming=False,
109
- split=None,
110
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  elif ds_from_hub:
112
  if d.data_files:
113
  ds = load_dataset(
@@ -394,8 +408,127 @@ def load_prepare_datasets(
394
  index=cfg.dataset_shard_idx,
395
  )
396
 
397
- dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
398
- train_dataset = dataset["train"]
399
- eval_dataset = dataset["test"]
 
 
 
 
400
 
401
  return train_dataset, eval_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module containing data utilities"""
2
+ import functools
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
+ import torch
9
  from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
10
  from huggingface_hub import hf_hub_download
11
  from transformers import PreTrainedTokenizerBase
 
102
  pass
103
 
104
  # prefer local dataset, even if hub exists
105
+ local_path = Path(d.path)
106
+ if local_path.exists():
107
+ if local_path.is_dir():
108
+ ds = load_dataset(
109
+ d.path,
110
+ data_files=d.data_files,
111
+ streaming=False,
112
+ split=None,
113
+ )
114
+ elif local_path.is_file():
115
+ ds = load_dataset(
116
+ "json",
117
+ data_files=d.path,
118
+ streaming=False,
119
+ split=None,
120
+ )
121
+ else:
122
+ raise ValueError(
123
+ "unhandled dataset load: local path exists, but is neither a directory or a file"
124
+ )
125
  elif ds_from_hub:
126
  if d.data_files:
127
  ds = load_dataset(
 
408
  index=cfg.dataset_shard_idx,
409
  )
410
 
411
+ if cfg.val_set_size:
412
+ dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
413
+ train_dataset = dataset["train"]
414
+ eval_dataset = dataset["test"]
415
+ else:
416
+ train_dataset = dataset
417
+ eval_dataset = None
418
 
419
  return train_dataset, eval_dataset
420
+
421
+
422
+ def encode_pretraining(tokenizer, max_tokens, examples):
423
+ res = tokenizer(
424
+ examples["text"],
425
+ truncation=True,
426
+ max_length=max_tokens - 2,
427
+ add_special_tokens=True,
428
+ )
429
+ # Convert to PyTorch tensors
430
+ input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
431
+ attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
432
+ new_input_ids = []
433
+ new_attention_mask = []
434
+ # Append EOS and PAD tokens to input_ids, and correct attention_mask
435
+ for i, _ in enumerate(input_ids):
436
+ input_ids[i] = torch.cat(
437
+ (
438
+ input_ids[i],
439
+ torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
440
+ ),
441
+ dim=0,
442
+ )
443
+ attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
444
+
445
+ # Concatenate tokens so that their lengths are less than max_tokens
446
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
447
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
448
+
449
+ for ids, mask in zip(input_ids, attention_mask):
450
+ if buffer_input_ids.numel() == max_tokens:
451
+ new_input_ids.append(buffer_input_ids)
452
+ new_attention_mask.append(buffer_attention_mask)
453
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
454
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
455
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
456
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
457
+ elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
458
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
459
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
460
+ else:
461
+ buffer_input_ids = torch.cat(
462
+ (
463
+ buffer_input_ids,
464
+ torch.full(
465
+ (max_tokens - buffer_input_ids.numel(),),
466
+ tokenizer.pad_token_id,
467
+ dtype=torch.long,
468
+ ),
469
+ ),
470
+ dim=0,
471
+ )
472
+ buffer_attention_mask = torch.cat(
473
+ (
474
+ buffer_attention_mask,
475
+ torch.full(
476
+ (max_tokens - buffer_attention_mask.numel(),),
477
+ 0,
478
+ dtype=torch.long,
479
+ ),
480
+ ),
481
+ dim=0,
482
+ )
483
+ new_input_ids.append(buffer_input_ids)
484
+ new_attention_mask.append(buffer_attention_mask)
485
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
486
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
487
+
488
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
489
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
490
+
491
+ if buffer_input_ids.numel() > 0: # for any leftover tokens
492
+ while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
493
+ buffer_input_ids = torch.cat(
494
+ (
495
+ buffer_input_ids,
496
+ torch.full(
497
+ (max_tokens - buffer_input_ids.numel(),),
498
+ tokenizer.pad_token_id,
499
+ dtype=torch.long,
500
+ ),
501
+ ),
502
+ dim=0,
503
+ )
504
+ buffer_attention_mask = torch.cat(
505
+ (
506
+ buffer_attention_mask,
507
+ torch.full(
508
+ (max_tokens - buffer_attention_mask.numel(),),
509
+ 0,
510
+ dtype=torch.long,
511
+ ),
512
+ ),
513
+ dim=0,
514
+ )
515
+ new_input_ids.append(buffer_input_ids)
516
+ new_attention_mask.append(buffer_attention_mask)
517
+
518
+ ret = {
519
+ "input_ids": [seq.tolist() for seq in new_input_ids],
520
+ "labels": [seq.tolist() for seq in new_input_ids],
521
+ "attention_mask": [seq.tolist() for seq in new_attention_mask],
522
+ }
523
+
524
+ logging.debug(len(ret["input_ids"]))
525
+ return ret
526
+
527
+
528
+ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
529
+ encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
530
+ dataset = load_dataset(path, streaming=True, split="train")
531
+ dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
532
+ # TODO dynamically figure out which columns/features to remove
533
+ dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
534
+ return dataset
src/axolotl/utils/models.py CHANGED
@@ -10,13 +10,15 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
- from transformers import PreTrainedModel # noqa: F401
14
  from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
17
  AutoTokenizer,
18
  BitsAndBytesConfig,
19
  LlamaConfig,
 
 
20
  )
21
 
22
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
@@ -32,15 +34,20 @@ def load_tokenizer(
32
  tokenizer_type,
33
  cfg,
34
  ):
 
 
 
35
  if tokenizer_type:
36
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
37
  tokenizer_config,
38
  trust_remote_code=cfg.trust_remote_code or False,
 
39
  )
40
  else:
41
  tokenizer = AutoTokenizer.from_pretrained(
42
  tokenizer_config,
43
  trust_remote_code=cfg.trust_remote_code or False,
 
44
  )
45
 
46
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -70,7 +77,7 @@ def load_tokenizer(
70
  def load_model(
71
  base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
72
  ):
73
- # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
74
  """
75
  Load a model from a base model and a model type.
76
  """
@@ -121,9 +128,9 @@ def load_model(
121
  logging.info("patching with xpos rope")
122
  replace_llama_rope_with_xpos_rope()
123
 
124
- if cfg.bf16:
125
  torch_dtype = torch.bfloat16
126
- elif cfg.load_in_8bit or cfg.fp16:
127
  torch_dtype = torch.float16
128
  else:
129
  torch_dtype = torch.float32
@@ -195,7 +202,7 @@ def load_model(
195
  else True,
196
  )
197
  load_in_8bit = False
198
- elif cfg.is_llama_derived_model:
199
  from transformers import LlamaForCausalLM
200
 
201
  config = LlamaConfig.from_pretrained(base_model_config)
@@ -234,7 +241,7 @@ def load_model(
234
  # device=cfg.device,
235
  # )
236
  # model.train() # sets to train instead of eval mode
237
- elif model_type:
238
  model = getattr(transformers, model_type).from_pretrained(
239
  base_model,
240
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
@@ -251,11 +258,16 @@ def load_model(
251
  )
252
  # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
253
  # when training starts
254
- if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
 
 
 
 
255
  config.max_seq_len = cfg.sequence_len
256
  logging.warning(f"increasing context length to {cfg.sequence_len}")
257
  elif (
258
  hasattr(config, "max_sequence_length")
 
259
  and cfg.sequence_len > config.max_sequence_length
260
  ):
261
  config.max_sequence_length = cfg.sequence_len
@@ -278,6 +290,7 @@ def load_model(
278
  model = AutoModelForCausalLM.from_pretrained(
279
  base_model,
280
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
281
  torch_dtype=torch_dtype,
282
  device_map=cfg.device_map,
283
  trust_remote_code=cfg.trust_remote_code or False,
@@ -287,6 +300,16 @@ def load_model(
287
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
288
  model.resize_token_embeddings(embeddings_len)
289
 
 
 
 
 
 
 
 
 
 
 
290
  if not cfg.gptq and (
291
  (cfg.adapter == "lora" and load_in_8bit)
292
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -332,6 +355,9 @@ def load_model(
332
  logging.warning("there are no parameters that require gradient updates")
333
  model.config.use_cache = False
334
 
 
 
 
335
  # TODO resume_from_checkpoint handling
336
  return model, lora_config
337
 
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
+ from optimum.bettertransformer import BetterTransformer
14
  from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
17
  AutoTokenizer,
18
  BitsAndBytesConfig,
19
  LlamaConfig,
20
+ PreTrainedModel,
21
+ PreTrainedTokenizerBase,
22
  )
23
 
24
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
 
34
  tokenizer_type,
35
  cfg,
36
  ):
37
+ use_fast = True # this is the default
38
+ if cfg.tokenizer_use_fast is not None:
39
+ use_fast = cfg.tokenizer_use_fast
40
  if tokenizer_type:
41
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
42
  tokenizer_config,
43
  trust_remote_code=cfg.trust_remote_code or False,
44
+ use_fast=use_fast,
45
  )
46
  else:
47
  tokenizer = AutoTokenizer.from_pretrained(
48
  tokenizer_config,
49
  trust_remote_code=cfg.trust_remote_code or False,
50
+ use_fast=use_fast,
51
  )
52
 
53
  logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
 
77
  def load_model(
78
  base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
79
  ):
80
+ # type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
81
  """
82
  Load a model from a base model and a model type.
83
  """
 
128
  logging.info("patching with xpos rope")
129
  replace_llama_rope_with_xpos_rope()
130
 
131
+ if cfg.bf16 or cfg.bfloat16:
132
  torch_dtype = torch.bfloat16
133
+ elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
134
  torch_dtype = torch.float16
135
  else:
136
  torch_dtype = torch.float32
 
202
  else True,
203
  )
204
  load_in_8bit = False
205
+ elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
206
  from transformers import LlamaForCausalLM
207
 
208
  config = LlamaConfig.from_pretrained(base_model_config)
 
241
  # device=cfg.device,
242
  # )
243
  # model.train() # sets to train instead of eval mode
244
+ elif model_type and not cfg.trust_remote_code:
245
  model = getattr(transformers, model_type).from_pretrained(
246
  base_model,
247
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
 
258
  )
259
  # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
260
  # when training starts
261
+ if (
262
+ hasattr(config, "max_seq_len")
263
+ and config.max_seq_len
264
+ and cfg.sequence_len > config.max_seq_len
265
+ ):
266
  config.max_seq_len = cfg.sequence_len
267
  logging.warning(f"increasing context length to {cfg.sequence_len}")
268
  elif (
269
  hasattr(config, "max_sequence_length")
270
+ and config.max_sequence_length
271
  and cfg.sequence_len > config.max_sequence_length
272
  ):
273
  config.max_sequence_length = cfg.sequence_len
 
290
  model = AutoModelForCausalLM.from_pretrained(
291
  base_model,
292
  load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
293
+ load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
294
  torch_dtype=torch_dtype,
295
  device_map=cfg.device_map,
296
  trust_remote_code=cfg.trust_remote_code or False,
 
300
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
301
  model.resize_token_embeddings(embeddings_len)
302
 
303
+ if (
304
+ hasattr(model.config, "max_position_embeddings")
305
+ and model.config.max_position_embeddings
306
+ and cfg.sequence_len >= model.config.max_position_embeddings
307
+ ):
308
+ logging.warning(
309
+ f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
310
+ )
311
+ model.config.max_position_embeddings = cfg.sequence_len
312
+
313
  if not cfg.gptq and (
314
  (cfg.adapter == "lora" and load_in_8bit)
315
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
355
  logging.warning("there are no parameters that require gradient updates")
356
  model.config.use_cache = False
357
 
358
+ if cfg.flash_optimum:
359
+ model = BetterTransformer.transform(model)
360
+
361
  # TODO resume_from_checkpoint handling
362
  return model, lora_config
363
 
src/axolotl/utils/tokenization.py CHANGED
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
34
 
35
  logging.info(" ".join(colored_tokens))
36
  logging.info("\n\n\n")
 
 
 
34
 
35
  logging.info(" ".join(colored_tokens))
36
  logging.info("\n\n\n")
37
+
38
+ return " ".join(colored_tokens)
src/axolotl/utils/trainer.py CHANGED
@@ -17,7 +17,10 @@ from torch.optim.lr_scheduler import OneCycleLR
17
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
18
  from transformers.trainer_pt_utils import get_parameter_names
19
 
20
- from axolotl.utils.callbacks import SavePeftModelCallback
 
 
 
21
  from axolotl.utils.schedulers import (
22
  InterpolatingLogScheduler,
23
  get_cosine_schedule_with_quadratic_warmup,
@@ -166,6 +169,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
166
  # TODO search Path("./") for one
167
  training_arguments_kwargs["deepspeed"] = "./ds_config.json"
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  training_args = AxolotlTrainingArguments(
170
  per_device_train_batch_size=cfg.micro_batch_size,
171
  per_device_eval_batch_size=cfg.eval_batch_size
@@ -282,6 +298,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
282
  ]: # only save in rank 0
283
  callbacks.append(SavePeftModelCallback)
284
 
 
 
 
285
  data_collator_kwargs = {
286
  "padding": True,
287
  }
 
17
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
18
  from transformers.trainer_pt_utils import get_parameter_names
19
 
20
+ from axolotl.utils.callbacks import (
21
+ SaveBetterTransformerModelCallback,
22
+ SavePeftModelCallback,
23
+ )
24
  from axolotl.utils.schedulers import (
25
  InterpolatingLogScheduler,
26
  get_cosine_schedule_with_quadratic_warmup,
 
169
  # TODO search Path("./") for one
170
  training_arguments_kwargs["deepspeed"] = "./ds_config.json"
171
 
172
+ if cfg.adam_beta1:
173
+ training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
174
+ if cfg.adam_beta2:
175
+ training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
176
+ if cfg.adam_epsilon:
177
+ training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
178
+ if cfg.max_grad_norm:
179
+ training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
180
+
181
+ if cfg.hub_model_id:
182
+ training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
183
+ training_arguments_kwargs["push_to_hub"] = True
184
+
185
  training_args = AxolotlTrainingArguments(
186
  per_device_train_batch_size=cfg.micro_batch_size,
187
  per_device_eval_batch_size=cfg.eval_batch_size
 
298
  ]: # only save in rank 0
299
  callbacks.append(SavePeftModelCallback)
300
 
301
+ if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
302
+ callbacks.append(SaveBetterTransformerModelCallback)
303
+
304
  data_collator_kwargs = {
305
  "padding": True,
306
  }
src/axolotl/utils/validation.py CHANGED
@@ -2,6 +2,8 @@
2
 
3
  import logging
4
 
 
 
5
 
6
  def validate_config(cfg):
7
  if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -62,7 +64,47 @@ def validate_config(cfg):
62
  ) and cfg.gradient_checkpointing:
63
  raise ValueError("gradient_checkpointing is not supported for MPT models")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # TODO
66
  # MPT 7b
67
  # https://github.com/facebookresearch/bitsandbytes/issues/25
68
- # no 8bit adamw w bf16
 
 
 
 
 
 
 
 
2
 
3
  import logging
4
 
5
+ import torch
6
+
7
 
8
  def validate_config(cfg):
9
  if cfg.gradient_accumulation_steps and cfg.batch_size:
 
64
  ) and cfg.gradient_checkpointing:
65
  raise ValueError("gradient_checkpointing is not supported for MPT models")
66
 
67
+ if cfg.flash_optimum is True:
68
+ if cfg.adapter:
69
+ logging.warning(
70
+ "BetterTransformers probably doesn't work with PEFT adapters"
71
+ )
72
+ if cfg.fp16 or cfg.bf16:
73
+ raise ValueError("AMP is not supported with BetterTransformer")
74
+ if cfg.float16 is not True and cfg.bloat16 is not True:
75
+ logging.warning(
76
+ "You should probably set bfloat16 or float16 to true to "
77
+ "load the model in float16 for BetterTransformers"
78
+ )
79
+ if int(torch.__version__.split(".")[0]) < 2:
80
+ logging.warning("torch>=2.0.0 required")
81
+ raise ValueError(
82
+ f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
83
+ )
84
+
85
+ if cfg.pretraining_dataset and cfg.group_by_length:
86
+ logging.warning(
87
+ "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
+ )
89
+
90
+ if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
91
+ not cfg.optimizer or "adamw" not in cfg.optimizer
92
+ ):
93
+ logging.warning("adamw hyperparameters found, but no adamw optimizer set")
94
+
95
+ if cfg.push_to_hub_model_id:
96
+ raise ValueError(
97
+ "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
98
+ )
99
+
100
  # TODO
101
  # MPT 7b
102
  # https://github.com/facebookresearch/bitsandbytes/issues/25
103
+ # no 8bit adaAmw w bf16
104
+
105
+ # GPT-NeoX
106
+ # evals broken when extending context len
107
+ # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
108
+ # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
109
+ # attention_mask = causal_mask + attention_mask
110
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
tests/test_prompt_tokenizers.py CHANGED
@@ -6,8 +6,16 @@ from pathlib import Path
6
 
7
  from transformers import AutoTokenizer
8
 
9
- from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
10
- from axolotl.prompters import ShareGPTPrompter
 
 
 
 
 
 
 
 
11
 
12
  logging.basicConfig(level="INFO")
13
 
@@ -29,7 +37,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
29
  )
30
 
31
  def test_sharegpt_integration(self):
32
- print(Path(__file__).parent)
33
  with open(
34
  Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
35
  ) as fin:
@@ -53,6 +60,79 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
53
  self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
54
  self.assertEqual(example[fields], tokenized_conversation[fields])
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  if __name__ == "__main__":
58
  unittest.main()
 
6
 
7
  from transformers import AutoTokenizer
8
 
9
+ from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
10
+ from axolotl.prompt_strategies.alpaca_w_system import (
11
+ InstructionWSystemPromptTokenizingStrategy,
12
+ SystemDataPrompter,
13
+ )
14
+ from axolotl.prompt_tokenizers import (
15
+ AlpacaPromptTokenizingStrategy,
16
+ ShareGPTPromptTokenizingStrategy,
17
+ )
18
+ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
19
 
20
  logging.basicConfig(level="INFO")
21
 
 
37
  )
38
 
39
  def test_sharegpt_integration(self):
 
40
  with open(
41
  Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
42
  ) as fin:
 
60
  self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
61
  self.assertEqual(example[fields], tokenized_conversation[fields])
62
 
63
+ def test_no_sys_prompt(self):
64
+ """
65
+ tests the interface between the user and assistant parts
66
+ """
67
+ prompter = NoSystemPrompter()
68
+ # pylint: disable=duplicate-code
69
+ strat = AlpacaPromptTokenizingStrategy(
70
+ prompter,
71
+ self.tokenizer,
72
+ False,
73
+ 2048,
74
+ )
75
+ sample = {
76
+ "instruction": "hello cruel. lorem ipsum dolor sit amet.",
77
+ "output": "world!",
78
+ }
79
+ example = strat.tokenize_prompt(sample)
80
+ world_idx = example["input_ids"].index(3186)
81
+ assert example["labels"][world_idx] == 3186
82
+ assert example["labels"][world_idx - 1] == -100
83
+
84
+ def test_alpaca(self):
85
+ """
86
+ tests the interface between the user and assistant parts
87
+ """
88
+ # pylint: disable=duplicate-code
89
+ prompter = AlpacaPrompter()
90
+ strat = AlpacaPromptTokenizingStrategy(
91
+ prompter,
92
+ self.tokenizer,
93
+ False,
94
+ 2048,
95
+ )
96
+ sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
97
+ example = strat.tokenize_prompt(sample)
98
+ world_idx = example["input_ids"].index(6324)
99
+ assert example["labels"][world_idx] == 6324
100
+ assert example["labels"][world_idx - 1] == -100
101
+
102
+
103
+ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
104
+ """
105
+ Test class for prompt tokenization strategies with sys prompt from the dataset
106
+ """
107
+
108
+ def setUp(self) -> None:
109
+ # pylint: disable=duplicate-code
110
+ self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
111
+ self.tokenizer.add_special_tokens(
112
+ {
113
+ "bos_token": "<s>",
114
+ "eos_token": "</s>",
115
+ "unk_token": "<unk>",
116
+ }
117
+ )
118
+
119
+ def test_system_alpaca(self):
120
+ prompter = SystemDataPrompter(PromptStyle.CHAT.value)
121
+ strat = InstructionWSystemPromptTokenizingStrategy(
122
+ prompter,
123
+ self.tokenizer,
124
+ False,
125
+ 2048,
126
+ )
127
+ sample = {
128
+ "system": "use cot",
129
+ "instruction": "hello!",
130
+ "output": "Hi! How can I help?",
131
+ }
132
+ example = strat.tokenize_prompt(sample)
133
+ assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
134
+ assert example["input_ids"][3] == 11889 # USER
135
+
136
 
137
  if __name__ == "__main__":
138
  unittest.main()
tests/test_prompters.py CHANGED
@@ -2,7 +2,13 @@
2
 
3
  import unittest
4
 
5
- from axolotl.prompters import AlpacaPrompter, PromptStyle
 
 
 
 
 
 
6
 
7
 
8
  class AlpacaPrompterTest(unittest.TestCase):
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
55
  assert "### Response:" not in res
56
  assert "USER:" in res
57
  assert "ASSISTANT:" in res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import unittest
4
 
5
+ from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
6
+ from axolotl.prompters import (
7
+ AlpacaPrompter,
8
+ MultipleChoiceExplainPrompter,
9
+ PromptStyle,
10
+ UnpromptedPrompter,
11
+ )
12
 
13
 
14
  class AlpacaPrompterTest(unittest.TestCase):
 
61
  assert "### Response:" not in res
62
  assert "USER:" in res
63
  assert "ASSISTANT:" in res
64
+
65
+ def test_system_prompt(self):
66
+ prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
67
+ res = next(
68
+ prompter.build_prompt_w_system(
69
+ "use cot", "tell me a joke about the following", "alpacas"
70
+ )
71
+ )
72
+ assert "use cot" in res
73
+ assert res.startswith("use cot")
74
+ assert "### Instruction:" not in res
75
+ assert "### Input:" not in res
76
+ assert "alpacas" in res
77
+ assert "### Response:" not in res
78
+ assert "USER:" in res
79
+ assert "ASSISTANT:" in res
80
+
81
+
82
+ class UnpromptedPrompterTest(unittest.TestCase):
83
+ """
84
+ Test class for UnpromptedPrompter with no system prompts
85
+ """
86
+
87
+ def test_prompt_style_w_none(self):
88
+ prompter = UnpromptedPrompter(prompt_style=None)
89
+ res = next(prompter.build_prompt("tell me a joke"))
90
+ assert "### Instruction:" in res
91
+ assert "tell me a joke" in res
92
+ assert res.startswith("###")
93
+
94
+ def test_prompt_style_w_instruct(self):
95
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
96
+ res = next(
97
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
98
+ )
99
+ assert "### Instruction:" in res
100
+ assert "tell me a joke" in res
101
+ assert res.startswith("###")
102
+
103
+ def test_prompt_style_w_chat(self):
104
+ prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
105
+ res = next(
106
+ prompter.build_prompt("tell me a joke about the following", "alpacas")
107
+ )
108
+ assert "USER:" in res
109
+ assert "tell me a joke" in res
110
+ assert res.startswith("USER:")
111
+
112
+
113
+ class MultipleChoiceExplainPrompterTest(unittest.TestCase):
114
+ """
115
+ Test class for MultipleChoiceExplainPrompter
116
+ """
117
+
118
+ def test_prompt_style_w_chat(self):
119
+ prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
120
+ res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
121
+ assert "USER:" in res
122
+ assert "choose one" in res
123
+ assert "Choose the answer that best answers the question." in res
124
+ assert "- A\n- B\n- C" in res
tests/test_tokenizers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test cases for the tokenizer loading
3
+ """
4
+ import unittest
5
+
6
+ from axolotl.utils.dict import DictDefault
7
+ from axolotl.utils.models import load_tokenizer
8
+
9
+
10
+ class TestTokenizers(unittest.TestCase):
11
+ """
12
+ test class for the load_tokenizer fn
13
+ """
14
+
15
+ def test_default_use_fast(self):
16
+ cfg = DictDefault({})
17
+ tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
18
+ assert "Fast" in tokenizer.__class__.__name__
19
+
20
+ def test_dont_use_fast(self):
21
+ cfg = DictDefault(
22
+ {
23
+ "tokenizer_use_fast": False,
24
+ }
25
+ )
26
+ tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
27
+ assert "Fast" not in tokenizer.__class__.__name__
28
+
29
+
30
+ if __name__ == "__main__":
31
+ unittest.main()
tests/test_validation.py CHANGED
@@ -212,3 +212,104 @@ class ValidationTest(unittest.TestCase):
212
 
213
  with pytest.raises(ValueError, match=regex_exp):
214
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  with pytest.raises(ValueError, match=regex_exp):
214
  validate_config(cfg)
215
+
216
+ def test_flash_optimum(self):
217
+ cfg = DictDefault(
218
+ {
219
+ "flash_optimum": True,
220
+ "adapter": "lora",
221
+ }
222
+ )
223
+
224
+ with self._caplog.at_level(logging.WARNING):
225
+ validate_config(cfg)
226
+ assert any(
227
+ "BetterTransformers probably doesn't work with PEFT adapters"
228
+ in record.message
229
+ for record in self._caplog.records
230
+ )
231
+
232
+ cfg = DictDefault(
233
+ {
234
+ "flash_optimum": True,
235
+ }
236
+ )
237
+
238
+ with self._caplog.at_level(logging.WARNING):
239
+ validate_config(cfg)
240
+ assert any(
241
+ "probably set bfloat16 or float16" in record.message
242
+ for record in self._caplog.records
243
+ )
244
+
245
+ cfg = DictDefault(
246
+ {
247
+ "flash_optimum": True,
248
+ "fp16": True,
249
+ }
250
+ )
251
+ regex_exp = r".*AMP is not supported.*"
252
+
253
+ with pytest.raises(ValueError, match=regex_exp):
254
+ validate_config(cfg)
255
+
256
+ cfg = DictDefault(
257
+ {
258
+ "flash_optimum": True,
259
+ "bf16": True,
260
+ }
261
+ )
262
+ regex_exp = r".*AMP is not supported.*"
263
+
264
+ with pytest.raises(ValueError, match=regex_exp):
265
+ validate_config(cfg)
266
+
267
+ def test_adamw_hyperparams(self):
268
+ cfg = DictDefault(
269
+ {
270
+ "optimizer": None,
271
+ "adam_epsilon": 0.0001,
272
+ }
273
+ )
274
+
275
+ with self._caplog.at_level(logging.WARNING):
276
+ validate_config(cfg)
277
+ assert any(
278
+ "adamw hyperparameters found, but no adamw optimizer set"
279
+ in record.message
280
+ for record in self._caplog.records
281
+ )
282
+
283
+ cfg = DictDefault(
284
+ {
285
+ "optimizer": "adafactor",
286
+ "adam_beta1": 0.0001,
287
+ }
288
+ )
289
+
290
+ with self._caplog.at_level(logging.WARNING):
291
+ validate_config(cfg)
292
+ assert any(
293
+ "adamw hyperparameters found, but no adamw optimizer set"
294
+ in record.message
295
+ for record in self._caplog.records
296
+ )
297
+
298
+ cfg = DictDefault(
299
+ {
300
+ "optimizer": "adamw_bnb_8bit",
301
+ "adam_beta1": 0.9,
302
+ "adam_beta2": 0.99,
303
+ "adam_epsilon": 0.0001,
304
+ }
305
+ )
306
+
307
+ validate_config(cfg)
308
+
309
+ cfg = DictDefault(
310
+ {
311
+ "optimizer": "adafactor",
312
+ }
313
+ )
314
+
315
+ validate_config(cfg)