winglian commited on
Commit
8df7b88
1 Parent(s): 6366b0c

beta support for multipack with gemmoe: (#1402)

Browse files
src/axolotl/monkeypatch/multipack.py CHANGED
@@ -1,6 +1,9 @@
1
  """multipack patching for v2 of sample packing"""
 
2
 
3
  import transformers
 
 
4
  from transformers.integrations import is_deepspeed_zero3_enabled
5
 
6
  from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
@@ -12,11 +15,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
12
  "falcon",
13
  "phi",
14
  "gemma",
 
15
  "starcoder2",
16
  ]
17
 
18
 
19
- def patch_for_multipack(model_type):
20
  if model_type == "mixtral":
21
  transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
22
  get_unpad_data
@@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
43
  transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
44
  get_unpad_data
45
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """multipack patching for v2 of sample packing"""
2
+ import importlib
3
 
4
  import transformers
5
+ from accelerate import init_empty_weights
6
+ from transformers import AutoConfig, AutoModelForCausalLM
7
  from transformers.integrations import is_deepspeed_zero3_enabled
8
 
9
  from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
 
15
  "falcon",
16
  "phi",
17
  "gemma",
18
+ "gemmoe",
19
  "starcoder2",
20
  ]
21
 
22
 
23
+ def patch_for_multipack(model_type, model_name=None):
24
  if model_type == "mixtral":
25
  transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
26
  get_unpad_data
 
47
  transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
48
  get_unpad_data
49
  )
50
+ elif model_type == "gemmoe":
51
+ model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
52
+ # we need to load the model here in order for modeling_gemmoe to be available
53
+ with init_empty_weights():
54
+ AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
55
+ module_name = model_config.__class__.__module__.replace(
56
+ ".configuration_gemmoe", ".modeling_gemmoe"
57
+ )
58
+ modeling_gemmoe = importlib.import_module(module_name)
59
+ modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
60
+ get_unpad_data
61
+ )
src/axolotl/utils/models.py CHANGED
@@ -429,7 +429,7 @@ def load_model(
429
  and cfg.flash_attention
430
  and cfg.sample_packing
431
  ):
432
- patch_for_multipack(cfg.model_config_type)
433
  elif cfg.is_llama_derived_model:
434
  # Modify all llama derived models in one block
435
 
 
429
  and cfg.flash_attention
430
  and cfg.sample_packing
431
  ):
432
+ patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
433
  elif cfg.is_llama_derived_model:
434
  # Modify all llama derived models in one block
435