""" expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf """ from typing import Optional import torch from axolotl.monkeypatch.utils import mask_2d_to_4d def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len) inverted_mask = 1.0 - masked_zero_one_mask return inverted_mask.masked_fill( inverted_mask.to(torch.bool), torch.finfo(dtype).min ) def hijack_expand_mask(): import transformers transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access _expand_mask )