""" | |
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention | |
""" | |
from axolotl.monkeypatch.utils import ( | |
patched_prepare_4d_causal_attention_mask, | |
patched_prepare_4d_causal_attention_mask_for_sdpa, | |
) | |
def hijack_llama_prepare_4d_mask(): | |
import transformers.modeling_attn_mask_utils | |
import transformers.models.llama.modeling_llama | |
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access | |
patched_prepare_4d_causal_attention_mask_for_sdpa | |
) | |
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access | |
patched_prepare_4d_causal_attention_mask_for_sdpa | |
) | |
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access | |
patched_prepare_4d_causal_attention_mask | |
) | |
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access | |
patched_prepare_4d_causal_attention_mask | |
) | |