""" Shared utils for the monkeypatches """ from typing import Optional import torch import torch.nn.functional as F from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.utils import is_torch_bf16_gpu_available @torch.jit.script def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: max_num = int(torch.max(attention_mask).item()) batch_size, _ = attention_mask.shape counts = torch.zeros((batch_size, max_num), dtype=torch.int32) for i in range(1, max_num + 1): mask = attention_mask == i counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) result = counts.flatten() nonzero_indices = torch.nonzero(result).squeeze(-1) return result[nonzero_indices] @torch.jit.script def get_unpad_data(attention_mask: torch.Tensor): device = attention_mask.device seqlens_in_batch = get_max_seqlen_in_batch(attention_mask) indices = torch.nonzero(attention_mask.flatten()).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = ( F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) .to(device=device) .detach() ) return ( indices, cu_seqlens, max_seqlen_in_batch, ) def get_cu_seqlens(attn_mask): """generate a cumulative sequence length mask for flash attention using attn mask""" if len(attn_mask.shape) == 1: attn_mask = attn_mask.unsqueeze(0) device = attn_mask.device results = [] max_seq_lens = [] for row in attn_mask: # Exclude zeros to avoid adding their positions to the mask t_non_zeros = row[row != 0] # Find where the sequence number changes (including the first position) seq_change = torch.cat( [ torch.tensor([1], dtype=torch.int32, device=device), t_non_zeros[1:] != t_non_zeros[:-1], ] ) # Get the indices where the sequence changes change_indices = torch.cat( [ (seq_change == 1).nonzero(as_tuple=True)[0], torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), ] ) # Calculate the sequence lengths seq_lengths = change_indices[1:] - change_indices[:-1] # Calculate the length of the final sequence or padding final_seq_length = len(row) - change_indices[-1] # Append the length of the final sequence or padding to seq_lengths if final_seq_length.item(): seq_lengths = torch.cat( [ seq_lengths, torch.tensor( [final_seq_length.item()], dtype=torch.int32, device=device ), ] ) # Calculate the cumulative sequence lengths cu_seqlens = torch.cat( [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] ) max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() results.append(cu_seqlens) max_seq_lens.append(max_seq_len) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) def get_cu_seqlens_from_pos_ids(position_ids): """generate a cumulative sequence length mask for flash attention using pos ids""" if len(position_ids.shape) == 1: position_ids = position_ids.unsqueeze(0) device = position_ids.device results = [] max_seq_lens = [] for row in position_ids: # Count the number of consecutive zeros from the right side padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() # Adjust the row to exclude padding adjusted_row = row[:-padding_length] if padding_length else row.clone() # Find where the position resets to 0 (indicating a new sequence) seq_starts = torch.cat( [ torch.tensor([True], dtype=torch.bool, device=device), adjusted_row[1:] == 0, ] ) # Get the indices where the sequence starts start_indices = torch.cat( [ torch.nonzero(seq_starts).unbind(dim=1)[0], torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), ] ) # Calculate the sequence lengths seq_lengths = start_indices[1:] - start_indices[:-1] # Calculate the cumulative sequence lengths cu_seqlens = torch.cat( [torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] ) # Append the padding length to the cumulative sequence lengths if padding_length: cu_seqlens = torch.cat( [cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)] ) max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() results.append(cu_seqlens) max_seq_lens.append(max_seq_len) # Find the maximum value across all tensors max_value = max(t.max() for t in results) # Find the length of the longest tensor max_length = max(t.size(0) for t in results) # Pad each tensor to the same length and collect them in a list padded_results = [ F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results ] return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens) def set_module_name(model, name, value): if "." in name: parent_name = name.rsplit(".", 1)[0] child_name = name[len(parent_name) + 1 :] parent = model.get_submodule(parent_name) else: parent_name = "" parent = model child_name = name setattr(parent, child_name, value) def mask_2d_to_4d( mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None ): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. This expansion handles packed sequences so that sequences share the same attention mask integer value when they attend to each other within that sequence. This expansion transforms the mask to lower triangular form to prevent future peeking. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len mask = mask.unsqueeze(1).unsqueeze(2) mask = mask.expand(bsz, 1, tgt_len, src_len) # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one binary_mask = torch.where( mask != 0, torch.tensor(1, device=mask.device).to(dtype), torch.tensor(0, device=mask.device).to(dtype), ) # Create a block-diagonal mask. # we multiply by the binary mask so that 0's in the original mask are correctly excluded zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask # Now let's create a lower triangular mask of ones that will zero out the upper triangular part lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( mask.device ) # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask masked_zero_one_mask = zero_one_mask * lower_triangular_ones return masked_zero_one_mask def patched_prepare_4d_causal_attention_mask( attention_mask: Optional[torch.Tensor], *args, ): dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 return _prepare_4d_causal_attention_mask( mask_2d_to_4d(attention_mask, dtype=dtype), *args, ) def patched_prepare_4d_causal_attention_mask_for_sdpa( attention_mask: Optional[torch.Tensor], *args, ): dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 return _prepare_4d_causal_attention_mask_for_sdpa( mask_2d_to_4d(attention_mask, dtype=dtype), *args, )