simplify haldning for newer multipack patches so they can be added in a single place (#1270)
Browse files- src/axolotl/core/trainer_builder.py +2 -1
- src/axolotl/monkeypatch/falcon/__init__.py +0 -12
- src/axolotl/monkeypatch/mixtral/__init__.py +0 -11
- src/axolotl/monkeypatch/multipack.py +30 -0
- src/axolotl/monkeypatch/phi/__init__.py +0 -12
- src/axolotl/monkeypatch/qwen2/__init__.py +0 -12
- src/axolotl/utils/models.py +14 -40
src/axolotl/core/trainer_builder.py
CHANGED
@@ -28,6 +28,7 @@ from transformers import (
|
|
28 |
from transformers.trainer_utils import seed_worker
|
29 |
from trl import DPOTrainer
|
30 |
|
|
|
31 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
32 |
from axolotl.utils.callbacks import (
|
33 |
EvalFirstStepCallback,
|
@@ -994,7 +995,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
994 |
]
|
995 |
]
|
996 |
if use_batch_sampler_collator:
|
997 |
-
if self.cfg.model_config_type in
|
998 |
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
999 |
elif (
|
1000 |
self.cfg.model_config_type in ["llama"]
|
|
|
28 |
from transformers.trainer_utils import seed_worker
|
29 |
from trl import DPOTrainer
|
30 |
|
31 |
+
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
32 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
33 |
from axolotl.utils.callbacks import (
|
34 |
EvalFirstStepCallback,
|
|
|
995 |
]
|
996 |
]
|
997 |
if use_batch_sampler_collator:
|
998 |
+
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
999 |
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
1000 |
elif (
|
1001 |
self.cfg.model_config_type in ["llama"]
|
src/axolotl/monkeypatch/falcon/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Patches to support multipack for falcon
|
3 |
-
"""
|
4 |
-
import transformers
|
5 |
-
|
6 |
-
from axolotl.monkeypatch.utils import get_unpad_data
|
7 |
-
|
8 |
-
|
9 |
-
def replace_falcon_attn_with_multipack_flash_attn():
|
10 |
-
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
11 |
-
get_unpad_data
|
12 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/monkeypatch/mixtral/__init__.py
CHANGED
@@ -2,9 +2,6 @@
|
|
2 |
Patches to support multipack for mixtral
|
3 |
"""
|
4 |
import torch
|
5 |
-
import transformers
|
6 |
-
|
7 |
-
from axolotl.monkeypatch.utils import get_unpad_data
|
8 |
|
9 |
|
10 |
def patch_mixtral_moe_forward_zero3() -> None:
|
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|
51 |
|
52 |
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
53 |
MixtralSparseMoeBlock.forward = moe_forward
|
54 |
-
|
55 |
-
|
56 |
-
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
57 |
-
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
58 |
-
get_unpad_data
|
59 |
-
)
|
60 |
-
if for_zero3:
|
61 |
-
patch_mixtral_moe_forward_zero3()
|
|
|
2 |
Patches to support multipack for mixtral
|
3 |
"""
|
4 |
import torch
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
def patch_mixtral_moe_forward_zero3() -> None:
|
|
|
48 |
|
49 |
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
50 |
MixtralSparseMoeBlock.forward = moe_forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/monkeypatch/multipack.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""multipack patching for v2 of sample packing"""
|
2 |
+
|
3 |
+
import transformers
|
4 |
+
from transformers.integrations import is_deepspeed_zero3_enabled
|
5 |
+
|
6 |
+
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
7 |
+
from axolotl.monkeypatch.utils import get_unpad_data
|
8 |
+
|
9 |
+
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
10 |
+
|
11 |
+
|
12 |
+
def patch_for_multipack(model_type):
|
13 |
+
if model_type == "mixtral":
|
14 |
+
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
15 |
+
get_unpad_data
|
16 |
+
)
|
17 |
+
if is_deepspeed_zero3_enabled():
|
18 |
+
patch_mixtral_moe_forward_zero3()
|
19 |
+
elif model_type == "qwen2":
|
20 |
+
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
21 |
+
get_unpad_data
|
22 |
+
)
|
23 |
+
elif model_type == "falcon":
|
24 |
+
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
25 |
+
get_unpad_data
|
26 |
+
)
|
27 |
+
elif model_type == "phi":
|
28 |
+
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
29 |
+
get_unpad_data
|
30 |
+
)
|
src/axolotl/monkeypatch/phi/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Patches to support multipack for phi2
|
3 |
-
"""
|
4 |
-
import transformers
|
5 |
-
|
6 |
-
from axolotl.monkeypatch.utils import get_unpad_data
|
7 |
-
|
8 |
-
|
9 |
-
def replace_phi_attn_with_multipack_flash_attn():
|
10 |
-
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
11 |
-
get_unpad_data
|
12 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/monkeypatch/qwen2/__init__.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Patches to support multipack for qwen2
|
3 |
-
"""
|
4 |
-
import transformers
|
5 |
-
|
6 |
-
from axolotl.monkeypatch.utils import get_unpad_data
|
7 |
-
|
8 |
-
|
9 |
-
def replace_qwen2_attn_with_multipack_flash_attn():
|
10 |
-
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
11 |
-
get_unpad_data
|
12 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/utils/models.py
CHANGED
@@ -29,6 +29,10 @@ from transformers import ( # noqa: F401
|
|
29 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
30 |
|
31 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
|
|
|
|
|
|
|
|
32 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
33 |
from axolotl.utils.bench import log_gpu_memory_usage
|
34 |
from axolotl.utils.chat_templates import chat_templates
|
@@ -299,8 +303,15 @@ def load_model(
|
|
299 |
shifted-sparse attention does not currently support sample packing."
|
300 |
)
|
301 |
|
302 |
-
|
303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
if cfg.flash_attention:
|
305 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
306 |
replace_llama_attn_with_flash_attn,
|
@@ -354,43 +365,6 @@ def load_model(
|
|
354 |
LOG.info("patching mistral with flash attention")
|
355 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
356 |
|
357 |
-
if (
|
358 |
-
cfg.model_config_type == "mixtral"
|
359 |
-
and cfg.flash_attention
|
360 |
-
and cfg.sample_packing
|
361 |
-
):
|
362 |
-
from axolotl.monkeypatch.mixtral import (
|
363 |
-
replace_mixtral_attn_with_multipack_flash_attn,
|
364 |
-
)
|
365 |
-
|
366 |
-
LOG.info("patching mixtral with flash attention")
|
367 |
-
mixtral_patch_kwargs = {}
|
368 |
-
if is_deepspeed_zero3_enabled():
|
369 |
-
mixtral_patch_kwargs["for_zero3"] = True
|
370 |
-
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
371 |
-
|
372 |
-
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
373 |
-
from axolotl.monkeypatch.falcon import (
|
374 |
-
replace_falcon_attn_with_multipack_flash_attn,
|
375 |
-
)
|
376 |
-
|
377 |
-
LOG.info("patching falcon with flash attention")
|
378 |
-
replace_falcon_attn_with_multipack_flash_attn()
|
379 |
-
|
380 |
-
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
|
381 |
-
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
|
382 |
-
|
383 |
-
LOG.info("patching phi with flash attention")
|
384 |
-
replace_phi_attn_with_multipack_flash_attn()
|
385 |
-
|
386 |
-
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
387 |
-
from axolotl.monkeypatch.qwen2 import (
|
388 |
-
replace_qwen2_attn_with_multipack_flash_attn,
|
389 |
-
)
|
390 |
-
|
391 |
-
LOG.info("patching qwen2 with flash attention")
|
392 |
-
replace_qwen2_attn_with_multipack_flash_attn()
|
393 |
-
|
394 |
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
395 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
396 |
|
@@ -501,7 +475,7 @@ def load_model(
|
|
501 |
"flash_attention_2"
|
502 |
)
|
503 |
else:
|
504 |
-
if model_config.model_type in
|
505 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
506 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
507 |
"flash_attention_2"
|
|
|
29 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
30 |
|
31 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
32 |
+
from axolotl.monkeypatch.multipack import (
|
33 |
+
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
34 |
+
patch_for_multipack,
|
35 |
+
)
|
36 |
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
37 |
from axolotl.utils.bench import log_gpu_memory_usage
|
38 |
from axolotl.utils.chat_templates import chat_templates
|
|
|
303 |
shifted-sparse attention does not currently support sample packing."
|
304 |
)
|
305 |
|
306 |
+
if (
|
307 |
+
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
308 |
+
and cfg.flash_attention
|
309 |
+
and cfg.sample_packing
|
310 |
+
):
|
311 |
+
patch_for_multipack(cfg.model_config_type)
|
312 |
+
elif cfg.is_llama_derived_model:
|
313 |
+
# Modify all llama derived models in one block
|
314 |
+
|
315 |
if cfg.flash_attention:
|
316 |
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
317 |
replace_llama_attn_with_flash_attn,
|
|
|
365 |
LOG.info("patching mistral with flash attention")
|
366 |
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
367 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
369 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
370 |
|
|
|
475 |
"flash_attention_2"
|
476 |
)
|
477 |
else:
|
478 |
+
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
479 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
480 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
481 |
"flash_attention_2"
|