bump transformers and update attention class map name (#1023)
Browse files* bump transformers and update attention class map name
* also run the tests in docker
* add mixtral e2e smoke test
* fix base name for docker image in test
* mixtral lora doesn't seem to work, at least check qlora
* add testcase for mixtral w sample packing
* check monkeypatch for flash attn multipack
* also run the e2e tests in docker
* use all gpus to run tests in docker ci
* use privileged mode too for docker w gpus
* rename the docker e2e actions for gh ci
* set privileged mode for docker and update mixtral model self attn check
* use fp16/bf16 for mixtral w fa2
* skip e2e tests on docker w gpus for now
* tests to validate mistral and mixtral patches
* fix rel import
- .github/workflows/tests-docker.yml +62 -0
- requirements.txt +1 -1
- src/axolotl/monkeypatch/mixtral/__init__.py +1 -1
- src/axolotl/monkeypatch/mixtral/modeling_mixtral.py +6 -2
- src/axolotl/utils/models.py +3 -0
- tests/e2e/test_mixtral.py +109 -0
- tests/e2e/test_mixtral_samplepack.py +123 -0
- tests/e2e/test_model_patches.py +99 -0
.github/workflows/tests-docker.yml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: e2e-docker-tests
|
2 |
+
|
3 |
+
on:
|
4 |
+
pull_request:
|
5 |
+
paths:
|
6 |
+
- '**.py'
|
7 |
+
- 'requirements.txt'
|
8 |
+
workflow_dispatch:
|
9 |
+
|
10 |
+
jobs:
|
11 |
+
build-axolotl:
|
12 |
+
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
13 |
+
# this job needs to be run on self-hosted GPU runners...
|
14 |
+
strategy:
|
15 |
+
fail-fast: false
|
16 |
+
matrix:
|
17 |
+
include:
|
18 |
+
- cuda: 118
|
19 |
+
cuda_version: 11.8.0
|
20 |
+
python_version: "3.10"
|
21 |
+
pytorch: 2.0.1
|
22 |
+
axolotl_extras:
|
23 |
+
is_latest: true
|
24 |
+
- cuda: 121
|
25 |
+
cuda_version: 12.1.0
|
26 |
+
python_version: "3.10"
|
27 |
+
pytorch: 2.1.1
|
28 |
+
axolotl_extras:
|
29 |
+
runs-on: [self-hosted, gpu, docker]
|
30 |
+
steps:
|
31 |
+
- name: Checkout
|
32 |
+
uses: actions/checkout@v4
|
33 |
+
- name: Docker metadata
|
34 |
+
id: metadata
|
35 |
+
uses: docker/metadata-action@v5
|
36 |
+
with:
|
37 |
+
images: winglian/axolotl
|
38 |
+
- name: Set up Docker Buildx
|
39 |
+
uses: docker/setup-buildx-action@v3
|
40 |
+
- name: Login to Docker Hub
|
41 |
+
uses: docker/login-action@v3
|
42 |
+
with:
|
43 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
44 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
45 |
+
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
46 |
+
- name: Build and export to Docker
|
47 |
+
uses: docker/build-push-action@v5
|
48 |
+
with:
|
49 |
+
context: .
|
50 |
+
load: true
|
51 |
+
build-args: |
|
52 |
+
BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
53 |
+
CUDA=${{ matrix.cuda }}
|
54 |
+
PYTORCH_VERSION=${{ matrix.pytorch }}
|
55 |
+
file: ./docker/Dockerfile
|
56 |
+
tags: |
|
57 |
+
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
58 |
+
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
59 |
+
labels: ${{ steps.metadata.outputs.labels }}
|
60 |
+
- name: Unit Tests
|
61 |
+
run: |
|
62 |
+
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
requirements.txt
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
auto-gptq==0.5.1
|
3 |
packaging
|
4 |
peft==0.6.0
|
5 |
-
transformers
|
6 |
tokenizers==0.15.0
|
7 |
bitsandbytes>=0.41.1
|
8 |
accelerate==0.24.1
|
|
|
2 |
auto-gptq==0.5.1
|
3 |
packaging
|
4 |
peft==0.6.0
|
5 |
+
transformers @ git+https://github.com/huggingface/transformers.git@3cefac1d974db5e2825a0cb2b842883a628be7a0
|
6 |
tokenizers==0.15.0
|
7 |
bitsandbytes>=0.41.1
|
8 |
accelerate==0.24.1
|
src/axolotl/monkeypatch/mixtral/__init__.py
CHANGED
@@ -17,6 +17,6 @@ def replace_mixtral_attn_with_multipack_flash_attn():
|
|
17 |
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
18 |
mixtral_model_forward
|
19 |
)
|
20 |
-
transformers.models.mixtral.modeling_mixtral.
|
21 |
"flash_attention_2"
|
22 |
] = MixtralMultipackFlashAttention2
|
|
|
17 |
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
18 |
mixtral_model_forward
|
19 |
)
|
20 |
+
transformers.models.mixtral.modeling_mixtral.MIXTRAL_ATTENTION_CLASSES[
|
21 |
"flash_attention_2"
|
22 |
] = MixtralMultipackFlashAttention2
|
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
CHANGED
@@ -261,7 +261,11 @@ def mixtral_model_forward(
|
|
261 |
if inputs_embeds is None:
|
262 |
inputs_embeds = self.embed_tokens(input_ids)
|
263 |
|
264 |
-
if
|
|
|
|
|
|
|
|
|
265 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
266 |
if is_padding_right:
|
267 |
raise ValueError(
|
@@ -270,7 +274,7 @@ def mixtral_model_forward(
|
|
270 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
271 |
)
|
272 |
|
273 |
-
if self.
|
274 |
# 2d mask is passed through the layers
|
275 |
attention_mask = (
|
276 |
attention_mask
|
|
|
261 |
if inputs_embeds is None:
|
262 |
inputs_embeds = self.embed_tokens(input_ids)
|
263 |
|
264 |
+
if (
|
265 |
+
attention_mask is not None
|
266 |
+
and self._attn_implementation == "flash_attention_2"
|
267 |
+
and use_cache
|
268 |
+
):
|
269 |
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
270 |
if is_padding_right:
|
271 |
raise ValueError(
|
|
|
274 |
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
275 |
)
|
276 |
|
277 |
+
if self._attn_implementation == "flash_attention_2":
|
278 |
# 2d mask is passed through the layers
|
279 |
attention_mask = (
|
280 |
attention_mask
|
src/axolotl/utils/models.py
CHANGED
@@ -332,15 +332,18 @@ def load_model(
|
|
332 |
or cfg.is_mistral_derived_model
|
333 |
or model_config.model_type == "mixtral"
|
334 |
):
|
|
|
335 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
336 |
"flash_attention_2"
|
337 |
)
|
338 |
else:
|
339 |
if model_config.model_type == "mixtral":
|
|
|
340 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
341 |
"flash_attention_2"
|
342 |
)
|
343 |
else:
|
|
|
344 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
345 |
"eager"
|
346 |
)
|
|
|
332 |
or cfg.is_mistral_derived_model
|
333 |
or model_config.model_type == "mixtral"
|
334 |
):
|
335 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
336 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
337 |
"flash_attention_2"
|
338 |
)
|
339 |
else:
|
340 |
if model_config.model_type == "mixtral":
|
341 |
+
model_kwargs["attn_implementation"] = "flash_attention_2"
|
342 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
343 |
"flash_attention_2"
|
344 |
)
|
345 |
else:
|
346 |
+
model_kwargs["attn_implementation"] = "eager"
|
347 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
348 |
"eager"
|
349 |
)
|
tests/e2e/test_mixtral.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E tests for mixtral
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
11 |
+
|
12 |
+
from axolotl.cli import load_datasets
|
13 |
+
from axolotl.common.cli import TrainerCliArgs
|
14 |
+
from axolotl.train import train
|
15 |
+
from axolotl.utils.config import normalize_config
|
16 |
+
from axolotl.utils.dict import DictDefault
|
17 |
+
|
18 |
+
from .utils import with_temp_dir
|
19 |
+
|
20 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
+
os.environ["WANDB_DISABLED"] = "true"
|
22 |
+
|
23 |
+
|
24 |
+
class TestMixtral(unittest.TestCase):
|
25 |
+
"""
|
26 |
+
Test case for Llama models using LoRA
|
27 |
+
"""
|
28 |
+
|
29 |
+
@with_temp_dir
|
30 |
+
def test_qlora(self, temp_dir):
|
31 |
+
# pylint: disable=duplicate-code
|
32 |
+
cfg = DictDefault(
|
33 |
+
{
|
34 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
35 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
36 |
+
"flash_attention": True,
|
37 |
+
"sequence_len": 1024,
|
38 |
+
"load_in_4bit": True,
|
39 |
+
"adapter": "qlora",
|
40 |
+
"lora_r": 16,
|
41 |
+
"lora_alpha": 32,
|
42 |
+
"lora_dropout": 0.1,
|
43 |
+
"lora_target_linear": True,
|
44 |
+
"val_set_size": 0.1,
|
45 |
+
"special_tokens": {},
|
46 |
+
"datasets": [
|
47 |
+
{
|
48 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
49 |
+
"type": "alpaca",
|
50 |
+
},
|
51 |
+
],
|
52 |
+
"num_epochs": 2,
|
53 |
+
"micro_batch_size": 2,
|
54 |
+
"gradient_accumulation_steps": 1,
|
55 |
+
"output_dir": temp_dir,
|
56 |
+
"learning_rate": 0.00001,
|
57 |
+
"optimizer": "adamw_bnb_8bit",
|
58 |
+
"lr_scheduler": "cosine",
|
59 |
+
"max_steps": 20,
|
60 |
+
"save_steps": 10,
|
61 |
+
"eval_steps": 10,
|
62 |
+
}
|
63 |
+
)
|
64 |
+
normalize_config(cfg)
|
65 |
+
cli_args = TrainerCliArgs()
|
66 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
67 |
+
|
68 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
69 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
70 |
+
|
71 |
+
@with_temp_dir
|
72 |
+
def test_ft(self, temp_dir):
|
73 |
+
# pylint: disable=duplicate-code
|
74 |
+
cfg = DictDefault(
|
75 |
+
{
|
76 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
77 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
78 |
+
"flash_attention": True,
|
79 |
+
"sequence_len": 1024,
|
80 |
+
"val_set_size": 0.1,
|
81 |
+
"special_tokens": {},
|
82 |
+
"datasets": [
|
83 |
+
{
|
84 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
85 |
+
"type": "alpaca",
|
86 |
+
},
|
87 |
+
],
|
88 |
+
"num_epochs": 2,
|
89 |
+
"micro_batch_size": 2,
|
90 |
+
"gradient_accumulation_steps": 1,
|
91 |
+
"output_dir": temp_dir,
|
92 |
+
"learning_rate": 0.00001,
|
93 |
+
"optimizer": "adamw_bnb_8bit",
|
94 |
+
"lr_scheduler": "cosine",
|
95 |
+
"max_steps": 20,
|
96 |
+
"save_steps": 10,
|
97 |
+
"eval_steps": 10,
|
98 |
+
}
|
99 |
+
)
|
100 |
+
if is_torch_bf16_gpu_available():
|
101 |
+
cfg.bf16 = True
|
102 |
+
else:
|
103 |
+
cfg.fp16 = True
|
104 |
+
normalize_config(cfg)
|
105 |
+
cli_args = TrainerCliArgs()
|
106 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
107 |
+
|
108 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
109 |
+
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
tests/e2e/test_mixtral_samplepack.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E tests for mixtral
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import unittest
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
11 |
+
|
12 |
+
from axolotl.cli import load_datasets
|
13 |
+
from axolotl.common.cli import TrainerCliArgs
|
14 |
+
from axolotl.train import train
|
15 |
+
from axolotl.utils.config import normalize_config
|
16 |
+
from axolotl.utils.dict import DictDefault
|
17 |
+
|
18 |
+
from .utils import with_temp_dir
|
19 |
+
|
20 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
21 |
+
os.environ["WANDB_DISABLED"] = "true"
|
22 |
+
|
23 |
+
|
24 |
+
class TestMixtral(unittest.TestCase):
|
25 |
+
"""
|
26 |
+
Test case for Llama models using LoRA
|
27 |
+
"""
|
28 |
+
|
29 |
+
@with_temp_dir
|
30 |
+
def test_qlora(self, temp_dir):
|
31 |
+
# pylint: disable=duplicate-code
|
32 |
+
cfg = DictDefault(
|
33 |
+
{
|
34 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
35 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
36 |
+
"flash_attention": True,
|
37 |
+
"sequence_len": 2048,
|
38 |
+
"load_in_4bit": True,
|
39 |
+
"adapter": "qlora",
|
40 |
+
"lora_r": 16,
|
41 |
+
"lora_alpha": 32,
|
42 |
+
"lora_dropout": 0.1,
|
43 |
+
"lora_target_linear": True,
|
44 |
+
"val_set_size": 0.1,
|
45 |
+
"special_tokens": {},
|
46 |
+
"datasets": [
|
47 |
+
{
|
48 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
49 |
+
"type": "alpaca",
|
50 |
+
},
|
51 |
+
],
|
52 |
+
"num_epochs": 2,
|
53 |
+
"micro_batch_size": 2,
|
54 |
+
"gradient_accumulation_steps": 1,
|
55 |
+
"output_dir": temp_dir,
|
56 |
+
"learning_rate": 0.00001,
|
57 |
+
"optimizer": "adamw_bnb_8bit",
|
58 |
+
"lr_scheduler": "cosine",
|
59 |
+
"max_steps": 20,
|
60 |
+
"save_steps": 10,
|
61 |
+
"eval_steps": 10,
|
62 |
+
"sample_packing": True,
|
63 |
+
}
|
64 |
+
)
|
65 |
+
if is_torch_bf16_gpu_available():
|
66 |
+
cfg.bf16 = True
|
67 |
+
else:
|
68 |
+
cfg.fp16 = True
|
69 |
+
normalize_config(cfg)
|
70 |
+
cli_args = TrainerCliArgs()
|
71 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
72 |
+
|
73 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
74 |
+
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
75 |
+
|
76 |
+
@with_temp_dir
|
77 |
+
def test_ft(self, temp_dir):
|
78 |
+
# pylint: disable=duplicate-code
|
79 |
+
cfg = DictDefault(
|
80 |
+
{
|
81 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
82 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
83 |
+
"flash_attention": True,
|
84 |
+
"sequence_len": 2048,
|
85 |
+
"val_set_size": 0.1,
|
86 |
+
"special_tokens": {},
|
87 |
+
"datasets": [
|
88 |
+
{
|
89 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
90 |
+
"type": "alpaca",
|
91 |
+
},
|
92 |
+
],
|
93 |
+
"num_epochs": 2,
|
94 |
+
"micro_batch_size": 2,
|
95 |
+
"gradient_accumulation_steps": 1,
|
96 |
+
"output_dir": temp_dir,
|
97 |
+
"learning_rate": 0.00001,
|
98 |
+
"optimizer": "adamw_bnb_8bit",
|
99 |
+
"lr_scheduler": "cosine",
|
100 |
+
"max_steps": 20,
|
101 |
+
"save_steps": 10,
|
102 |
+
"eval_steps": 10,
|
103 |
+
"sample_packing": True,
|
104 |
+
}
|
105 |
+
)
|
106 |
+
if is_torch_bf16_gpu_available():
|
107 |
+
cfg.bf16 = True
|
108 |
+
else:
|
109 |
+
cfg.fp16 = True
|
110 |
+
normalize_config(cfg)
|
111 |
+
cli_args = TrainerCliArgs()
|
112 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
113 |
+
|
114 |
+
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
115 |
+
assert (
|
116 |
+
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
117 |
+
in model.model.layers[0].self_attn.__class__.__module__
|
118 |
+
)
|
119 |
+
assert (
|
120 |
+
"MixtralMultipackFlashAttention2"
|
121 |
+
in model.model.layers[0].self_attn.__class__.__name__
|
122 |
+
)
|
123 |
+
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
tests/e2e/test_model_patches.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E smoke tests to check that the monkeypatches are in place for certain configurations
|
3 |
+
"""
|
4 |
+
|
5 |
+
import unittest
|
6 |
+
|
7 |
+
from axolotl.common.cli import TrainerCliArgs
|
8 |
+
from axolotl.utils.config import normalize_config
|
9 |
+
from axolotl.utils.dict import DictDefault
|
10 |
+
from axolotl.utils.models import load_model, load_tokenizer
|
11 |
+
|
12 |
+
from .utils import with_temp_dir
|
13 |
+
|
14 |
+
|
15 |
+
class TestModelPatches(unittest.TestCase):
|
16 |
+
"""
|
17 |
+
TestCases for the multipack monkey patches
|
18 |
+
"""
|
19 |
+
|
20 |
+
@with_temp_dir
|
21 |
+
def test_mixtral_multipack(self, temp_dir):
|
22 |
+
cfg = DictDefault(
|
23 |
+
{
|
24 |
+
"base_model": "hf-internal-testing/Mixtral-tiny",
|
25 |
+
"tokenizer_config": "mistralai/Mixtral-8x7B-v0.1",
|
26 |
+
"flash_attention": True,
|
27 |
+
"sample_packing": True,
|
28 |
+
"sequence_len": 2048,
|
29 |
+
"val_set_size": 0.1,
|
30 |
+
"special_tokens": {},
|
31 |
+
"datasets": [
|
32 |
+
{
|
33 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
34 |
+
"type": "alpaca",
|
35 |
+
},
|
36 |
+
],
|
37 |
+
"num_epochs": 2,
|
38 |
+
"micro_batch_size": 2,
|
39 |
+
"gradient_accumulation_steps": 1,
|
40 |
+
"output_dir": temp_dir,
|
41 |
+
"learning_rate": 0.00001,
|
42 |
+
"optimizer": "adamw_bnb_8bit",
|
43 |
+
"lr_scheduler": "cosine",
|
44 |
+
"max_steps": 20,
|
45 |
+
"save_steps": 10,
|
46 |
+
"eval_steps": 10,
|
47 |
+
}
|
48 |
+
)
|
49 |
+
normalize_config(cfg)
|
50 |
+
cli_args = TrainerCliArgs()
|
51 |
+
tokenizer = load_tokenizer(cfg)
|
52 |
+
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
53 |
+
|
54 |
+
assert (
|
55 |
+
"axolotl.monkeypatch.mixtral.modeling_mixtral"
|
56 |
+
in model.model.layers[0].self_attn.__class__.__module__
|
57 |
+
)
|
58 |
+
assert (
|
59 |
+
"MixtralMultipackFlashAttention2"
|
60 |
+
in model.model.layers[0].self_attn.__class__.__name__
|
61 |
+
)
|
62 |
+
|
63 |
+
@with_temp_dir
|
64 |
+
def test_mistral_multipack(self, temp_dir):
|
65 |
+
cfg = DictDefault(
|
66 |
+
{
|
67 |
+
"base_model": "openaccess-ai-collective/tiny-mistral",
|
68 |
+
"flash_attention": True,
|
69 |
+
"sample_packing": True,
|
70 |
+
"sequence_len": 2048,
|
71 |
+
"val_set_size": 0.1,
|
72 |
+
"special_tokens": {},
|
73 |
+
"datasets": [
|
74 |
+
{
|
75 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
76 |
+
"type": "alpaca",
|
77 |
+
},
|
78 |
+
],
|
79 |
+
"num_epochs": 2,
|
80 |
+
"micro_batch_size": 2,
|
81 |
+
"gradient_accumulation_steps": 1,
|
82 |
+
"output_dir": temp_dir,
|
83 |
+
"learning_rate": 0.00001,
|
84 |
+
"optimizer": "adamw_bnb_8bit",
|
85 |
+
"lr_scheduler": "cosine",
|
86 |
+
"max_steps": 20,
|
87 |
+
"save_steps": 10,
|
88 |
+
"eval_steps": 10,
|
89 |
+
}
|
90 |
+
)
|
91 |
+
normalize_config(cfg)
|
92 |
+
cli_args = TrainerCliArgs()
|
93 |
+
tokenizer = load_tokenizer(cfg)
|
94 |
+
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
95 |
+
|
96 |
+
assert (
|
97 |
+
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
98 |
+
in model.model.layers[0].self_attn.forward.__module__
|
99 |
+
)
|