winglian commited on
Commit
5698943
1 Parent(s): 411293b

simplify haldning for newer multipack patches so they can be added in a single place (#1270)

Browse files
src/axolotl/core/trainer_builder.py CHANGED
@@ -28,6 +28,7 @@ from transformers import (
28
  from transformers.trainer_utils import seed_worker
29
  from trl import DPOTrainer
30
 
 
31
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
32
  from axolotl.utils.callbacks import (
33
  EvalFirstStepCallback,
@@ -994,7 +995,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
994
  ]
995
  ]
996
  if use_batch_sampler_collator:
997
- if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
998
  collator = V2BatchSamplerDataCollatorForSeq2Seq
999
  elif (
1000
  self.cfg.model_config_type in ["llama"]
 
28
  from transformers.trainer_utils import seed_worker
29
  from trl import DPOTrainer
30
 
31
+ from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
32
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
33
  from axolotl.utils.callbacks import (
34
  EvalFirstStepCallback,
 
995
  ]
996
  ]
997
  if use_batch_sampler_collator:
998
+ if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
999
  collator = V2BatchSamplerDataCollatorForSeq2Seq
1000
  elif (
1001
  self.cfg.model_config_type in ["llama"]
src/axolotl/monkeypatch/falcon/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- """
2
- Patches to support multipack for falcon
3
- """
4
- import transformers
5
-
6
- from axolotl.monkeypatch.utils import get_unpad_data
7
-
8
-
9
- def replace_falcon_attn_with_multipack_flash_attn():
10
- transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
11
- get_unpad_data
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/mixtral/__init__.py CHANGED
@@ -2,9 +2,6 @@
2
  Patches to support multipack for mixtral
3
  """
4
  import torch
5
- import transformers
6
-
7
- from axolotl.monkeypatch.utils import get_unpad_data
8
 
9
 
10
  def patch_mixtral_moe_forward_zero3() -> None:
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
51
 
52
  MixtralBLockSparseTop2MLP.forward = mlp_forward
53
  MixtralSparseMoeBlock.forward = moe_forward
54
-
55
-
56
- def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
57
- transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
58
- get_unpad_data
59
- )
60
- if for_zero3:
61
- patch_mixtral_moe_forward_zero3()
 
2
  Patches to support multipack for mixtral
3
  """
4
  import torch
 
 
 
5
 
6
 
7
  def patch_mixtral_moe_forward_zero3() -> None:
 
48
 
49
  MixtralBLockSparseTop2MLP.forward = mlp_forward
50
  MixtralSparseMoeBlock.forward = moe_forward
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/multipack.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """multipack patching for v2 of sample packing"""
2
+
3
+ import transformers
4
+ from transformers.integrations import is_deepspeed_zero3_enabled
5
+
6
+ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
7
+ from axolotl.monkeypatch.utils import get_unpad_data
8
+
9
+ SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
10
+
11
+
12
+ def patch_for_multipack(model_type):
13
+ if model_type == "mixtral":
14
+ transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
15
+ get_unpad_data
16
+ )
17
+ if is_deepspeed_zero3_enabled():
18
+ patch_mixtral_moe_forward_zero3()
19
+ elif model_type == "qwen2":
20
+ transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
21
+ get_unpad_data
22
+ )
23
+ elif model_type == "falcon":
24
+ transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
25
+ get_unpad_data
26
+ )
27
+ elif model_type == "phi":
28
+ transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
29
+ get_unpad_data
30
+ )
src/axolotl/monkeypatch/phi/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- """
2
- Patches to support multipack for phi2
3
- """
4
- import transformers
5
-
6
- from axolotl.monkeypatch.utils import get_unpad_data
7
-
8
-
9
- def replace_phi_attn_with_multipack_flash_attn():
10
- transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
11
- get_unpad_data
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/monkeypatch/qwen2/__init__.py DELETED
@@ -1,12 +0,0 @@
1
- """
2
- Patches to support multipack for qwen2
3
- """
4
- import transformers
5
-
6
- from axolotl.monkeypatch.utils import get_unpad_data
7
-
8
-
9
- def replace_qwen2_attn_with_multipack_flash_attn():
10
- transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
11
- get_unpad_data
12
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/models.py CHANGED
@@ -29,6 +29,10 @@ from transformers import ( # noqa: F401
29
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
30
 
31
  from axolotl.models.mamba import fix_mamba_attn_for_loss
 
 
 
 
32
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
33
  from axolotl.utils.bench import log_gpu_memory_usage
34
  from axolotl.utils.chat_templates import chat_templates
@@ -299,8 +303,15 @@ def load_model(
299
  shifted-sparse attention does not currently support sample packing."
300
  )
301
 
302
- # Modify all llama derived models in one block
303
- if cfg.is_llama_derived_model:
 
 
 
 
 
 
 
304
  if cfg.flash_attention:
305
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
306
  replace_llama_attn_with_flash_attn,
@@ -354,43 +365,6 @@ def load_model(
354
  LOG.info("patching mistral with flash attention")
355
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
356
 
357
- if (
358
- cfg.model_config_type == "mixtral"
359
- and cfg.flash_attention
360
- and cfg.sample_packing
361
- ):
362
- from axolotl.monkeypatch.mixtral import (
363
- replace_mixtral_attn_with_multipack_flash_attn,
364
- )
365
-
366
- LOG.info("patching mixtral with flash attention")
367
- mixtral_patch_kwargs = {}
368
- if is_deepspeed_zero3_enabled():
369
- mixtral_patch_kwargs["for_zero3"] = True
370
- replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
371
-
372
- if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
373
- from axolotl.monkeypatch.falcon import (
374
- replace_falcon_attn_with_multipack_flash_attn,
375
- )
376
-
377
- LOG.info("patching falcon with flash attention")
378
- replace_falcon_attn_with_multipack_flash_attn()
379
-
380
- if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
381
- from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
382
-
383
- LOG.info("patching phi with flash attention")
384
- replace_phi_attn_with_multipack_flash_attn()
385
-
386
- if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
387
- from axolotl.monkeypatch.qwen2 import (
388
- replace_qwen2_attn_with_multipack_flash_attn,
389
- )
390
-
391
- LOG.info("patching qwen2 with flash attention")
392
- replace_qwen2_attn_with_multipack_flash_attn()
393
-
394
  if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
395
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
396
 
@@ -501,7 +475,7 @@ def load_model(
501
  "flash_attention_2"
502
  )
503
  else:
504
- if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
505
  model_kwargs["attn_implementation"] = "flash_attention_2"
506
  model_config._attn_implementation = ( # pylint: disable=protected-access
507
  "flash_attention_2"
 
29
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
30
 
31
  from axolotl.models.mamba import fix_mamba_attn_for_loss
32
+ from axolotl.monkeypatch.multipack import (
33
+ SUPPORTED_MULTIPACK_MODEL_TYPES,
34
+ patch_for_multipack,
35
+ )
36
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
37
  from axolotl.utils.bench import log_gpu_memory_usage
38
  from axolotl.utils.chat_templates import chat_templates
 
303
  shifted-sparse attention does not currently support sample packing."
304
  )
305
 
306
+ if (
307
+ cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
308
+ and cfg.flash_attention
309
+ and cfg.sample_packing
310
+ ):
311
+ patch_for_multipack(cfg.model_config_type)
312
+ elif cfg.is_llama_derived_model:
313
+ # Modify all llama derived models in one block
314
+
315
  if cfg.flash_attention:
316
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
317
  replace_llama_attn_with_flash_attn,
 
365
  LOG.info("patching mistral with flash attention")
366
  replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
367
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
  if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
369
  from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
370
 
 
475
  "flash_attention_2"
476
  )
477
  else:
478
+ if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
479
  model_kwargs["attn_implementation"] = "flash_attention_2"
480
  model_config._attn_implementation = ( # pylint: disable=protected-access
481
  "flash_attention_2"