Infinity / models /flex_attn.py
MohamedRashad's picture
Add initial project structure with requirements and utility functions
32287b3
"""
Wrap torch's flex attention and handle mess info or potentially refactor
"""
from functools import partial
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
try:
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention_available = True
except ImportError:
print(f"[Warning] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}")
flex_attention_available = False
def _causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def _length_to_offsets(lengths, device):
"""Converts a list of lengths to a list of offsets.
Args:
lengths: A list of lengths.
"""
offsets = [0]
offsets.extend(lengths)
offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
offsets = torch.cumsum(offsets, dim=-1)
return offsets
def _generate_var_mask_mod(offsets):
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
format.
Args:
offsets: This tensor should be of shape(num_documents + 1)
this should contain the cumulative counts of document tokens.
e.g. if you have 3 documents of length 2, 4, 3 then
offsets = [0, 2, 6, 9]
Note:
What is the sequence stacked format? When assembling batches of inputs, we
take multiple sequences and stack them together to form 1 large sequence. We then
use masking to ensure that the attention scores are only applied to tokens within
the same document.
"""
def _offsets_to_doc_ids_tensor(offsets):
device = offsets.device
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts
)
document_id = _offsets_to_doc_ids_tensor(offsets)
def var_mask_mod(b, h, q_idx, kv_idx):
same_doc = document_id[q_idx] == document_id[kv_idx]
causal_mask = _causal_mask(b, h, q_idx, kv_idx)
return same_doc | causal_mask
return var_mask_mod
def _generate_var_infer_mask_with_kv_cache(lengths):
kv_len = sum(lengths)
def var_mask_mod(b, h, q_idx, kv_idx):
return kv_idx < kv_len
return var_mask_mod
class FlexAttn(nn.Module):
def __init__(
self, block_scales:list, mask_type:str, B, H, L:int, auto_padding=False
):
"""
:param block_scales: accept VAR's block sizes like [(1,1), (2,2), (3,3)]
:param mask_type: var/causal
:param B: batch size
:param H: heads num
:param L: sequence length
"""
super().__init__()
if not flex_attention_available:
raise NotImplementedError((f"[Error] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}"))
self.support_mask_type = ["var", "causal", "var_infer_mask_with_kv_cache"]
self.auto_padding = auto_padding
self.flex_attention = torch.compile(flex_attention)
self.block_scales = block_scales
self.lengths = [ x * y * z for x,y,z in block_scales]
self.offsets = _length_to_offsets(self.lengths, device='cuda')
# if L paded to align 128, block need to cover padding area
if self.offsets[-1] < L:
self.offsets = torch.cat((self.offsets, torch.tensor([L], device='cuda')), dim=0)
if mask_type == "var":
self.mask_mod = _generate_var_mask_mod(self.offsets)
self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
elif mask_type == "causal":
self.mask_mod = _causal_mask
self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
elif mask_type == 'var_infer_mask_with_kv_cache':
self.mask_mod = _generate_var_infer_mask_with_kv_cache(self.lengths)
self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True)
else:
raise NotImplementedError(f"{mask_type} not supportted in FlexAttn, support type:{self.support_mask_type}")
def forward(self, q, k, v, scale = None):
if self.auto_padding:
q_pad_len = (128 - q.shape[-2] % 128) % 128
kv_pad_len = (128 - k.shape[-2] % 128) % 128
q_pad = F.pad(q, (0, 0, 0, q_pad_len))
k_pad = F.pad(k, (0, 0, 0, kv_pad_len))
v_pad = F.pad(v, (0, 0, 0, kv_pad_len))
oup = self.flex_attention(q_pad.to(v_pad.dtype), k_pad.to(v.dtype), v_pad, block_mask = self.block_mask, scale = scale)
if q_pad_len > 0:
oup = oup[:,:,:-q_pad_len]
else:
oup = self.flex_attention(q.to(v.dtype), k.to(v.dtype), v, block_mask = self.block_mask, scale = scale)
return oup
def extra_repr(self) -> str:
tail = ''
return f'block size:{self.block_scales} {tail}'