|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
''' |
|
def subsequent_mask( |
|
size: int, |
|
device: torch.device = torch.device("cpu"), |
|
) -> torch.Tensor: |
|
"""Create mask for subsequent steps (size, size). |
|
|
|
This mask is used only in decoder which works in an auto-regressive mode. |
|
This means the current step could only do attention with its left steps. |
|
|
|
In encoder, fully attention is used when streaming is not necessary and |
|
the sequence is not long. In this case, no attention mask is needed. |
|
|
|
When streaming is need, chunk-based attention is used in encoder. See |
|
subsequent_chunk_mask for the chunk-based attention mask. |
|
|
|
Args: |
|
size (int): size of mask |
|
str device (str): "cpu" or "cuda" or torch.Tensor.device |
|
dtype (torch.device): result dtype |
|
|
|
Returns: |
|
torch.Tensor: mask |
|
|
|
Examples: |
|
>>> subsequent_mask(3) |
|
[[1, 0, 0], |
|
[1, 1, 0], |
|
[1, 1, 1]] |
|
""" |
|
ret = torch.ones(size, size, device=device, dtype=torch.bool) |
|
return torch.tril(ret) |
|
''' |
|
|
|
|
|
def subsequent_mask( |
|
size: int, |
|
device: torch.device = torch.device("cpu"), |
|
) -> torch.Tensor: |
|
"""Create mask for subsequent steps (size, size). |
|
|
|
This mask is used only in decoder which works in an auto-regressive mode. |
|
This means the current step could only do attention with its left steps. |
|
|
|
In encoder, fully attention is used when streaming is not necessary and |
|
the sequence is not long. In this case, no attention mask is needed. |
|
|
|
When streaming is need, chunk-based attention is used in encoder. See |
|
subsequent_chunk_mask for the chunk-based attention mask. |
|
|
|
Args: |
|
size (int): size of mask |
|
str device (str): "cpu" or "cuda" or torch.Tensor.device |
|
dtype (torch.device): result dtype |
|
|
|
Returns: |
|
torch.Tensor: mask |
|
|
|
Examples: |
|
>>> subsequent_mask(3) |
|
[[1, 0, 0], |
|
[1, 1, 0], |
|
[1, 1, 1]] |
|
""" |
|
arange = torch.arange(size, device=device) |
|
mask = arange.expand(size, size) |
|
arange = arange.unsqueeze(-1) |
|
mask = mask <= arange |
|
return mask |
|
|
|
|
|
def subsequent_chunk_mask( |
|
size: int, |
|
chunk_size: int, |
|
num_left_chunks: int = -1, |
|
device: torch.device = torch.device("cpu"), |
|
) -> torch.Tensor: |
|
"""Create mask for subsequent steps (size, size) with chunk size, |
|
this is for streaming encoder |
|
|
|
Args: |
|
size (int): size of mask |
|
chunk_size (int): size of chunk |
|
num_left_chunks (int): number of left chunks |
|
<0: use full chunk |
|
>=0: use num_left_chunks |
|
device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
|
|
|
Returns: |
|
torch.Tensor: mask |
|
|
|
Examples: |
|
>>> subsequent_chunk_mask(4, 2) |
|
[[1, 1, 0, 0], |
|
[1, 1, 0, 0], |
|
[1, 1, 1, 1], |
|
[1, 1, 1, 1]] |
|
""" |
|
ret = torch.zeros(size, size, device=device, dtype=torch.bool) |
|
for i in range(size): |
|
if num_left_chunks < 0: |
|
start = 0 |
|
else: |
|
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) |
|
ending = min((i // chunk_size + 1) * chunk_size, size) |
|
ret[i, start:ending] = True |
|
return ret |
|
|
|
|
|
def add_optional_chunk_mask( |
|
xs: torch.Tensor, |
|
masks: torch.Tensor, |
|
use_dynamic_chunk: bool, |
|
use_dynamic_left_chunk: bool, |
|
decoding_chunk_size: int, |
|
static_chunk_size: int, |
|
num_decoding_left_chunks: int, |
|
): |
|
"""Apply optional mask for encoder. |
|
|
|
Args: |
|
xs (torch.Tensor): padded input, (B, L, D), L for max length |
|
mask (torch.Tensor): mask for xs, (B, 1, L) |
|
use_dynamic_chunk (bool): whether to use dynamic chunk or not |
|
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for |
|
training. |
|
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's |
|
0: default for training, use random dynamic chunk. |
|
<0: for decoding, use full chunk. |
|
>0: for decoding, use fixed chunk size as set. |
|
static_chunk_size (int): chunk size for static chunk training/decoding |
|
if it's greater than 0, if use_dynamic_chunk is true, |
|
this parameter will be ignored |
|
num_decoding_left_chunks: number of left chunks, this is for decoding, |
|
the chunk size is decoding_chunk_size. |
|
>=0: use num_decoding_left_chunks |
|
<0: use all left chunks |
|
|
|
Returns: |
|
torch.Tensor: chunk mask of the input xs. |
|
""" |
|
|
|
if use_dynamic_chunk: |
|
max_len = xs.size(1) |
|
if decoding_chunk_size < 0: |
|
chunk_size = max_len |
|
num_left_chunks = -1 |
|
elif decoding_chunk_size > 0: |
|
chunk_size = decoding_chunk_size |
|
num_left_chunks = num_decoding_left_chunks |
|
else: |
|
|
|
|
|
|
|
chunk_size = torch.randint(1, max_len, (1,)).item() |
|
num_left_chunks = -1 |
|
if chunk_size > max_len // 2: |
|
chunk_size = max_len |
|
else: |
|
chunk_size = chunk_size % 25 + 1 |
|
if use_dynamic_left_chunk: |
|
max_left_chunks = (max_len - 1) // chunk_size |
|
num_left_chunks = torch.randint(0, max_left_chunks, (1,)).item() |
|
chunk_masks = subsequent_chunk_mask( |
|
xs.size(1), chunk_size, num_left_chunks, xs.device |
|
) |
|
chunk_masks = chunk_masks.unsqueeze(0) |
|
chunk_masks = masks & chunk_masks |
|
elif static_chunk_size > 0: |
|
num_left_chunks = num_decoding_left_chunks |
|
chunk_masks = subsequent_chunk_mask( |
|
xs.size(1), static_chunk_size, num_left_chunks, xs.device |
|
) |
|
chunk_masks = chunk_masks.unsqueeze(0) |
|
chunk_masks = masks & chunk_masks |
|
else: |
|
chunk_masks = masks |
|
return chunk_masks |
|
|
|
|
|
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
|
"""Make mask tensor containing indices of padded part. |
|
|
|
See description of make_non_pad_mask. |
|
|
|
Args: |
|
lengths (torch.Tensor): Batch of lengths (B,). |
|
Returns: |
|
torch.Tensor: Mask tensor containing indices of padded part. |
|
|
|
Examples: |
|
>>> lengths = [5, 3, 2] |
|
>>> make_pad_mask(lengths) |
|
masks = [[0, 0, 0, 0 ,0], |
|
[0, 0, 0, 1, 1], |
|
[0, 0, 1, 1, 1]] |
|
""" |
|
batch_size = lengths.size(0) |
|
max_len = max_len if max_len > 0 else lengths.max().item() |
|
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) |
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) |
|
seq_length_expand = lengths.unsqueeze(-1) |
|
mask = seq_range_expand >= seq_length_expand |
|
return mask |
|
|
|
|
|
def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: |
|
"""Make mask tensor containing indices of non-padded part. |
|
|
|
The sequences in a batch may have different lengths. To enable |
|
batch computing, padding is need to make all sequence in same |
|
size. To avoid the padding part pass value to context dependent |
|
block such as attention or convolution , this padding part is |
|
masked. |
|
|
|
This pad_mask is used in both encoder and decoder. |
|
|
|
1 for non-padded part and 0 for padded part. |
|
|
|
Args: |
|
lengths (torch.Tensor): Batch of lengths (B,). |
|
Returns: |
|
torch.Tensor: mask tensor containing indices of padded part. |
|
|
|
Examples: |
|
>>> lengths = [5, 3, 2] |
|
>>> make_non_pad_mask(lengths) |
|
masks = [[1, 1, 1, 1 ,1], |
|
[1, 1, 1, 0, 0], |
|
[1, 1, 0, 0, 0]] |
|
""" |
|
return ~make_pad_mask(lengths) |
|
|
|
|
|
def mask_finished_scores(score: torch.Tensor, flag: torch.Tensor) -> torch.Tensor: |
|
""" |
|
If a sequence is finished, we only allow one alive branch. This function |
|
aims to give one branch a zero score and the rest -inf score. |
|
|
|
Args: |
|
score (torch.Tensor): A real value array with shape |
|
(batch_size * beam_size, beam_size). |
|
flag (torch.Tensor): A bool array with shape |
|
(batch_size * beam_size, 1). |
|
|
|
Returns: |
|
torch.Tensor: (batch_size * beam_size, beam_size). |
|
""" |
|
beam_size = score.size(-1) |
|
zero_mask = torch.zeros_like(flag, dtype=torch.bool) |
|
if beam_size > 1: |
|
unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])), dim=1) |
|
finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])), dim=1) |
|
else: |
|
unfinished = zero_mask |
|
finished = flag |
|
score.masked_fill_(unfinished, -float("inf")) |
|
score.masked_fill_(finished, 0) |
|
return score |
|
|
|
|
|
def mask_finished_preds( |
|
pred: torch.Tensor, flag: torch.Tensor, eos: int |
|
) -> torch.Tensor: |
|
""" |
|
If a sequence is finished, all of its branch should be <eos> |
|
|
|
Args: |
|
pred (torch.Tensor): A int array with shape |
|
(batch_size * beam_size, beam_size). |
|
flag (torch.Tensor): A bool array with shape |
|
(batch_size * beam_size, 1). |
|
|
|
Returns: |
|
torch.Tensor: (batch_size * beam_size). |
|
""" |
|
beam_size = pred.size(-1) |
|
finished = flag.repeat([1, beam_size]) |
|
return pred.masked_fill_(finished, eos) |
|
|