File size: 6,291 Bytes
426f887 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# Copyright 2022 MosaicML Examples authors
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
# Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
"""Helper functions for padding and unpadding batches.
These functions are used extensively throughout the Mosaic BERT implementation
in `bert_layers.py`.
"""
from typing import Tuple, cast
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
class IndexFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor,
indices: torch.Tensor) -> torch.Tensor:
"""Get just the values of `input` which are at `indices`.
Arguments:
ctx: the autograd context object
input: (b, ...) 2+ dimensional tensor
indices: (num_idx) 1D tensor
"""
ctx.save_for_backward(indices)
assert input.ndim >= 2
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[
1:] # type: ignore
second_dim = other_shape.numel(
) # product of sizes of all but first dimension
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
return torch.gather(
rearrange(input, 'b ... -> b (...)'), # (b, ...) -> (b, second_dim)
0,
repeat(indices, 'z -> z d',
d=second_dim) # (indices,) -> (indices, second_dim)
).reshape(-1, *other_shape) # (num_idx, ...)
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
indices, = ctx.saved_tensors
assert grad_output.ndim >= 2
other_shape = grad_output.shape[1:]
grad_output = rearrange(grad_output, 'b ... -> b (...)')
grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]],
device=grad_output.device,
dtype=grad_output.dtype)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input.scatter_(0,
repeat(indices, 'z -> z d', d=grad_output.shape[1]),
grad_output)
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
index_first_axis = IndexFirstAxis.apply
class IndexPutFirstAxis(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, indices: torch.Tensor,
first_axis_dim) -> torch.Tensor:
ctx.save_for_backward(indices)
assert indices.ndim == 1
assert values.ndim >= 2
output = torch.zeros(first_axis_dim,
*values.shape[1:],
device=values.device,
dtype=values.dtype)
output[indices] = values
return output
@staticmethod
def backward(ctx,
grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
indices, = ctx.saved_tensors
grad_values = grad_output[indices]
return grad_values, None, None
index_put_first_axis = IndexPutFirstAxis.apply
def unpad_input(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""Remove padding from input sequences.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Returns:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int ()
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32),
(1, 0))
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
hidden_states = cast(
torch.Tensor,
index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'),
indices))
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def unpad_input_only(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""Like unpad_input, but only return the unpadded first tensor.
Save a small amount of overhead.
Arguments:
hidden_states: (batch, seqlen, ...)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Returns:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
"""
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
rearranged = rearrange(hidden_states, 'b s ... -> (b s) ...')
return index_first_axis(rearranged, indices) # type: ignore
def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int,
seqlen: int) -> torch.Tensor:
"""Add padding to sequences.
Arguments:
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
batch: int batch_size
seqlen: int max sequence length
Returns:
hidden_states: (batch, seqlen, ...)
"""
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
return rearrange(output, '(b s) ... -> b s ...', b=batch) # type: ignore
|