|
"""multipack patching for v2 of sample packing""" |
|
|
|
import transformers |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
|
|
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 |
|
from axolotl.monkeypatch.utils import get_unpad_data |
|
|
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [ |
|
"mixtral", |
|
"qwen2", |
|
"falcon", |
|
"phi", |
|
"gemma", |
|
"starcoder2", |
|
] |
|
|
|
|
|
def patch_for_multipack(model_type): |
|
if model_type == "mixtral": |
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( |
|
get_unpad_data |
|
) |
|
if is_deepspeed_zero3_enabled(): |
|
patch_mixtral_moe_forward_zero3() |
|
elif model_type == "qwen2": |
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( |
|
get_unpad_data |
|
) |
|
elif model_type == "falcon": |
|
transformers.models.falcon.modeling_falcon._get_unpad_data = ( |
|
get_unpad_data |
|
) |
|
elif model_type == "phi": |
|
transformers.models.phi.modeling_phi._get_unpad_data = ( |
|
get_unpad_data |
|
) |
|
elif model_type == "gemma": |
|
transformers.models.gemma.modeling_gemma._get_unpad_data = ( |
|
get_unpad_data |
|
) |
|
elif model_type == "starcoder2": |
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( |
|
get_unpad_data |
|
) |
|
|