""" | |
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf | |
""" | |
from typing import Optional | |
import torch | |
from axolotl.monkeypatch.utils import mask_2d_to_4d | |
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): | |
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len) | |
inverted_mask = 1.0 - masked_zero_one_mask | |
return inverted_mask.masked_fill( | |
inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
) | |
def hijack_expand_mask(): | |
import transformers | |
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access | |
_expand_mask | |
) | |