Merge branch 'main' into quadratic-warmup
Browse files- .github/workflows/base.yml +2 -1
- .github/workflows/main.yml +3 -2
- .github/workflows/tests.yml +1 -0
- .pre-commit-config.yaml +1 -1
- README.md +52 -10
- data/README.md +4 -4
- docker/Dockerfile-base +2 -2
- examples/openllama-3b/config.yml +6 -5
- examples/pythia-12b/README.md +9 -0
- examples/pythia-12b/config.yml +49 -0
- examples/redpajama/config-3b.yml +1 -1
- requirements.txt +1 -0
- scripts/finetune.py +40 -17
- src/axolotl/datasets.py +1 -0
- src/axolotl/prompt_strategies/alpaca_chat.py +42 -6
- src/axolotl/prompt_strategies/alpaca_instruct.py +10 -1
- src/axolotl/prompt_strategies/alpaca_w_system.py +120 -0
- src/axolotl/prompt_tokenizers.py +21 -17
- src/axolotl/prompters.py +29 -37
- src/axolotl/utils/callbacks.py +38 -1
- src/axolotl/utils/data.py +144 -11
- src/axolotl/utils/models.py +33 -7
- src/axolotl/utils/tokenization.py +2 -0
- src/axolotl/utils/trainer.py +20 -1
- src/axolotl/utils/validation.py +43 -1
- tests/test_prompt_tokenizers.py +83 -3
- tests/test_prompters.py +68 -1
- tests/test_tokenizers.py +31 -0
- tests/test_validation.py +101 -0
.github/workflows/base.yml
CHANGED
@@ -12,6 +12,7 @@ jobs:
|
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
13 |
runs-on: self-hosted
|
14 |
strategy:
|
|
|
15 |
matrix:
|
16 |
include:
|
17 |
- cuda: "118"
|
@@ -25,7 +26,7 @@ jobs:
|
|
25 |
pytorch: 2.0.0
|
26 |
axolotl_extras:
|
27 |
- cuda: "117"
|
28 |
-
cuda_version: 11.7.
|
29 |
python_version: "3.9"
|
30 |
pytorch: 1.13.1
|
31 |
axolotl_extras:
|
|
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
13 |
runs-on: self-hosted
|
14 |
strategy:
|
15 |
+
fail-fast: false
|
16 |
matrix:
|
17 |
include:
|
18 |
- cuda: "118"
|
|
|
26 |
pytorch: 2.0.0
|
27 |
axolotl_extras:
|
28 |
- cuda: "117"
|
29 |
+
cuda_version: 11.7.1
|
30 |
python_version: "3.9"
|
31 |
pytorch: 1.13.1
|
32 |
axolotl_extras:
|
.github/workflows/main.yml
CHANGED
@@ -11,6 +11,7 @@ jobs:
|
|
11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
13 |
strategy:
|
|
|
14 |
matrix:
|
15 |
include:
|
16 |
- cuda: cu118
|
@@ -29,7 +30,7 @@ jobs:
|
|
29 |
pytorch: 2.0.0
|
30 |
axolotl_extras: gptq
|
31 |
- cuda: cu117
|
32 |
-
cuda_version: 11.7.
|
33 |
python_version: "3.9"
|
34 |
pytorch: 1.13.1
|
35 |
axolotl_extras:
|
@@ -84,7 +85,7 @@ jobs:
|
|
84 |
pytorch: 2.0.0
|
85 |
axolotl_extras: gptq
|
86 |
- cuda: cu117
|
87 |
-
cuda_version: 11.7.
|
88 |
python_version: "3.9"
|
89 |
pytorch: 1.13.1
|
90 |
axolotl_extras:
|
|
|
11 |
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
12 |
# this job needs to be run on self-hosted GPU runners...
|
13 |
strategy:
|
14 |
+
fail-fast: false
|
15 |
matrix:
|
16 |
include:
|
17 |
- cuda: cu118
|
|
|
30 |
pytorch: 2.0.0
|
31 |
axolotl_extras: gptq
|
32 |
- cuda: cu117
|
33 |
+
cuda_version: 11.7.1
|
34 |
python_version: "3.9"
|
35 |
pytorch: 1.13.1
|
36 |
axolotl_extras:
|
|
|
85 |
pytorch: 2.0.0
|
86 |
axolotl_extras: gptq
|
87 |
- cuda: cu117
|
88 |
+
cuda_version: 11.7.1
|
89 |
python_version: "3.9"
|
90 |
pytorch: 1.13.1
|
91 |
axolotl_extras:
|
.github/workflows/tests.yml
CHANGED
@@ -7,6 +7,7 @@ jobs:
|
|
7 |
test:
|
8 |
runs-on: ubuntu-latest
|
9 |
strategy:
|
|
|
10 |
matrix:
|
11 |
python_version: ["3.9", "3.10"]
|
12 |
timeout-minutes: 10
|
|
|
7 |
test:
|
8 |
runs-on: ubuntu-latest
|
9 |
strategy:
|
10 |
+
fail-fast: false
|
11 |
matrix:
|
12 |
python_version: ["3.9", "3.10"]
|
13 |
timeout-minutes: 10
|
.pre-commit-config.yaml
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
default_language_version:
|
2 |
-
python: python3
|
3 |
|
4 |
repos:
|
5 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
|
1 |
default_language_version:
|
2 |
+
python: python3
|
3 |
|
4 |
repos:
|
5 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
README.md
CHANGED
@@ -138,7 +138,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
138 |
```json
|
139 |
{"instruction": "...", "input": "...", "output": "..."}
|
140 |
```
|
141 |
-
- `sharegpt`: conversations
|
142 |
```json
|
143 |
{"conversations": [{"from": "...", "value": "..."}]}
|
144 |
```
|
@@ -195,6 +195,10 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
195 |
```json
|
196 |
{"message_1": "...", "message_2": "..."}
|
197 |
```
|
|
|
|
|
|
|
|
|
198 |
- `context_qa`: in context question answering from an article
|
199 |
```json
|
200 |
{"article": "...", "question": "...", "answer": "..."}
|
@@ -233,7 +237,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|
233 |
#### How to add custom prompts
|
234 |
|
235 |
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
236 |
-
2. Use your custom file name as the dataset type
|
237 |
|
238 |
Optionally, download some datasets, see [data/README.md](data/README.md)
|
239 |
|
@@ -251,10 +255,18 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
|
251 |
|
252 |
- dataset
|
253 |
```yaml
|
|
|
|
|
|
|
254 |
datasets:
|
255 |
-
- path: vicgalle/alpaca-gpt4
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
type: alpaca # format from earlier
|
257 |
-
sequence_len: 2048 # max token length / prompt
|
258 |
```
|
259 |
|
260 |
- loading
|
@@ -264,6 +276,8 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
|
|
264 |
bf16: true # require >=ampere
|
265 |
fp16: true
|
266 |
tf32: true # require >=ampere
|
|
|
|
|
267 |
```
|
268 |
Note: Repo does not do 4-bit quantization.
|
269 |
|
@@ -300,6 +314,8 @@ model_type: AutoModelForCausalLM
|
|
300 |
tokenizer_type: AutoTokenizer
|
301 |
# Trust remote code for untrusted source
|
302 |
trust_remote_code:
|
|
|
|
|
303 |
|
304 |
# whether you are training a 4-bit GPTQ quantized model
|
305 |
gptq: true
|
@@ -320,10 +336,10 @@ tf32: true # require >=ampere
|
|
320 |
|
321 |
# a list of one or more datasets to finetune the model with
|
322 |
datasets:
|
323 |
-
#
|
324 |
- path: vicgalle/alpaca-gpt4
|
325 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
326 |
-
type: alpaca # format
|
327 |
data_files: # path to source data files
|
328 |
shards: # number of shards to split data into
|
329 |
|
@@ -332,6 +348,8 @@ datasets:
|
|
332 |
dataset_prepared_path: data/last_run_prepared
|
333 |
# push prepared dataset to hub
|
334 |
push_dataset_to_hub: # repo path
|
|
|
|
|
335 |
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
336 |
# required to be true when used in combination with `push_dataset_to_hub`
|
337 |
hf_use_auth_token: # boolean
|
@@ -420,7 +438,15 @@ log_sweep_max_lr:
|
|
420 |
optimizer:
|
421 |
# specify weight decay
|
422 |
weight_decay:
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
425 |
xformers_attention:
|
426 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
@@ -500,16 +526,16 @@ Pass the appropriate flag to the train command:
|
|
500 |
|
501 |
- Pretrained LORA:
|
502 |
```bash
|
503 |
-
--inference --lora_model_dir
|
504 |
```
|
505 |
- Full weights finetune:
|
506 |
```bash
|
507 |
-
--inference --base_model
|
508 |
```
|
509 |
- Full weights finetune w/ a prompt from a text file:
|
510 |
```bash
|
511 |
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
512 |
-
--base_model
|
513 |
```
|
514 |
|
515 |
### Merge LORA to base
|
@@ -520,6 +546,12 @@ Add below flag to train command above
|
|
520 |
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
521 |
```
|
522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
## Common Errors 🧰
|
524 |
|
525 |
> Cuda out of memory
|
@@ -552,6 +584,16 @@ Building something cool with Axolotl? Consider adding a badge to your model card
|
|
552 |
|
553 |
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
554 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
555 |
## Contributing 🤝
|
556 |
|
557 |
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
|
|
138 |
```json
|
139 |
{"instruction": "...", "input": "...", "output": "..."}
|
140 |
```
|
141 |
+
- `sharegpt:chat`: conversations
|
142 |
```json
|
143 |
{"conversations": [{"from": "...", "value": "..."}]}
|
144 |
```
|
|
|
195 |
```json
|
196 |
{"message_1": "...", "message_2": "..."}
|
197 |
```
|
198 |
+
- `alpaca_w_system.load_open_orca`: support for open orca datasets with included system prompts, instruct
|
199 |
+
```json
|
200 |
+
{"system_prompt": "...", "question": "...", "response": "..."}
|
201 |
+
```
|
202 |
- `context_qa`: in context question answering from an article
|
203 |
```json
|
204 |
{"article": "...", "question": "...", "answer": "..."}
|
|
|
237 |
#### How to add custom prompts
|
238 |
|
239 |
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
240 |
+
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
241 |
|
242 |
Optionally, download some datasets, see [data/README.md](data/README.md)
|
243 |
|
|
|
255 |
|
256 |
- dataset
|
257 |
```yaml
|
258 |
+
sequence_len: 2048 # max token length for prompt
|
259 |
+
|
260 |
+
# huggingface repo
|
261 |
datasets:
|
262 |
+
- path: vicgalle/alpaca-gpt4
|
263 |
+
type: alpaca # format from earlier
|
264 |
+
|
265 |
+
# local
|
266 |
+
datasets:
|
267 |
+
- path: json
|
268 |
+
data_files: data.jsonl # or json
|
269 |
type: alpaca # format from earlier
|
|
|
270 |
```
|
271 |
|
272 |
- loading
|
|
|
276 |
bf16: true # require >=ampere
|
277 |
fp16: true
|
278 |
tf32: true # require >=ampere
|
279 |
+
bfloat16: true # require >=ampere, use instead of bf16 when you don't want AMP (automatic mixed precision)
|
280 |
+
float16: true # use instead of fp16 when you don't want AMP
|
281 |
```
|
282 |
Note: Repo does not do 4-bit quantization.
|
283 |
|
|
|
314 |
tokenizer_type: AutoTokenizer
|
315 |
# Trust remote code for untrusted source
|
316 |
trust_remote_code:
|
317 |
+
# use_fast option for tokenizer loading from_pretrained, default to True
|
318 |
+
tokenizer_use_fast:
|
319 |
|
320 |
# whether you are training a 4-bit GPTQ quantized model
|
321 |
gptq: true
|
|
|
336 |
|
337 |
# a list of one or more datasets to finetune the model with
|
338 |
datasets:
|
339 |
+
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
340 |
- path: vicgalle/alpaca-gpt4
|
341 |
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
342 |
+
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
343 |
data_files: # path to source data files
|
344 |
shards: # number of shards to split data into
|
345 |
|
|
|
348 |
dataset_prepared_path: data/last_run_prepared
|
349 |
# push prepared dataset to hub
|
350 |
push_dataset_to_hub: # repo path
|
351 |
+
# push checkpoints to hub
|
352 |
+
hub_model_id: # repo path
|
353 |
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
354 |
# required to be true when used in combination with `push_dataset_to_hub`
|
355 |
hf_use_auth_token: # boolean
|
|
|
438 |
optimizer:
|
439 |
# specify weight decay
|
440 |
weight_decay:
|
441 |
+
# adamw hyperparams
|
442 |
+
adam_beta1:
|
443 |
+
adam_beta2:
|
444 |
+
adam_epsilon:
|
445 |
+
# Gradient clipping max norm
|
446 |
+
max_grad_norm:
|
447 |
+
|
448 |
+
# whether to bettertransformers
|
449 |
+
flash_optimum:
|
450 |
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
451 |
xformers_attention:
|
452 |
# whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
|
|
|
526 |
|
527 |
- Pretrained LORA:
|
528 |
```bash
|
529 |
+
--inference --lora_model_dir="./lora-output-dir"
|
530 |
```
|
531 |
- Full weights finetune:
|
532 |
```bash
|
533 |
+
--inference --base_model="./completed-model"
|
534 |
```
|
535 |
- Full weights finetune w/ a prompt from a text file:
|
536 |
```bash
|
537 |
cat /tmp/prompt.txt | python scripts/finetune.py configs/your_config.yml \
|
538 |
+
--base_model="./completed-model" --inference --prompter=None --load_in_8bit=True
|
539 |
```
|
540 |
|
541 |
### Merge LORA to base
|
|
|
546 |
--merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
547 |
```
|
548 |
|
549 |
+
If you run out of CUDA memory, you can try to merge in system RAM with
|
550 |
+
|
551 |
+
```bash
|
552 |
+
CUDA_VISIBLE_DEVICES="" python3 scripts/finetune.py ...
|
553 |
+
```
|
554 |
+
|
555 |
## Common Errors 🧰
|
556 |
|
557 |
> Cuda out of memory
|
|
|
584 |
|
585 |
[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)
|
586 |
|
587 |
+
## Community Showcase
|
588 |
+
|
589 |
+
Open Access AI Collective
|
590 |
+
- [Minotaur 13b](https://huggingface.co/openaccess-ai-collective/minotaur-13b)
|
591 |
+
- [Manticore 13b](https://huggingface.co/openaccess-ai-collective/manticore-13b)
|
592 |
+
- [Hippogriff 30b](https://huggingface.co/openaccess-ai-collective/hippogriff-30b-chat)
|
593 |
+
|
594 |
+
PocketDoc Labs
|
595 |
+
- [Dan's PersonalityEngine 13b LoRA](https://huggingface.co/PocketDoc/Dans-PersonalityEngine-13b-LoRA)
|
596 |
+
|
597 |
## Contributing 🤝
|
598 |
|
599 |
Bugs? Please check for open issue else create a new [Issue](https://github.com/OpenAccess-AI-Collective/axolotl/issues/new).
|
data/README.md
CHANGED
@@ -10,10 +10,10 @@ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarit
|
|
10 |
## Convert the JSON data files to JSONL.
|
11 |
|
12 |
```shell
|
13 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
14 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
15 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
16 |
-
python3 ./scripts/alpaca_json_to_jsonl.py --
|
17 |
```
|
18 |
---
|
19 |
|
|
|
10 |
## Convert the JSON data files to JSONL.
|
11 |
|
12 |
```shell
|
13 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/alpaca_data_gpt4.json --output data/alpaca_data_gpt4.jsonl
|
14 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/vicuna_cleaned.json --output data/vicuna_cleaned.jsonl
|
15 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/roleplay-similarity_0.6-instruct-dataset.json --output data/roleplay-similarity_0.6-instruct-dataset.jsonl
|
16 |
+
python3 ./scripts/alpaca_json_to_jsonl.py --file data/raw/gpt4-instruct-similarity-0.6-dataset.json --output data/gpt4-instruct-similarity-0.6-dataset.jsonl
|
17 |
```
|
18 |
---
|
19 |
|
docker/Dockerfile-base
CHANGED
@@ -77,7 +77,7 @@ FROM base-builder
|
|
77 |
RUN python3 -m pip uninstall -y apex
|
78 |
RUN git clone https://github.com/NVIDIA/apex
|
79 |
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
80 |
-
RUN cd apex && MAX_JOBS=1 python3 -m pip install
|
81 |
|
82 |
RUN mkdir -p /workspace/builds
|
83 |
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
@@ -97,4 +97,4 @@ RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
|
97 |
RUN git lfs install --skip-repo
|
98 |
RUN pip3 install awscli && \
|
99 |
# The base image ships with `pydantic==1.8.2` which is not working
|
100 |
-
pip3 install -U --no-cache-dir pydantic
|
|
|
77 |
RUN python3 -m pip uninstall -y apex
|
78 |
RUN git clone https://github.com/NVIDIA/apex
|
79 |
# `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
|
80 |
+
RUN cd apex && MAX_JOBS=1 python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
81 |
|
82 |
RUN mkdir -p /workspace/builds
|
83 |
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
|
|
97 |
RUN git lfs install --skip-repo
|
98 |
RUN pip3 install awscli && \
|
99 |
# The base image ships with `pydantic==1.8.2` which is not working
|
100 |
+
pip3 install -U --no-cache-dir pydantic==1.10.10
|
examples/openllama-3b/config.yml
CHANGED
@@ -26,17 +26,18 @@ wandb_watch:
|
|
26 |
wandb_run_id:
|
27 |
wandb_log_model:
|
28 |
output_dir: ./openllama-out
|
29 |
-
|
30 |
-
micro_batch_size:
|
31 |
num_epochs: 3
|
32 |
optimizer: adamw_bnb_8bit
|
33 |
torchdistx_path:
|
34 |
lr_scheduler: cosine
|
35 |
-
learning_rate: 0.
|
36 |
train_on_inputs: false
|
37 |
group_by_length: false
|
|
|
38 |
bf16: false
|
39 |
-
fp16:
|
40 |
tf32: false
|
41 |
gradient_checkpointing: true
|
42 |
early_stopping_patience:
|
@@ -52,7 +53,7 @@ eval_steps: 50
|
|
52 |
save_steps:
|
53 |
debug:
|
54 |
deepspeed:
|
55 |
-
weight_decay: 0.
|
56 |
fsdp:
|
57 |
fsdp_config:
|
58 |
special_tokens:
|
|
|
26 |
wandb_run_id:
|
27 |
wandb_log_model:
|
28 |
output_dir: ./openllama-out
|
29 |
+
gradient_accumulation_steps: 1
|
30 |
+
micro_batch_size: 1
|
31 |
num_epochs: 3
|
32 |
optimizer: adamw_bnb_8bit
|
33 |
torchdistx_path:
|
34 |
lr_scheduler: cosine
|
35 |
+
learning_rate: 0.00001
|
36 |
train_on_inputs: false
|
37 |
group_by_length: false
|
38 |
+
float16: true
|
39 |
bf16: false
|
40 |
+
fp16: false
|
41 |
tf32: false
|
42 |
gradient_checkpointing: true
|
43 |
early_stopping_patience:
|
|
|
53 |
save_steps:
|
54 |
debug:
|
55 |
deepspeed:
|
56 |
+
weight_decay: 0.1
|
57 |
fsdp:
|
58 |
fsdp_config:
|
59 |
special_tokens:
|
examples/pythia-12b/README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pythia 12B
|
2 |
+
|
3 |
+
- Single-GPU A100 only (?)
|
4 |
+
|
5 |
+
```shell
|
6 |
+
python scripts/finetune.py examples/pythia-12b/config.yml
|
7 |
+
```
|
8 |
+
|
9 |
+
⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
|
examples/pythia-12b/config.yml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: EleutherAI/pythia-12b-deduped
|
2 |
+
base_model_config: EleutherAI/pythia-12b-deduped
|
3 |
+
base_model_ignore_patterns: pytorch* # prefer safetensors
|
4 |
+
model_type: GPTNeoXForCausalLM
|
5 |
+
tokenizer_type: AutoTokenizer
|
6 |
+
load_in_8bit: false
|
7 |
+
load_in_4bit: false
|
8 |
+
gptq: false
|
9 |
+
device_map: auto
|
10 |
+
datasets:
|
11 |
+
- path: vicgalle/alpaca-gpt4
|
12 |
+
type: alpaca
|
13 |
+
dataset_prepared_path: last_run_prepared
|
14 |
+
val_set_size: 0.05
|
15 |
+
adapter:
|
16 |
+
lora_model_dir:
|
17 |
+
sequence_len: 2048
|
18 |
+
max_packed_sequence_len: 2048
|
19 |
+
lora_r: 64
|
20 |
+
lora_alpha: 32
|
21 |
+
lora_dropout: 0.0
|
22 |
+
lora_target_modules:
|
23 |
+
lora_target_linear: true
|
24 |
+
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
25 |
+
wandb_project:
|
26 |
+
wandb_watch:
|
27 |
+
wandb_run_id:
|
28 |
+
wandb_log_model:
|
29 |
+
output_dir: ./pythia-12b
|
30 |
+
gradient_accumulation_steps: 1
|
31 |
+
micro_batch_size: 1
|
32 |
+
num_epochs: 5
|
33 |
+
learning_rate: 0.00003
|
34 |
+
optimizer: adamw_bnb_8bit
|
35 |
+
lr_scheduler: cosine
|
36 |
+
train_on_inputs: false
|
37 |
+
group_by_length: false
|
38 |
+
bf16: false
|
39 |
+
fp16: false
|
40 |
+
float16: true
|
41 |
+
tf32: true
|
42 |
+
flash_optimum: true
|
43 |
+
early_stopping_patience:
|
44 |
+
resume_from_checkpoint:
|
45 |
+
local_rank:
|
46 |
+
gradient_checkpointing: true
|
47 |
+
fsdp:
|
48 |
+
fsdp_config:
|
49 |
+
collator_pad_to_longest: true
|
examples/redpajama/config-3b.yml
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
2 |
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
3 |
model_type: GPTNeoXForCausalLM
|
4 |
-
tokenizer_type:
|
5 |
trust_remote_code:
|
6 |
load_in_8bit: false
|
7 |
datasets:
|
|
|
1 |
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
2 |
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
3 |
model_type: GPTNeoXForCausalLM
|
4 |
+
tokenizer_type: AutoTokenizer
|
5 |
trust_remote_code:
|
6 |
load_in_8bit: false
|
7 |
datasets:
|
requirements.txt
CHANGED
@@ -11,6 +11,7 @@ sentencepiece
|
|
11 |
wandb
|
12 |
einops
|
13 |
xformers
|
|
|
14 |
# qlora things
|
15 |
bert-score==0.3.13
|
16 |
evaluate==0.4.0
|
|
|
11 |
wandb
|
12 |
einops
|
13 |
xformers
|
14 |
+
optimum
|
15 |
# qlora things
|
16 |
bert-score==0.3.13
|
17 |
evaluate==0.4.0
|
scripts/finetune.py
CHANGED
@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
|
|
12 |
import fire
|
13 |
import torch
|
14 |
import yaml
|
|
|
|
|
|
|
15 |
from transformers import GenerationConfig, TextStreamer
|
16 |
|
17 |
-
from axolotl.utils.data import load_prepare_datasets
|
18 |
from axolotl.utils.dict import DictDefault
|
19 |
from axolotl.utils.models import load_model, load_tokenizer
|
20 |
-
|
21 |
-
# add src to the pythonpath so we don't need to pip install this
|
22 |
from axolotl.utils.tokenization import check_dataset_labels
|
23 |
from axolotl.utils.trainer import setup_trainer
|
24 |
from axolotl.utils.validation import validate_config
|
@@ -63,7 +64,7 @@ def get_multi_line_input() -> Optional[str]:
|
|
63 |
return instruction
|
64 |
|
65 |
|
66 |
-
def do_inference(cfg, model, tokenizer, prompter
|
67 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
68 |
|
69 |
for token, symbol in default_tokens.items():
|
@@ -217,9 +218,20 @@ def train(
|
|
217 |
if (
|
218 |
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
219 |
): # don't need to load dataset for these
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
if cfg.debug or "debug" in kwargs:
|
225 |
logging.info("check_dataset_labels...")
|
@@ -257,13 +269,13 @@ def train(
|
|
257 |
|
258 |
if cfg.inference:
|
259 |
logging.info("calling do_inference function")
|
260 |
-
|
261 |
if "prompter" in kwargs:
|
262 |
if kwargs["prompter"] == "None":
|
263 |
-
|
264 |
else:
|
265 |
-
|
266 |
-
do_inference(cfg, model, tokenizer,
|
267 |
return
|
268 |
|
269 |
if "shard" in kwargs:
|
@@ -285,12 +297,15 @@ def train(
|
|
285 |
|
286 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
287 |
if cfg.local_rank == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
signal.signal(
|
289 |
-
signal.SIGINT,
|
290 |
-
lambda signal, frame: (
|
291 |
-
model.save_pretrained(cfg.output_dir),
|
292 |
-
sys.exit(0),
|
293 |
-
),
|
294 |
)
|
295 |
|
296 |
logging.info("Starting trainer...")
|
@@ -313,13 +328,21 @@ def train(
|
|
313 |
|
314 |
if not Path(cfg.output_dir).is_dir():
|
315 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
316 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
|
318 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
319 |
|
320 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
321 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
322 |
if cfg.local_rank == 0:
|
|
|
|
|
323 |
model.save_pretrained(cfg.output_dir)
|
324 |
|
325 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
|
|
12 |
import fire
|
13 |
import torch
|
14 |
import yaml
|
15 |
+
|
16 |
+
# add src to the pythonpath so we don't need to pip install this
|
17 |
+
from optimum.bettertransformer import BetterTransformer
|
18 |
from transformers import GenerationConfig, TextStreamer
|
19 |
|
20 |
+
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
21 |
from axolotl.utils.dict import DictDefault
|
22 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
|
|
23 |
from axolotl.utils.tokenization import check_dataset_labels
|
24 |
from axolotl.utils.trainer import setup_trainer
|
25 |
from axolotl.utils.validation import validate_config
|
|
|
64 |
return instruction
|
65 |
|
66 |
|
67 |
+
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
68 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
69 |
|
70 |
for token, symbol in default_tokens.items():
|
|
|
218 |
if (
|
219 |
check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
|
220 |
): # don't need to load dataset for these
|
221 |
+
if not cfg.pretraining_dataset:
|
222 |
+
train_dataset, eval_dataset = load_prepare_datasets(
|
223 |
+
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
train_dataset = load_pretraining_dataset(
|
227 |
+
cfg.pretraining_dataset,
|
228 |
+
tokenizer,
|
229 |
+
max_tokens=cfg.sequence_len,
|
230 |
+
seed=cfg.seed,
|
231 |
+
)
|
232 |
+
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
233 |
+
train_dataset = train_dataset.with_format("torch")
|
234 |
+
eval_dataset = None
|
235 |
|
236 |
if cfg.debug or "debug" in kwargs:
|
237 |
logging.info("check_dataset_labels...")
|
|
|
269 |
|
270 |
if cfg.inference:
|
271 |
logging.info("calling do_inference function")
|
272 |
+
prompter: Optional[str] = "AlpacaPrompter"
|
273 |
if "prompter" in kwargs:
|
274 |
if kwargs["prompter"] == "None":
|
275 |
+
prompter = None
|
276 |
else:
|
277 |
+
prompter = kwargs["prompter"]
|
278 |
+
do_inference(cfg, model, tokenizer, prompter=prompter)
|
279 |
return
|
280 |
|
281 |
if "shard" in kwargs:
|
|
|
297 |
|
298 |
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
299 |
if cfg.local_rank == 0:
|
300 |
+
|
301 |
+
def terminate_handler(_, __, model):
|
302 |
+
if cfg.flash_optimum:
|
303 |
+
model = BetterTransformer.reverse(model)
|
304 |
+
model.save_pretrained(cfg.output_dir)
|
305 |
+
sys.exit(0)
|
306 |
+
|
307 |
signal.signal(
|
308 |
+
signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
|
|
|
|
|
|
|
|
|
309 |
)
|
310 |
|
311 |
logging.info("Starting trainer...")
|
|
|
328 |
|
329 |
if not Path(cfg.output_dir).is_dir():
|
330 |
os.makedirs(cfg.output_dir, exist_ok=True)
|
331 |
+
if cfg.flash_optimum:
|
332 |
+
with torch.backends.cuda.sdp_kernel(
|
333 |
+
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
334 |
+
):
|
335 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
336 |
+
else:
|
337 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
338 |
|
339 |
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
340 |
|
341 |
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
342 |
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
343 |
if cfg.local_rank == 0:
|
344 |
+
if cfg.flash_optimum:
|
345 |
+
model = BetterTransformer.reverse(model)
|
346 |
model.save_pretrained(cfg.output_dir)
|
347 |
|
348 |
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
src/axolotl/datasets.py
CHANGED
@@ -126,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
|
|
126 |
buffer_len = 0
|
127 |
|
128 |
if example:
|
|
|
129 |
# just going to drop data points that are too long
|
130 |
if len(example["input_ids"]) <= self.seq_length:
|
131 |
input_ids = example["input_ids"]
|
|
|
126 |
buffer_len = 0
|
127 |
|
128 |
if example:
|
129 |
+
# FIXME
|
130 |
# just going to drop data points that are too long
|
131 |
if len(example["input_ids"]) <= self.seq_length:
|
132 |
input_ids = example["input_ids"]
|
src/axolotl/prompt_strategies/alpaca_chat.py
CHANGED
@@ -6,7 +6,7 @@ from axolotl.prompt_tokenizers import (
|
|
6 |
AlpacaPromptTokenizingStrategy,
|
7 |
InstructionPromptTokenizingStrategy,
|
8 |
)
|
9 |
-
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
10 |
|
11 |
|
12 |
def load(tokenizer, cfg):
|
@@ -20,11 +20,38 @@ def load(tokenizer, cfg):
|
|
20 |
|
21 |
class AlpacaConcisePrompter(AlpacaPrompter):
|
22 |
"""
|
23 |
-
Alpaca Prompter extending the system prompt to ask for concise answers
|
24 |
"""
|
25 |
|
26 |
-
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context.
|
27 |
-
system_no_input_prompt = "Below is an instruction that describes a task.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
|
30 |
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
@@ -64,7 +91,7 @@ def load_concise(tokenizer, cfg):
|
|
64 |
|
65 |
def load_qa(tokenizer, cfg):
|
66 |
return AlpacaQAPromptTokenizingStrategy(
|
67 |
-
|
68 |
tokenizer,
|
69 |
cfg.train_on_inputs,
|
70 |
cfg.sequence_len,
|
@@ -73,7 +100,16 @@ def load_qa(tokenizer, cfg):
|
|
73 |
|
74 |
def load_camel_ai(tokenizer, cfg):
|
75 |
return CamelAIPromptTokenizingStrategy(
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
tokenizer,
|
78 |
cfg.train_on_inputs,
|
79 |
cfg.sequence_len,
|
|
|
6 |
AlpacaPromptTokenizingStrategy,
|
7 |
InstructionPromptTokenizingStrategy,
|
8 |
)
|
9 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
10 |
|
11 |
|
12 |
def load(tokenizer, cfg):
|
|
|
20 |
|
21 |
class AlpacaConcisePrompter(AlpacaPrompter):
|
22 |
"""
|
23 |
+
Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
|
24 |
"""
|
25 |
|
26 |
+
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
27 |
+
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
28 |
+
|
29 |
+
|
30 |
+
class AlpacaChatPrompter(AlpacaPrompter):
|
31 |
+
"""
|
32 |
+
Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
|
33 |
+
"""
|
34 |
+
|
35 |
+
system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
|
36 |
+
system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"
|
37 |
+
|
38 |
+
def __init__(self): # pylint: disable=super-init-not-called
|
39 |
+
self.prompt_style = PromptStyle.CHAT.value
|
40 |
+
self.match_prompt_style()
|
41 |
+
|
42 |
+
|
43 |
+
class NoSystemPrompter(AlpacaPrompter):
|
44 |
+
"""
|
45 |
+
Null Prompter with no system prompts
|
46 |
+
"""
|
47 |
+
|
48 |
+
system_prompt = ""
|
49 |
+
system_no_input_prompt = ""
|
50 |
+
turn_format = "{instruction} {input} "
|
51 |
+
turn_no_input_format = "{instruction} "
|
52 |
+
|
53 |
+
def __init__(self): # pylint: disable=super-init-not-called
|
54 |
+
pass
|
55 |
|
56 |
|
57 |
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
|
91 |
|
92 |
def load_qa(tokenizer, cfg):
|
93 |
return AlpacaQAPromptTokenizingStrategy(
|
94 |
+
AlpacaChatPrompter(),
|
95 |
tokenizer,
|
96 |
cfg.train_on_inputs,
|
97 |
cfg.sequence_len,
|
|
|
100 |
|
101 |
def load_camel_ai(tokenizer, cfg):
|
102 |
return CamelAIPromptTokenizingStrategy(
|
103 |
+
AlpacaChatPrompter(),
|
104 |
+
tokenizer,
|
105 |
+
cfg.train_on_inputs,
|
106 |
+
cfg.sequence_len,
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
def load_no_prompt(tokenizer, cfg):
|
111 |
+
return AlpacaPromptTokenizingStrategy(
|
112 |
+
UnpromptedPrompter(PromptStyle.CHAT.value),
|
113 |
tokenizer,
|
114 |
cfg.train_on_inputs,
|
115 |
cfg.sequence_len,
|
src/axolotl/prompt_strategies/alpaca_instruct.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
|
2 |
|
3 |
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
4 |
-
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
5 |
|
6 |
|
7 |
def load(tokenizer, cfg):
|
@@ -11,3 +11,12 @@ def load(tokenizer, cfg):
|
|
11 |
cfg.train_on_inputs,
|
12 |
cfg.sequence_len,
|
13 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Module loading the AlpacaInstructPromptTokenizingStrategy class"""
|
2 |
|
3 |
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
|
4 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
5 |
|
6 |
|
7 |
def load(tokenizer, cfg):
|
|
|
11 |
cfg.train_on_inputs,
|
12 |
cfg.sequence_len,
|
13 |
)
|
14 |
+
|
15 |
+
|
16 |
+
def load_no_prompt(tokenizer, cfg):
|
17 |
+
return AlpacaPromptTokenizingStrategy(
|
18 |
+
UnpromptedPrompter(PromptStyle.INSTRUCT.value),
|
19 |
+
tokenizer,
|
20 |
+
cfg.train_on_inputs,
|
21 |
+
cfg.sequence_len,
|
22 |
+
)
|
src/axolotl/prompt_strategies/alpaca_w_system.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Prompt strategies loader for alpaca instruction datasets with system prompts
|
3 |
+
"""
|
4 |
+
from typing import Generator, Tuple, Union
|
5 |
+
|
6 |
+
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
7 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle
|
8 |
+
|
9 |
+
|
10 |
+
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy):
|
11 |
+
"""
|
12 |
+
Tokenizing strategy for instruction-based prompts.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
16 |
+
return (
|
17 |
+
prompt["instruction"],
|
18 |
+
prompt["input"] if "input" in prompt else "",
|
19 |
+
prompt["output"],
|
20 |
+
prompt["system"],
|
21 |
+
)
|
22 |
+
|
23 |
+
def tokenize_prompt(self, prompt):
|
24 |
+
# pylint: disable=duplicate-code
|
25 |
+
(
|
26 |
+
instruction,
|
27 |
+
input, # pylint: disable=redefined-builtin
|
28 |
+
response,
|
29 |
+
system,
|
30 |
+
) = self.parse_instruction_fields(prompt)
|
31 |
+
user_prompt = next(
|
32 |
+
iter(
|
33 |
+
self.prompter.build_prompt_w_system(
|
34 |
+
system,
|
35 |
+
instruction,
|
36 |
+
input,
|
37 |
+
)
|
38 |
+
)
|
39 |
+
)
|
40 |
+
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
41 |
+
if not self.train_on_inputs:
|
42 |
+
user_prompt_len = len(tokenized_prompt["input_ids"])
|
43 |
+
# TODO this could be sped up using numpy array slicing
|
44 |
+
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
45 |
+
tokenized_res_prompt = self._tokenize(
|
46 |
+
response, strip_bos_token=True, add_eos_token=True
|
47 |
+
)
|
48 |
+
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
49 |
+
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
50 |
+
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
51 |
+
|
52 |
+
return tokenized_prompt
|
53 |
+
|
54 |
+
|
55 |
+
class SystemDataPrompter(AlpacaPrompter):
|
56 |
+
"""
|
57 |
+
Alpaca Style Prompter that uses system prompts from the dataset
|
58 |
+
"""
|
59 |
+
|
60 |
+
def build_prompt_w_system(
|
61 |
+
self,
|
62 |
+
system: str,
|
63 |
+
instruction: str,
|
64 |
+
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
65 |
+
output: Union[None, str] = None,
|
66 |
+
) -> Generator[str, None, None]:
|
67 |
+
# returns the full prompt from instruction and optional input
|
68 |
+
# if a label (=response, =output) is provided, it's also appended.
|
69 |
+
if input:
|
70 |
+
res = system + self.turn_format.format(instruction=instruction, input=input)
|
71 |
+
else:
|
72 |
+
res = system + self.turn_no_input_format.format(instruction=instruction)
|
73 |
+
if output:
|
74 |
+
res = f"{res}{output}"
|
75 |
+
yield res
|
76 |
+
|
77 |
+
|
78 |
+
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
79 |
+
"""
|
80 |
+
Tokenizing strategy for OpenOrca datasets
|
81 |
+
"""
|
82 |
+
|
83 |
+
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]:
|
84 |
+
return (
|
85 |
+
prompt["question"],
|
86 |
+
"",
|
87 |
+
prompt["response"],
|
88 |
+
prompt["system_prompt"],
|
89 |
+
)
|
90 |
+
|
91 |
+
|
92 |
+
def load(tokenizer, cfg):
|
93 |
+
return load_chat(tokenizer, cfg)
|
94 |
+
|
95 |
+
|
96 |
+
def load_instruct(tokenizer, cfg):
|
97 |
+
return InstructionWSystemPromptTokenizingStrategy(
|
98 |
+
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
99 |
+
tokenizer,
|
100 |
+
cfg.train_on_inputs,
|
101 |
+
cfg.sequence_len,
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
def load_chat(tokenizer, cfg):
|
106 |
+
return InstructionWSystemPromptTokenizingStrategy(
|
107 |
+
SystemDataPrompter(PromptStyle.CHAT.value),
|
108 |
+
tokenizer,
|
109 |
+
cfg.train_on_inputs,
|
110 |
+
cfg.sequence_len,
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
def load_open_orca(tokenizer, cfg):
|
115 |
+
return OpenOrcaPromptTokenizingStrategy(
|
116 |
+
SystemDataPrompter(PromptStyle.INSTRUCT.value),
|
117 |
+
tokenizer,
|
118 |
+
cfg.train_on_inputs,
|
119 |
+
cfg.sequence_len,
|
120 |
+
)
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -87,7 +87,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
87 |
Tokenizing strategy for instruction-based prompts.
|
88 |
"""
|
89 |
|
90 |
-
def parse_instruction_fields(
|
|
|
|
|
91 |
raise NotImplementedError
|
92 |
|
93 |
def tokenize_prompt(self, prompt):
|
@@ -96,25 +98,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|
96 |
input, # pylint: disable=redefined-builtin
|
97 |
response,
|
98 |
) = self.parse_instruction_fields(prompt)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
self.prompter.build_prompt(
|
105 |
-
instruction,
|
106 |
-
input,
|
107 |
-
)
|
108 |
)
|
109 |
)
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
# TODO this could be sped up using numpy array slicing
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
|
117 |
-
return
|
118 |
|
119 |
def _build_full_prompt(
|
120 |
self, instruction, input, response # pylint: disable=redefined-builtin
|
@@ -436,7 +440,7 @@ def parse_tokenized_to_result(
|
|
436 |
result: Dict[str, List[int]],
|
437 |
current_len: int,
|
438 |
res: Dict[str, List[int]],
|
439 |
-
labels:
|
440 |
pad_token_id: Union[int, None] = None,
|
441 |
) -> Tuple[Dict[str, List[int]], int]:
|
442 |
"""
|
|
|
87 |
Tokenizing strategy for instruction-based prompts.
|
88 |
"""
|
89 |
|
90 |
+
def parse_instruction_fields(
|
91 |
+
self, prompt
|
92 |
+
) -> Union[Tuple[str, str, str], Tuple[str, str, str, str]]:
|
93 |
raise NotImplementedError
|
94 |
|
95 |
def tokenize_prompt(self, prompt):
|
|
|
98 |
input, # pylint: disable=redefined-builtin
|
99 |
response,
|
100 |
) = self.parse_instruction_fields(prompt)
|
101 |
+
user_prompt = next(
|
102 |
+
iter(
|
103 |
+
self.prompter.build_prompt(
|
104 |
+
instruction,
|
105 |
+
input,
|
|
|
|
|
|
|
|
|
106 |
)
|
107 |
)
|
108 |
+
)
|
109 |
+
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
|
110 |
+
if not self.train_on_inputs:
|
111 |
+
user_prompt_len = len(tokenized_prompt["input_ids"])
|
112 |
# TODO this could be sped up using numpy array slicing
|
113 |
+
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
114 |
+
tokenized_res_prompt = self._tokenize(
|
115 |
+
response, strip_bos_token=True, add_eos_token=True
|
116 |
+
)
|
117 |
+
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
|
118 |
+
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
|
119 |
+
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
|
120 |
|
121 |
+
return tokenized_prompt
|
122 |
|
123 |
def _build_full_prompt(
|
124 |
self, instruction, input, response # pylint: disable=redefined-builtin
|
|
|
440 |
result: Dict[str, List[int]],
|
441 |
current_len: int,
|
442 |
res: Dict[str, List[int]],
|
443 |
+
labels: List[int],
|
444 |
pad_token_id: Union[int, None] = None,
|
445 |
) -> Tuple[Dict[str, List[int]], int]:
|
446 |
"""
|
src/axolotl/prompters.py
CHANGED
@@ -24,6 +24,8 @@ class AlpacaPrompter:
|
|
24 |
|
25 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
26 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
|
|
|
|
27 |
prompt_style: Optional[PromptStyle] = None
|
28 |
|
29 |
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
@@ -32,23 +34,13 @@ class AlpacaPrompter:
|
|
32 |
|
33 |
def match_prompt_style(self):
|
34 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
35 |
-
self.
|
36 |
-
|
37 |
-
|
38 |
-
)
|
39 |
-
self.prompt_no_input = (
|
40 |
-
self.system_no_input_prompt
|
41 |
-
+ "### Instruction:\n{instruction}\n\n### Response:\n"
|
42 |
)
|
43 |
-
self.response_split = "### Response:"
|
44 |
if self.prompt_style == PromptStyle.CHAT.value:
|
45 |
-
self.
|
46 |
-
|
47 |
-
)
|
48 |
-
self.prompt_no_input = (
|
49 |
-
self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
50 |
-
)
|
51 |
-
self.response_split = "ASSISTANT:"
|
52 |
|
53 |
def build_prompt(
|
54 |
self,
|
@@ -59,16 +51,17 @@ class AlpacaPrompter:
|
|
59 |
# returns the full prompt from instruction and optional input
|
60 |
# if a label (=response, =output) is provided, it's also appended.
|
61 |
if input:
|
62 |
-
res = self.
|
|
|
|
|
63 |
else:
|
64 |
-
res = self.
|
|
|
|
|
65 |
if output:
|
66 |
res = f"{res}{output}"
|
67 |
yield res
|
68 |
|
69 |
-
def get_response(self, output: str) -> str:
|
70 |
-
return output.split(self.response_split)[1].strip()
|
71 |
-
|
72 |
|
73 |
class UnpromptedPrompter(AlpacaPrompter):
|
74 |
"""
|
@@ -93,7 +86,10 @@ class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
|
93 |
"""
|
94 |
|
95 |
system_prompt = (
|
96 |
-
"Choose the answer that best answers the question. Explain your reasoning
|
|
|
|
|
|
|
97 |
)
|
98 |
|
99 |
|
@@ -102,7 +98,12 @@ class MultipleChoiceConcisePrompter(AlpacaPrompter):
|
|
102 |
Prompter for multiple choice concise
|
103 |
"""
|
104 |
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
|
108 |
class SummarizeTLDRPrompter(AlpacaPrompter):
|
@@ -110,9 +111,12 @@ class SummarizeTLDRPrompter(AlpacaPrompter):
|
|
110 |
Prompter for summarize TLDR
|
111 |
"""
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
|
117 |
|
118 |
class CompletionPrompter:
|
@@ -128,9 +132,6 @@ class CompletionPrompter:
|
|
128 |
) -> Generator[str, None, None]:
|
129 |
yield instruction
|
130 |
|
131 |
-
def get_response(self, output: str) -> str:
|
132 |
-
return output.strip()
|
133 |
-
|
134 |
|
135 |
class GPTeacherPrompter(AlpacaPrompter):
|
136 |
"""
|
@@ -210,9 +211,6 @@ class ReflectAlpacaPrompter:
|
|
210 |
res = f"{res}{label}"
|
211 |
yield res
|
212 |
|
213 |
-
def get_response(self, output: str) -> str:
|
214 |
-
return output.split(self.response_split)[1].strip()
|
215 |
-
|
216 |
|
217 |
class SeparatorStyle(Enum):
|
218 |
"""Different separator style."""
|
@@ -289,12 +287,6 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|
289 |
sep2=" ",
|
290 |
)
|
291 |
|
292 |
-
# def match_prompt_style(self):
|
293 |
-
# if self.prompt_style == PromptStyle.chat.value:
|
294 |
-
# self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:"
|
295 |
-
# self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:"
|
296 |
-
# self.response_split = "ASSISTANT:"
|
297 |
-
|
298 |
def build_prompt(self, source) -> Generator[str, None, None]:
|
299 |
# ignore the system prompt if provided
|
300 |
if source[0]["from"] == "system":
|
|
|
24 |
|
25 |
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
26 |
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
27 |
+
turn_format: str
|
28 |
+
turn_no_input_format: str
|
29 |
prompt_style: Optional[PromptStyle] = None
|
30 |
|
31 |
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
|
|
34 |
|
35 |
def match_prompt_style(self):
|
36 |
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
37 |
+
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
38 |
+
self.turn_no_input_format = (
|
39 |
+
"### Instruction:\n{instruction}\n\n### Response:\n"
|
|
|
|
|
|
|
|
|
40 |
)
|
|
|
41 |
if self.prompt_style == PromptStyle.CHAT.value:
|
42 |
+
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
43 |
+
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
def build_prompt(
|
46 |
self,
|
|
|
51 |
# returns the full prompt from instruction and optional input
|
52 |
# if a label (=response, =output) is provided, it's also appended.
|
53 |
if input:
|
54 |
+
res = self.system_prompt + self.turn_format.format(
|
55 |
+
instruction=instruction, input=input
|
56 |
+
)
|
57 |
else:
|
58 |
+
res = self.system_no_input_prompt + self.turn_no_input_format.format(
|
59 |
+
instruction=instruction
|
60 |
+
)
|
61 |
if output:
|
62 |
res = f"{res}{output}"
|
63 |
yield res
|
64 |
|
|
|
|
|
|
|
65 |
|
66 |
class UnpromptedPrompter(AlpacaPrompter):
|
67 |
"""
|
|
|
86 |
"""
|
87 |
|
88 |
system_prompt = (
|
89 |
+
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
90 |
+
)
|
91 |
+
system_no_input_prompt = (
|
92 |
+
"Choose the answer that best answers the question. Explain your reasoning.\n"
|
93 |
)
|
94 |
|
95 |
|
|
|
98 |
Prompter for multiple choice concise
|
99 |
"""
|
100 |
|
101 |
+
system_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
102 |
+
system_no_input_prompt = "Choose the answer that best answers the question. Be concise in your response.\n\n"
|
103 |
+
|
104 |
+
def match_prompt_style(self):
|
105 |
+
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
106 |
+
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
107 |
|
108 |
|
109 |
class SummarizeTLDRPrompter(AlpacaPrompter):
|
|
|
111 |
Prompter for summarize TLDR
|
112 |
"""
|
113 |
|
114 |
+
system_prompt = ""
|
115 |
+
system_no_input_prompt = ""
|
116 |
+
|
117 |
+
def match_prompt_style(self):
|
118 |
+
self.turn_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\n{input}\nASSISTANT:"
|
119 |
+
self.turn_no_input_format = "USER: Summarize the following article as a TL;DR.\n{instruction}\nASSISTANT:"
|
120 |
|
121 |
|
122 |
class CompletionPrompter:
|
|
|
132 |
) -> Generator[str, None, None]:
|
133 |
yield instruction
|
134 |
|
|
|
|
|
|
|
135 |
|
136 |
class GPTeacherPrompter(AlpacaPrompter):
|
137 |
"""
|
|
|
211 |
res = f"{res}{label}"
|
212 |
yield res
|
213 |
|
|
|
|
|
|
|
214 |
|
215 |
class SeparatorStyle(Enum):
|
216 |
"""Different separator style."""
|
|
|
287 |
sep2=" ",
|
288 |
)
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
def build_prompt(self, source) -> Generator[str, None, None]:
|
291 |
# ignore the system prompt if provided
|
292 |
if source[0]["from"] == "system":
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -2,13 +2,14 @@
|
|
2 |
|
3 |
import os
|
4 |
|
|
|
5 |
from transformers import (
|
6 |
TrainerCallback,
|
7 |
TrainerControl,
|
8 |
TrainerState,
|
9 |
TrainingArguments,
|
10 |
)
|
11 |
-
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
12 |
|
13 |
|
14 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
30 |
kwargs["model"].save_pretrained(peft_model_path)
|
31 |
|
32 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import os
|
4 |
|
5 |
+
from optimum.bettertransformer import BetterTransformer
|
6 |
from transformers import (
|
7 |
TrainerCallback,
|
8 |
TrainerControl,
|
9 |
TrainerState,
|
10 |
TrainingArguments,
|
11 |
)
|
12 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
13 |
|
14 |
|
15 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
|
|
31 |
kwargs["model"].save_pretrained(peft_model_path)
|
32 |
|
33 |
return control
|
34 |
+
|
35 |
+
|
36 |
+
class SaveBetterTransformerModelCallback(
|
37 |
+
TrainerCallback
|
38 |
+
): # pylint: disable=too-few-public-methods
|
39 |
+
"""Callback to save the BetterTransformer wrapped model"""
|
40 |
+
|
41 |
+
def on_step_end(
|
42 |
+
self,
|
43 |
+
args: TrainingArguments,
|
44 |
+
state: TrainerState,
|
45 |
+
control: TrainerControl,
|
46 |
+
**kwargs,
|
47 |
+
):
|
48 |
+
# Save
|
49 |
+
if (
|
50 |
+
args.save_strategy == IntervalStrategy.STEPS
|
51 |
+
and args.save_steps > 0
|
52 |
+
and state.global_step % args.save_steps == 0
|
53 |
+
):
|
54 |
+
control.should_save = True
|
55 |
+
|
56 |
+
if control.should_save:
|
57 |
+
checkpoint_folder = os.path.join(
|
58 |
+
args.output_dir,
|
59 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
60 |
+
)
|
61 |
+
|
62 |
+
model = BetterTransformer.reverse(kwargs["model"])
|
63 |
+
model.save_pretrained(checkpoint_folder)
|
64 |
+
# FIXME - need to cleanup old checkpoints
|
65 |
+
|
66 |
+
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
67 |
+
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
68 |
+
control.should_save = False
|
69 |
+
return control
|
src/axolotl/utils/data.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
"""Module containing data utilities"""
|
2 |
-
|
3 |
import logging
|
4 |
from hashlib import md5
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
|
|
8 |
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
from transformers import PreTrainedTokenizerBase
|
@@ -101,13 +102,26 @@ def load_tokenized_prepared_datasets(
|
|
101 |
pass
|
102 |
|
103 |
# prefer local dataset, even if hub exists
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
elif ds_from_hub:
|
112 |
if d.data_files:
|
113 |
ds = load_dataset(
|
@@ -394,8 +408,127 @@ def load_prepare_datasets(
|
|
394 |
index=cfg.dataset_shard_idx,
|
395 |
)
|
396 |
|
397 |
-
|
398 |
-
|
399 |
-
|
|
|
|
|
|
|
|
|
400 |
|
401 |
return train_dataset, eval_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""Module containing data utilities"""
|
2 |
+
import functools
|
3 |
import logging
|
4 |
from hashlib import md5
|
5 |
from pathlib import Path
|
6 |
from typing import List, Tuple, Union
|
7 |
|
8 |
+
import torch
|
9 |
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
|
10 |
from huggingface_hub import hf_hub_download
|
11 |
from transformers import PreTrainedTokenizerBase
|
|
|
102 |
pass
|
103 |
|
104 |
# prefer local dataset, even if hub exists
|
105 |
+
local_path = Path(d.path)
|
106 |
+
if local_path.exists():
|
107 |
+
if local_path.is_dir():
|
108 |
+
ds = load_dataset(
|
109 |
+
d.path,
|
110 |
+
data_files=d.data_files,
|
111 |
+
streaming=False,
|
112 |
+
split=None,
|
113 |
+
)
|
114 |
+
elif local_path.is_file():
|
115 |
+
ds = load_dataset(
|
116 |
+
"json",
|
117 |
+
data_files=d.path,
|
118 |
+
streaming=False,
|
119 |
+
split=None,
|
120 |
+
)
|
121 |
+
else:
|
122 |
+
raise ValueError(
|
123 |
+
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
124 |
+
)
|
125 |
elif ds_from_hub:
|
126 |
if d.data_files:
|
127 |
ds = load_dataset(
|
|
|
408 |
index=cfg.dataset_shard_idx,
|
409 |
)
|
410 |
|
411 |
+
if cfg.val_set_size:
|
412 |
+
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
413 |
+
train_dataset = dataset["train"]
|
414 |
+
eval_dataset = dataset["test"]
|
415 |
+
else:
|
416 |
+
train_dataset = dataset
|
417 |
+
eval_dataset = None
|
418 |
|
419 |
return train_dataset, eval_dataset
|
420 |
+
|
421 |
+
|
422 |
+
def encode_pretraining(tokenizer, max_tokens, examples):
|
423 |
+
res = tokenizer(
|
424 |
+
examples["text"],
|
425 |
+
truncation=True,
|
426 |
+
max_length=max_tokens - 2,
|
427 |
+
add_special_tokens=True,
|
428 |
+
)
|
429 |
+
# Convert to PyTorch tensors
|
430 |
+
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
431 |
+
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
432 |
+
new_input_ids = []
|
433 |
+
new_attention_mask = []
|
434 |
+
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
435 |
+
for i, _ in enumerate(input_ids):
|
436 |
+
input_ids[i] = torch.cat(
|
437 |
+
(
|
438 |
+
input_ids[i],
|
439 |
+
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
440 |
+
),
|
441 |
+
dim=0,
|
442 |
+
)
|
443 |
+
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
444 |
+
|
445 |
+
# Concatenate tokens so that their lengths are less than max_tokens
|
446 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
447 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
448 |
+
|
449 |
+
for ids, mask in zip(input_ids, attention_mask):
|
450 |
+
if buffer_input_ids.numel() == max_tokens:
|
451 |
+
new_input_ids.append(buffer_input_ids)
|
452 |
+
new_attention_mask.append(buffer_attention_mask)
|
453 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
454 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
455 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
456 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
457 |
+
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
458 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
459 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
460 |
+
else:
|
461 |
+
buffer_input_ids = torch.cat(
|
462 |
+
(
|
463 |
+
buffer_input_ids,
|
464 |
+
torch.full(
|
465 |
+
(max_tokens - buffer_input_ids.numel(),),
|
466 |
+
tokenizer.pad_token_id,
|
467 |
+
dtype=torch.long,
|
468 |
+
),
|
469 |
+
),
|
470 |
+
dim=0,
|
471 |
+
)
|
472 |
+
buffer_attention_mask = torch.cat(
|
473 |
+
(
|
474 |
+
buffer_attention_mask,
|
475 |
+
torch.full(
|
476 |
+
(max_tokens - buffer_attention_mask.numel(),),
|
477 |
+
0,
|
478 |
+
dtype=torch.long,
|
479 |
+
),
|
480 |
+
),
|
481 |
+
dim=0,
|
482 |
+
)
|
483 |
+
new_input_ids.append(buffer_input_ids)
|
484 |
+
new_attention_mask.append(buffer_attention_mask)
|
485 |
+
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
486 |
+
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
487 |
+
|
488 |
+
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
489 |
+
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
490 |
+
|
491 |
+
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
492 |
+
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
493 |
+
buffer_input_ids = torch.cat(
|
494 |
+
(
|
495 |
+
buffer_input_ids,
|
496 |
+
torch.full(
|
497 |
+
(max_tokens - buffer_input_ids.numel(),),
|
498 |
+
tokenizer.pad_token_id,
|
499 |
+
dtype=torch.long,
|
500 |
+
),
|
501 |
+
),
|
502 |
+
dim=0,
|
503 |
+
)
|
504 |
+
buffer_attention_mask = torch.cat(
|
505 |
+
(
|
506 |
+
buffer_attention_mask,
|
507 |
+
torch.full(
|
508 |
+
(max_tokens - buffer_attention_mask.numel(),),
|
509 |
+
0,
|
510 |
+
dtype=torch.long,
|
511 |
+
),
|
512 |
+
),
|
513 |
+
dim=0,
|
514 |
+
)
|
515 |
+
new_input_ids.append(buffer_input_ids)
|
516 |
+
new_attention_mask.append(buffer_attention_mask)
|
517 |
+
|
518 |
+
ret = {
|
519 |
+
"input_ids": [seq.tolist() for seq in new_input_ids],
|
520 |
+
"labels": [seq.tolist() for seq in new_input_ids],
|
521 |
+
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
522 |
+
}
|
523 |
+
|
524 |
+
logging.debug(len(ret["input_ids"]))
|
525 |
+
return ret
|
526 |
+
|
527 |
+
|
528 |
+
def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
|
529 |
+
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
530 |
+
dataset = load_dataset(path, streaming=True, split="train")
|
531 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
532 |
+
# TODO dynamically figure out which columns/features to remove
|
533 |
+
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
|
534 |
+
return dataset
|
src/axolotl/utils/models.py
CHANGED
@@ -10,13 +10,15 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
|
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
-
from
|
14 |
from transformers import ( # noqa: F401
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
17 |
AutoTokenizer,
|
18 |
BitsAndBytesConfig,
|
19 |
LlamaConfig,
|
|
|
|
|
20 |
)
|
21 |
|
22 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
@@ -32,15 +34,20 @@ def load_tokenizer(
|
|
32 |
tokenizer_type,
|
33 |
cfg,
|
34 |
):
|
|
|
|
|
|
|
35 |
if tokenizer_type:
|
36 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
37 |
tokenizer_config,
|
38 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
39 |
)
|
40 |
else:
|
41 |
tokenizer = AutoTokenizer.from_pretrained(
|
42 |
tokenizer_config,
|
43 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
44 |
)
|
45 |
|
46 |
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
@@ -70,7 +77,7 @@ def load_tokenizer(
|
|
70 |
def load_model(
|
71 |
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
72 |
):
|
73 |
-
# type: (str, str, str,
|
74 |
"""
|
75 |
Load a model from a base model and a model type.
|
76 |
"""
|
@@ -121,9 +128,9 @@ def load_model(
|
|
121 |
logging.info("patching with xpos rope")
|
122 |
replace_llama_rope_with_xpos_rope()
|
123 |
|
124 |
-
if cfg.bf16:
|
125 |
torch_dtype = torch.bfloat16
|
126 |
-
elif cfg.load_in_8bit or cfg.fp16:
|
127 |
torch_dtype = torch.float16
|
128 |
else:
|
129 |
torch_dtype = torch.float32
|
@@ -195,7 +202,7 @@ def load_model(
|
|
195 |
else True,
|
196 |
)
|
197 |
load_in_8bit = False
|
198 |
-
elif cfg.is_llama_derived_model:
|
199 |
from transformers import LlamaForCausalLM
|
200 |
|
201 |
config = LlamaConfig.from_pretrained(base_model_config)
|
@@ -234,7 +241,7 @@ def load_model(
|
|
234 |
# device=cfg.device,
|
235 |
# )
|
236 |
# model.train() # sets to train instead of eval mode
|
237 |
-
elif model_type:
|
238 |
model = getattr(transformers, model_type).from_pretrained(
|
239 |
base_model,
|
240 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
@@ -251,11 +258,16 @@ def load_model(
|
|
251 |
)
|
252 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
253 |
# when training starts
|
254 |
-
if
|
|
|
|
|
|
|
|
|
255 |
config.max_seq_len = cfg.sequence_len
|
256 |
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
257 |
elif (
|
258 |
hasattr(config, "max_sequence_length")
|
|
|
259 |
and cfg.sequence_len > config.max_sequence_length
|
260 |
):
|
261 |
config.max_sequence_length = cfg.sequence_len
|
@@ -278,6 +290,7 @@ def load_model(
|
|
278 |
model = AutoModelForCausalLM.from_pretrained(
|
279 |
base_model,
|
280 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
281 |
torch_dtype=torch_dtype,
|
282 |
device_map=cfg.device_map,
|
283 |
trust_remote_code=cfg.trust_remote_code or False,
|
@@ -287,6 +300,16 @@ def load_model(
|
|
287 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
288 |
model.resize_token_embeddings(embeddings_len)
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
if not cfg.gptq and (
|
291 |
(cfg.adapter == "lora" and load_in_8bit)
|
292 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
@@ -332,6 +355,9 @@ def load_model(
|
|
332 |
logging.warning("there are no parameters that require gradient updates")
|
333 |
model.config.use_cache = False
|
334 |
|
|
|
|
|
|
|
335 |
# TODO resume_from_checkpoint handling
|
336 |
return model, lora_config
|
337 |
|
|
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
+
from optimum.bettertransformer import BetterTransformer
|
14 |
from transformers import ( # noqa: F401
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
17 |
AutoTokenizer,
|
18 |
BitsAndBytesConfig,
|
19 |
LlamaConfig,
|
20 |
+
PreTrainedModel,
|
21 |
+
PreTrainedTokenizerBase,
|
22 |
)
|
23 |
|
24 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN
|
|
|
34 |
tokenizer_type,
|
35 |
cfg,
|
36 |
):
|
37 |
+
use_fast = True # this is the default
|
38 |
+
if cfg.tokenizer_use_fast is not None:
|
39 |
+
use_fast = cfg.tokenizer_use_fast
|
40 |
if tokenizer_type:
|
41 |
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
42 |
tokenizer_config,
|
43 |
trust_remote_code=cfg.trust_remote_code or False,
|
44 |
+
use_fast=use_fast,
|
45 |
)
|
46 |
else:
|
47 |
tokenizer = AutoTokenizer.from_pretrained(
|
48 |
tokenizer_config,
|
49 |
trust_remote_code=cfg.trust_remote_code or False,
|
50 |
+
use_fast=use_fast,
|
51 |
)
|
52 |
|
53 |
logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
|
|
77 |
def load_model(
|
78 |
base_model, base_model_config, model_type, tokenizer, cfg, adapter="lora"
|
79 |
):
|
80 |
+
# type: (str, str, str, PreTrainedTokenizerBase, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
81 |
"""
|
82 |
Load a model from a base model and a model type.
|
83 |
"""
|
|
|
128 |
logging.info("patching with xpos rope")
|
129 |
replace_llama_rope_with_xpos_rope()
|
130 |
|
131 |
+
if cfg.bf16 or cfg.bfloat16:
|
132 |
torch_dtype = torch.bfloat16
|
133 |
+
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
134 |
torch_dtype = torch.float16
|
135 |
else:
|
136 |
torch_dtype = torch.float32
|
|
|
202 |
else True,
|
203 |
)
|
204 |
load_in_8bit = False
|
205 |
+
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
206 |
from transformers import LlamaForCausalLM
|
207 |
|
208 |
config = LlamaConfig.from_pretrained(base_model_config)
|
|
|
241 |
# device=cfg.device,
|
242 |
# )
|
243 |
# model.train() # sets to train instead of eval mode
|
244 |
+
elif model_type and not cfg.trust_remote_code:
|
245 |
model = getattr(transformers, model_type).from_pretrained(
|
246 |
base_model,
|
247 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
|
|
258 |
)
|
259 |
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
260 |
# when training starts
|
261 |
+
if (
|
262 |
+
hasattr(config, "max_seq_len")
|
263 |
+
and config.max_seq_len
|
264 |
+
and cfg.sequence_len > config.max_seq_len
|
265 |
+
):
|
266 |
config.max_seq_len = cfg.sequence_len
|
267 |
logging.warning(f"increasing context length to {cfg.sequence_len}")
|
268 |
elif (
|
269 |
hasattr(config, "max_sequence_length")
|
270 |
+
and config.max_sequence_length
|
271 |
and cfg.sequence_len > config.max_sequence_length
|
272 |
):
|
273 |
config.max_sequence_length = cfg.sequence_len
|
|
|
290 |
model = AutoModelForCausalLM.from_pretrained(
|
291 |
base_model,
|
292 |
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
293 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
294 |
torch_dtype=torch_dtype,
|
295 |
device_map=cfg.device_map,
|
296 |
trust_remote_code=cfg.trust_remote_code or False,
|
|
|
300 |
embeddings_len = math.ceil(len(tokenizer) / 32) * 32
|
301 |
model.resize_token_embeddings(embeddings_len)
|
302 |
|
303 |
+
if (
|
304 |
+
hasattr(model.config, "max_position_embeddings")
|
305 |
+
and model.config.max_position_embeddings
|
306 |
+
and cfg.sequence_len >= model.config.max_position_embeddings
|
307 |
+
):
|
308 |
+
logging.warning(
|
309 |
+
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
310 |
+
)
|
311 |
+
model.config.max_position_embeddings = cfg.sequence_len
|
312 |
+
|
313 |
if not cfg.gptq and (
|
314 |
(cfg.adapter == "lora" and load_in_8bit)
|
315 |
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
|
|
355 |
logging.warning("there are no parameters that require gradient updates")
|
356 |
model.config.use_cache = False
|
357 |
|
358 |
+
if cfg.flash_optimum:
|
359 |
+
model = BetterTransformer.transform(model)
|
360 |
+
|
361 |
# TODO resume_from_checkpoint handling
|
362 |
return model, lora_config
|
363 |
|
src/axolotl/utils/tokenization.py
CHANGED
@@ -34,3 +34,5 @@ def check_example_labels(example, tokenizer):
|
|
34 |
|
35 |
logging.info(" ".join(colored_tokens))
|
36 |
logging.info("\n\n\n")
|
|
|
|
|
|
34 |
|
35 |
logging.info(" ".join(colored_tokens))
|
36 |
logging.info("\n\n\n")
|
37 |
+
|
38 |
+
return " ".join(colored_tokens)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -17,7 +17,10 @@ from torch.optim.lr_scheduler import OneCycleLR
|
|
17 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
18 |
from transformers.trainer_pt_utils import get_parameter_names
|
19 |
|
20 |
-
from axolotl.utils.callbacks import
|
|
|
|
|
|
|
21 |
from axolotl.utils.schedulers import (
|
22 |
InterpolatingLogScheduler,
|
23 |
get_cosine_schedule_with_quadratic_warmup,
|
@@ -166,6 +169,19 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
166 |
# TODO search Path("./") for one
|
167 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
training_args = AxolotlTrainingArguments(
|
170 |
per_device_train_batch_size=cfg.micro_batch_size,
|
171 |
per_device_eval_batch_size=cfg.eval_batch_size
|
@@ -282,6 +298,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
282 |
]: # only save in rank 0
|
283 |
callbacks.append(SavePeftModelCallback)
|
284 |
|
|
|
|
|
|
|
285 |
data_collator_kwargs = {
|
286 |
"padding": True,
|
287 |
}
|
|
|
17 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
18 |
from transformers.trainer_pt_utils import get_parameter_names
|
19 |
|
20 |
+
from axolotl.utils.callbacks import (
|
21 |
+
SaveBetterTransformerModelCallback,
|
22 |
+
SavePeftModelCallback,
|
23 |
+
)
|
24 |
from axolotl.utils.schedulers import (
|
25 |
InterpolatingLogScheduler,
|
26 |
get_cosine_schedule_with_quadratic_warmup,
|
|
|
169 |
# TODO search Path("./") for one
|
170 |
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
|
171 |
|
172 |
+
if cfg.adam_beta1:
|
173 |
+
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
174 |
+
if cfg.adam_beta2:
|
175 |
+
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
176 |
+
if cfg.adam_epsilon:
|
177 |
+
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
178 |
+
if cfg.max_grad_norm:
|
179 |
+
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
180 |
+
|
181 |
+
if cfg.hub_model_id:
|
182 |
+
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
183 |
+
training_arguments_kwargs["push_to_hub"] = True
|
184 |
+
|
185 |
training_args = AxolotlTrainingArguments(
|
186 |
per_device_train_batch_size=cfg.micro_batch_size,
|
187 |
per_device_eval_batch_size=cfg.eval_batch_size
|
|
|
298 |
]: # only save in rank 0
|
299 |
callbacks.append(SavePeftModelCallback)
|
300 |
|
301 |
+
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
302 |
+
callbacks.append(SaveBetterTransformerModelCallback)
|
303 |
+
|
304 |
data_collator_kwargs = {
|
305 |
"padding": True,
|
306 |
}
|
src/axolotl/utils/validation.py
CHANGED
@@ -2,6 +2,8 @@
|
|
2 |
|
3 |
import logging
|
4 |
|
|
|
|
|
5 |
|
6 |
def validate_config(cfg):
|
7 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
@@ -62,7 +64,47 @@ def validate_config(cfg):
|
|
62 |
) and cfg.gradient_checkpointing:
|
63 |
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
# TODO
|
66 |
# MPT 7b
|
67 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
68 |
-
# no 8bit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import logging
|
4 |
|
5 |
+
import torch
|
6 |
+
|
7 |
|
8 |
def validate_config(cfg):
|
9 |
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
|
|
64 |
) and cfg.gradient_checkpointing:
|
65 |
raise ValueError("gradient_checkpointing is not supported for MPT models")
|
66 |
|
67 |
+
if cfg.flash_optimum is True:
|
68 |
+
if cfg.adapter:
|
69 |
+
logging.warning(
|
70 |
+
"BetterTransformers probably doesn't work with PEFT adapters"
|
71 |
+
)
|
72 |
+
if cfg.fp16 or cfg.bf16:
|
73 |
+
raise ValueError("AMP is not supported with BetterTransformer")
|
74 |
+
if cfg.float16 is not True and cfg.bloat16 is not True:
|
75 |
+
logging.warning(
|
76 |
+
"You should probably set bfloat16 or float16 to true to "
|
77 |
+
"load the model in float16 for BetterTransformers"
|
78 |
+
)
|
79 |
+
if int(torch.__version__.split(".")[0]) < 2:
|
80 |
+
logging.warning("torch>=2.0.0 required")
|
81 |
+
raise ValueError(
|
82 |
+
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
|
83 |
+
)
|
84 |
+
|
85 |
+
if cfg.pretraining_dataset and cfg.group_by_length:
|
86 |
+
logging.warning(
|
87 |
+
"You probably want to disable group_by_length as it will force a streamed dataset to download completely."
|
88 |
+
)
|
89 |
+
|
90 |
+
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
|
91 |
+
not cfg.optimizer or "adamw" not in cfg.optimizer
|
92 |
+
):
|
93 |
+
logging.warning("adamw hyperparameters found, but no adamw optimizer set")
|
94 |
+
|
95 |
+
if cfg.push_to_hub_model_id:
|
96 |
+
raise ValueError(
|
97 |
+
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
98 |
+
)
|
99 |
+
|
100 |
# TODO
|
101 |
# MPT 7b
|
102 |
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
103 |
+
# no 8bit adaAmw w bf16
|
104 |
+
|
105 |
+
# GPT-NeoX
|
106 |
+
# evals broken when extending context len
|
107 |
+
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
108 |
+
# File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
|
109 |
+
# attention_mask = causal_mask + attention_mask
|
110 |
+
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
|
tests/test_prompt_tokenizers.py
CHANGED
@@ -6,8 +6,16 @@ from pathlib import Path
|
|
6 |
|
7 |
from transformers import AutoTokenizer
|
8 |
|
9 |
-
from axolotl.
|
10 |
-
from axolotl.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
logging.basicConfig(level="INFO")
|
13 |
|
@@ -29,7 +37,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
29 |
)
|
30 |
|
31 |
def test_sharegpt_integration(self):
|
32 |
-
print(Path(__file__).parent)
|
33 |
with open(
|
34 |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
35 |
) as fin:
|
@@ -53,6 +60,79 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|
53 |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
54 |
self.assertEqual(example[fields], tokenized_conversation[fields])
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
if __name__ == "__main__":
|
58 |
unittest.main()
|
|
|
6 |
|
7 |
from transformers import AutoTokenizer
|
8 |
|
9 |
+
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
10 |
+
from axolotl.prompt_strategies.alpaca_w_system import (
|
11 |
+
InstructionWSystemPromptTokenizingStrategy,
|
12 |
+
SystemDataPrompter,
|
13 |
+
)
|
14 |
+
from axolotl.prompt_tokenizers import (
|
15 |
+
AlpacaPromptTokenizingStrategy,
|
16 |
+
ShareGPTPromptTokenizingStrategy,
|
17 |
+
)
|
18 |
+
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
19 |
|
20 |
logging.basicConfig(level="INFO")
|
21 |
|
|
|
37 |
)
|
38 |
|
39 |
def test_sharegpt_integration(self):
|
|
|
40 |
with open(
|
41 |
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
42 |
) as fin:
|
|
|
60 |
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
61 |
self.assertEqual(example[fields], tokenized_conversation[fields])
|
62 |
|
63 |
+
def test_no_sys_prompt(self):
|
64 |
+
"""
|
65 |
+
tests the interface between the user and assistant parts
|
66 |
+
"""
|
67 |
+
prompter = NoSystemPrompter()
|
68 |
+
# pylint: disable=duplicate-code
|
69 |
+
strat = AlpacaPromptTokenizingStrategy(
|
70 |
+
prompter,
|
71 |
+
self.tokenizer,
|
72 |
+
False,
|
73 |
+
2048,
|
74 |
+
)
|
75 |
+
sample = {
|
76 |
+
"instruction": "hello cruel. lorem ipsum dolor sit amet.",
|
77 |
+
"output": "world!",
|
78 |
+
}
|
79 |
+
example = strat.tokenize_prompt(sample)
|
80 |
+
world_idx = example["input_ids"].index(3186)
|
81 |
+
assert example["labels"][world_idx] == 3186
|
82 |
+
assert example["labels"][world_idx - 1] == -100
|
83 |
+
|
84 |
+
def test_alpaca(self):
|
85 |
+
"""
|
86 |
+
tests the interface between the user and assistant parts
|
87 |
+
"""
|
88 |
+
# pylint: disable=duplicate-code
|
89 |
+
prompter = AlpacaPrompter()
|
90 |
+
strat = AlpacaPromptTokenizingStrategy(
|
91 |
+
prompter,
|
92 |
+
self.tokenizer,
|
93 |
+
False,
|
94 |
+
2048,
|
95 |
+
)
|
96 |
+
sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
|
97 |
+
example = strat.tokenize_prompt(sample)
|
98 |
+
world_idx = example["input_ids"].index(6324)
|
99 |
+
assert example["labels"][world_idx] == 6324
|
100 |
+
assert example["labels"][world_idx - 1] == -100
|
101 |
+
|
102 |
+
|
103 |
+
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
104 |
+
"""
|
105 |
+
Test class for prompt tokenization strategies with sys prompt from the dataset
|
106 |
+
"""
|
107 |
+
|
108 |
+
def setUp(self) -> None:
|
109 |
+
# pylint: disable=duplicate-code
|
110 |
+
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
111 |
+
self.tokenizer.add_special_tokens(
|
112 |
+
{
|
113 |
+
"bos_token": "<s>",
|
114 |
+
"eos_token": "</s>",
|
115 |
+
"unk_token": "<unk>",
|
116 |
+
}
|
117 |
+
)
|
118 |
+
|
119 |
+
def test_system_alpaca(self):
|
120 |
+
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
121 |
+
strat = InstructionWSystemPromptTokenizingStrategy(
|
122 |
+
prompter,
|
123 |
+
self.tokenizer,
|
124 |
+
False,
|
125 |
+
2048,
|
126 |
+
)
|
127 |
+
sample = {
|
128 |
+
"system": "use cot",
|
129 |
+
"instruction": "hello!",
|
130 |
+
"output": "Hi! How can I help?",
|
131 |
+
}
|
132 |
+
example = strat.tokenize_prompt(sample)
|
133 |
+
assert example["input_ids"][0:3] == [1, 671, 20118] # <s>use cot
|
134 |
+
assert example["input_ids"][3] == 11889 # USER
|
135 |
+
|
136 |
|
137 |
if __name__ == "__main__":
|
138 |
unittest.main()
|
tests/test_prompters.py
CHANGED
@@ -2,7 +2,13 @@
|
|
2 |
|
3 |
import unittest
|
4 |
|
5 |
-
from axolotl.
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
class AlpacaPrompterTest(unittest.TestCase):
|
@@ -55,3 +61,64 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|
55 |
assert "### Response:" not in res
|
56 |
assert "USER:" in res
|
57 |
assert "ASSISTANT:" in res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import unittest
|
4 |
|
5 |
+
from axolotl.prompt_strategies.alpaca_w_system import SystemDataPrompter
|
6 |
+
from axolotl.prompters import (
|
7 |
+
AlpacaPrompter,
|
8 |
+
MultipleChoiceExplainPrompter,
|
9 |
+
PromptStyle,
|
10 |
+
UnpromptedPrompter,
|
11 |
+
)
|
12 |
|
13 |
|
14 |
class AlpacaPrompterTest(unittest.TestCase):
|
|
|
61 |
assert "### Response:" not in res
|
62 |
assert "USER:" in res
|
63 |
assert "ASSISTANT:" in res
|
64 |
+
|
65 |
+
def test_system_prompt(self):
|
66 |
+
prompter = SystemDataPrompter(prompt_style=PromptStyle.CHAT.value)
|
67 |
+
res = next(
|
68 |
+
prompter.build_prompt_w_system(
|
69 |
+
"use cot", "tell me a joke about the following", "alpacas"
|
70 |
+
)
|
71 |
+
)
|
72 |
+
assert "use cot" in res
|
73 |
+
assert res.startswith("use cot")
|
74 |
+
assert "### Instruction:" not in res
|
75 |
+
assert "### Input:" not in res
|
76 |
+
assert "alpacas" in res
|
77 |
+
assert "### Response:" not in res
|
78 |
+
assert "USER:" in res
|
79 |
+
assert "ASSISTANT:" in res
|
80 |
+
|
81 |
+
|
82 |
+
class UnpromptedPrompterTest(unittest.TestCase):
|
83 |
+
"""
|
84 |
+
Test class for UnpromptedPrompter with no system prompts
|
85 |
+
"""
|
86 |
+
|
87 |
+
def test_prompt_style_w_none(self):
|
88 |
+
prompter = UnpromptedPrompter(prompt_style=None)
|
89 |
+
res = next(prompter.build_prompt("tell me a joke"))
|
90 |
+
assert "### Instruction:" in res
|
91 |
+
assert "tell me a joke" in res
|
92 |
+
assert res.startswith("###")
|
93 |
+
|
94 |
+
def test_prompt_style_w_instruct(self):
|
95 |
+
prompter = UnpromptedPrompter(prompt_style=PromptStyle.INSTRUCT.value)
|
96 |
+
res = next(
|
97 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
98 |
+
)
|
99 |
+
assert "### Instruction:" in res
|
100 |
+
assert "tell me a joke" in res
|
101 |
+
assert res.startswith("###")
|
102 |
+
|
103 |
+
def test_prompt_style_w_chat(self):
|
104 |
+
prompter = UnpromptedPrompter(prompt_style=PromptStyle.CHAT.value)
|
105 |
+
res = next(
|
106 |
+
prompter.build_prompt("tell me a joke about the following", "alpacas")
|
107 |
+
)
|
108 |
+
assert "USER:" in res
|
109 |
+
assert "tell me a joke" in res
|
110 |
+
assert res.startswith("USER:")
|
111 |
+
|
112 |
+
|
113 |
+
class MultipleChoiceExplainPrompterTest(unittest.TestCase):
|
114 |
+
"""
|
115 |
+
Test class for MultipleChoiceExplainPrompter
|
116 |
+
"""
|
117 |
+
|
118 |
+
def test_prompt_style_w_chat(self):
|
119 |
+
prompter = MultipleChoiceExplainPrompter(prompt_style=PromptStyle.CHAT.value)
|
120 |
+
res = next(prompter.build_prompt("choose one", "- A\n- B\n- C", "C"))
|
121 |
+
assert "USER:" in res
|
122 |
+
assert "choose one" in res
|
123 |
+
assert "Choose the answer that best answers the question." in res
|
124 |
+
assert "- A\n- B\n- C" in res
|
tests/test_tokenizers.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Test cases for the tokenizer loading
|
3 |
+
"""
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
from axolotl.utils.dict import DictDefault
|
7 |
+
from axolotl.utils.models import load_tokenizer
|
8 |
+
|
9 |
+
|
10 |
+
class TestTokenizers(unittest.TestCase):
|
11 |
+
"""
|
12 |
+
test class for the load_tokenizer fn
|
13 |
+
"""
|
14 |
+
|
15 |
+
def test_default_use_fast(self):
|
16 |
+
cfg = DictDefault({})
|
17 |
+
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
18 |
+
assert "Fast" in tokenizer.__class__.__name__
|
19 |
+
|
20 |
+
def test_dont_use_fast(self):
|
21 |
+
cfg = DictDefault(
|
22 |
+
{
|
23 |
+
"tokenizer_use_fast": False,
|
24 |
+
}
|
25 |
+
)
|
26 |
+
tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
|
27 |
+
assert "Fast" not in tokenizer.__class__.__name__
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
unittest.main()
|
tests/test_validation.py
CHANGED
@@ -212,3 +212,104 @@ class ValidationTest(unittest.TestCase):
|
|
212 |
|
213 |
with pytest.raises(ValueError, match=regex_exp):
|
214 |
validate_config(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
with pytest.raises(ValueError, match=regex_exp):
|
214 |
validate_config(cfg)
|
215 |
+
|
216 |
+
def test_flash_optimum(self):
|
217 |
+
cfg = DictDefault(
|
218 |
+
{
|
219 |
+
"flash_optimum": True,
|
220 |
+
"adapter": "lora",
|
221 |
+
}
|
222 |
+
)
|
223 |
+
|
224 |
+
with self._caplog.at_level(logging.WARNING):
|
225 |
+
validate_config(cfg)
|
226 |
+
assert any(
|
227 |
+
"BetterTransformers probably doesn't work with PEFT adapters"
|
228 |
+
in record.message
|
229 |
+
for record in self._caplog.records
|
230 |
+
)
|
231 |
+
|
232 |
+
cfg = DictDefault(
|
233 |
+
{
|
234 |
+
"flash_optimum": True,
|
235 |
+
}
|
236 |
+
)
|
237 |
+
|
238 |
+
with self._caplog.at_level(logging.WARNING):
|
239 |
+
validate_config(cfg)
|
240 |
+
assert any(
|
241 |
+
"probably set bfloat16 or float16" in record.message
|
242 |
+
for record in self._caplog.records
|
243 |
+
)
|
244 |
+
|
245 |
+
cfg = DictDefault(
|
246 |
+
{
|
247 |
+
"flash_optimum": True,
|
248 |
+
"fp16": True,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
regex_exp = r".*AMP is not supported.*"
|
252 |
+
|
253 |
+
with pytest.raises(ValueError, match=regex_exp):
|
254 |
+
validate_config(cfg)
|
255 |
+
|
256 |
+
cfg = DictDefault(
|
257 |
+
{
|
258 |
+
"flash_optimum": True,
|
259 |
+
"bf16": True,
|
260 |
+
}
|
261 |
+
)
|
262 |
+
regex_exp = r".*AMP is not supported.*"
|
263 |
+
|
264 |
+
with pytest.raises(ValueError, match=regex_exp):
|
265 |
+
validate_config(cfg)
|
266 |
+
|
267 |
+
def test_adamw_hyperparams(self):
|
268 |
+
cfg = DictDefault(
|
269 |
+
{
|
270 |
+
"optimizer": None,
|
271 |
+
"adam_epsilon": 0.0001,
|
272 |
+
}
|
273 |
+
)
|
274 |
+
|
275 |
+
with self._caplog.at_level(logging.WARNING):
|
276 |
+
validate_config(cfg)
|
277 |
+
assert any(
|
278 |
+
"adamw hyperparameters found, but no adamw optimizer set"
|
279 |
+
in record.message
|
280 |
+
for record in self._caplog.records
|
281 |
+
)
|
282 |
+
|
283 |
+
cfg = DictDefault(
|
284 |
+
{
|
285 |
+
"optimizer": "adafactor",
|
286 |
+
"adam_beta1": 0.0001,
|
287 |
+
}
|
288 |
+
)
|
289 |
+
|
290 |
+
with self._caplog.at_level(logging.WARNING):
|
291 |
+
validate_config(cfg)
|
292 |
+
assert any(
|
293 |
+
"adamw hyperparameters found, but no adamw optimizer set"
|
294 |
+
in record.message
|
295 |
+
for record in self._caplog.records
|
296 |
+
)
|
297 |
+
|
298 |
+
cfg = DictDefault(
|
299 |
+
{
|
300 |
+
"optimizer": "adamw_bnb_8bit",
|
301 |
+
"adam_beta1": 0.9,
|
302 |
+
"adam_beta2": 0.99,
|
303 |
+
"adam_epsilon": 0.0001,
|
304 |
+
}
|
305 |
+
)
|
306 |
+
|
307 |
+
validate_config(cfg)
|
308 |
+
|
309 |
+
cfg = DictDefault(
|
310 |
+
{
|
311 |
+
"optimizer": "adafactor",
|
312 |
+
}
|
313 |
+
)
|
314 |
+
|
315 |
+
validate_config(cfg)
|