""" Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention """ from axolotl.monkeypatch.utils import ( patched_prepare_4d_causal_attention_mask, patched_prepare_4d_causal_attention_mask_for_sdpa, ) def hijack_llama_prepare_4d_mask(): import transformers.modeling_attn_mask_utils import transformers.models.llama.modeling_llama transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask_for_sdpa ) transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask ) transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access patched_prepare_4d_causal_attention_mask )