File size: 298 Bytes
814aee6
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
"""
Patches to support multipack for phi2
"""
import transformers

from axolotl.monkeypatch.utils import get_unpad_data


def replace_phi_attn_with_multipack_flash_attn():
    transformers.models.phi.modeling_phi._get_unpad_data = (  # pylint: disable=protected-access
        get_unpad_data
    )