File size: 1,097 Bytes
00568c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
import transformers.modeling_attn_mask_utils
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
|