|
""" |
|
Shared utils for the monkeypatches |
|
""" |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
@torch.jit.script |
|
def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor: |
|
max_num = int(torch.max(attention_mask).item()) |
|
batch_size, _ = attention_mask.shape |
|
counts = torch.zeros((batch_size, max_num), dtype=torch.int32) |
|
|
|
for i in range(1, max_num + 1): |
|
mask = attention_mask == i |
|
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32) |
|
|
|
result = counts.flatten() |
|
nonzero_indices = torch.nonzero(result).squeeze(-1) |
|
return result[nonzero_indices] |
|
|
|
|
|
@torch.jit.script |
|
def get_unpad_data(attention_mask: torch.Tensor): |
|
device = attention_mask.device |
|
seqlens_in_batch = get_max_seqlen_in_batch(attention_mask) |
|
indices = torch.nonzero(attention_mask.flatten()).flatten() |
|
max_seqlen_in_batch = seqlens_in_batch.max().item() |
|
cu_seqlens = ( |
|
F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
.to(device=device) |
|
.detach() |
|
) |
|
return ( |
|
indices, |
|
cu_seqlens, |
|
max_seqlen_in_batch, |
|
) |
|
|
|
|
|
def get_cu_seqlens(attn_mask): |
|
"""generate a cumulative sequence length mask for flash attention using attn mask""" |
|
if len(attn_mask.shape) == 1: |
|
attn_mask = attn_mask.unsqueeze(0) |
|
|
|
device = attn_mask.device |
|
results = [] |
|
max_seq_lens = [] |
|
|
|
for row in attn_mask: |
|
|
|
t_non_zeros = row[row != 0] |
|
|
|
seq_change = torch.cat( |
|
[ |
|
torch.tensor([1], dtype=torch.int32, device=device), |
|
t_non_zeros[1:] != t_non_zeros[:-1], |
|
] |
|
) |
|
|
|
change_indices = torch.cat( |
|
[ |
|
(seq_change == 1).nonzero(as_tuple=True)[0], |
|
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device), |
|
] |
|
) |
|
|
|
seq_lengths = change_indices[1:] - change_indices[:-1] |
|
|
|
final_seq_length = len(row) - change_indices[-1] |
|
|
|
if final_seq_length.item(): |
|
seq_lengths = torch.cat( |
|
[ |
|
seq_lengths, |
|
torch.tensor( |
|
[final_seq_length.item()], dtype=torch.int32, device=device |
|
), |
|
] |
|
) |
|
|
|
cu_seqlens = torch.cat( |
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] |
|
) |
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() |
|
results.append(cu_seqlens) |
|
max_seq_lens.append(max_seq_len) |
|
|
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) |
|
|
|
|
|
@torch.jit.script |
|
def get_cu_seqlens_from_pos_ids(position_ids): |
|
"""generate a cumulative sequence length mask for flash attention using pos ids""" |
|
if len(position_ids.shape) == 1: |
|
position_ids = position_ids.unsqueeze(0) |
|
|
|
device = position_ids.device |
|
results = [] |
|
max_seq_lens = [] |
|
|
|
for row in position_ids: |
|
|
|
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item() |
|
|
|
|
|
adjusted_row = row[:-padding_length] if padding_length else row.clone() |
|
|
|
|
|
seq_starts = torch.cat( |
|
[ |
|
torch.tensor([True], dtype=torch.bool, device=device), |
|
adjusted_row[1:] == 0, |
|
] |
|
) |
|
|
|
start_indices = torch.cat( |
|
[ |
|
torch.nonzero(seq_starts).unbind(dim=1)[0], |
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device), |
|
] |
|
) |
|
|
|
seq_lengths = start_indices[1:] - start_indices[:-1] |
|
|
|
cu_seqlens = torch.cat( |
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)] |
|
) |
|
|
|
if padding_length: |
|
cu_seqlens = torch.cat( |
|
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)] |
|
) |
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max() |
|
results.append(cu_seqlens) |
|
max_seq_lens.append(max_seq_len) |
|
|
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) |
|
|
|
|
|
def set_module_name(model, name, value): |
|
if "." in name: |
|
parent_name = name.rsplit(".", 1)[0] |
|
child_name = name[len(parent_name) + 1 :] |
|
parent = model.get_submodule(parent_name) |
|
else: |
|
parent_name = "" |
|
parent = model |
|
child_name = name |
|
|
|
setattr(parent, child_name, value) |
|
|