File size: 298 Bytes
814aee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
"""
Patches to support multipack for phi2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_phi_attn_with_multipack_flash_attn():
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
|