|
|
|
|
|
|
|
|
|
|
|
from functools import lru_cache |
|
|
|
import numpy as np |
|
import torch |
|
from fairseq.data import Dictionary, data_utils |
|
|
|
from . import BaseWrapperDataset, LRUCacheDataset |
|
|
|
|
|
class MaskTokensDataset(BaseWrapperDataset): |
|
""" |
|
A wrapper Dataset for masked language modeling. |
|
|
|
Input items are masked according to the specified masking probability. |
|
|
|
Args: |
|
dataset: Dataset to wrap. |
|
sizes: Sentence lengths |
|
vocab: Dictionary with the vocabulary and special tokens. |
|
pad_idx: Id of pad token in vocab |
|
mask_idx: Id of mask token in vocab |
|
return_masked_tokens: controls whether to return the non-masked tokens |
|
(the default) or to return a tensor with the original masked token |
|
IDs (and *pad_idx* elsewhere). The latter is useful as targets for |
|
masked LM training. |
|
seed: Seed for random number generator for reproducibility. |
|
mask_prob: probability of replacing a token with *mask_idx*. |
|
leave_unmasked_prob: probability that a masked token is unmasked. |
|
random_token_prob: probability of replacing a masked token with a |
|
random token from the vocabulary. |
|
freq_weighted_replacement: sample random replacement words based on |
|
word frequencies in the vocab. |
|
mask_whole_words: only mask whole words. This should be a byte mask |
|
over vocab indices, indicating whether it is the beginning of a |
|
word. We will extend any mask to encompass the whole word. |
|
bpe: BPE to use for whole-word masking. |
|
mask_multiple_length : repeat each mask index multiple times. Default |
|
value is 1. |
|
mask_stdev : standard deviation of masks distribution in case of |
|
multiple masking. Default value is 0. |
|
""" |
|
|
|
@classmethod |
|
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs): |
|
"""Return the source and target datasets for masked LM training.""" |
|
dataset = LRUCacheDataset(dataset) |
|
return ( |
|
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), |
|
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), |
|
) |
|
|
|
def __init__( |
|
self, |
|
dataset: torch.utils.data.Dataset, |
|
vocab: Dictionary, |
|
pad_idx: int, |
|
mask_idx: int, |
|
return_masked_tokens: bool = False, |
|
seed: int = 1, |
|
mask_prob: float = 0.15, |
|
leave_unmasked_prob: float = 0.1, |
|
random_token_prob: float = 0.1, |
|
freq_weighted_replacement: bool = False, |
|
mask_whole_words: torch.Tensor = None, |
|
mask_multiple_length: int = 1, |
|
mask_stdev: float = 0.0, |
|
): |
|
assert 0.0 < mask_prob < 1.0 |
|
assert 0.0 <= random_token_prob <= 1.0 |
|
assert 0.0 <= leave_unmasked_prob <= 1.0 |
|
assert random_token_prob + leave_unmasked_prob <= 1.0 |
|
assert mask_multiple_length >= 1 |
|
assert mask_stdev >= 0.0 |
|
|
|
self.dataset = dataset |
|
self.vocab = vocab |
|
self.pad_idx = pad_idx |
|
self.mask_idx = mask_idx |
|
self.return_masked_tokens = return_masked_tokens |
|
self.seed = seed |
|
self.mask_prob = mask_prob |
|
self.leave_unmasked_prob = leave_unmasked_prob |
|
self.random_token_prob = random_token_prob |
|
self.mask_whole_words = mask_whole_words |
|
self.mask_multiple_length = mask_multiple_length |
|
self.mask_stdev = mask_stdev |
|
|
|
if random_token_prob > 0.0: |
|
if freq_weighted_replacement: |
|
weights = np.array(self.vocab.count) |
|
else: |
|
weights = np.ones(len(self.vocab)) |
|
weights[: self.vocab.nspecial] = 0 |
|
self.weights = weights / weights.sum() |
|
|
|
self.epoch = 0 |
|
|
|
@property |
|
def can_reuse_epoch_itr_across_epochs(self): |
|
return True |
|
|
|
def set_epoch(self, epoch, **unused): |
|
super().set_epoch(epoch) |
|
self.epoch = epoch |
|
|
|
def __getitem__(self, index: int): |
|
return self.__getitem_cached__(self.seed, self.epoch, index) |
|
|
|
@lru_cache(maxsize=8) |
|
def __getitem_cached__(self, seed: int, epoch: int, index: int): |
|
with data_utils.numpy_seed(self.seed, self.epoch, index): |
|
item = self.dataset[index] |
|
sz = len(item) |
|
|
|
assert ( |
|
self.mask_idx not in item |
|
), "Dataset contains mask_idx (={}), this is not expected!".format( |
|
self.mask_idx, |
|
) |
|
|
|
if self.mask_whole_words is not None: |
|
word_begins_mask = self.mask_whole_words.gather(0, item) |
|
word_begins_idx = word_begins_mask.nonzero().view(-1) |
|
sz = len(word_begins_idx) |
|
words = np.split(word_begins_mask, word_begins_idx)[1:] |
|
assert len(words) == sz |
|
word_lens = list(map(len, words)) |
|
|
|
|
|
mask = np.full(sz, False) |
|
num_mask = int( |
|
|
|
self.mask_prob * sz / float(self.mask_multiple_length) |
|
+ np.random.rand() |
|
) |
|
|
|
|
|
mask_idc = np.random.choice(sz, num_mask, replace=False) |
|
if self.mask_stdev > 0.0: |
|
lengths = np.random.normal( |
|
self.mask_multiple_length, self.mask_stdev, size=num_mask |
|
) |
|
lengths = [max(0, int(round(x))) for x in lengths] |
|
mask_idc = np.asarray( |
|
[ |
|
mask_idc[j] + offset |
|
for j in range(len(mask_idc)) |
|
for offset in range(lengths[j]) |
|
], |
|
dtype=np.int64, |
|
) |
|
else: |
|
mask_idc = np.concatenate( |
|
[mask_idc + i for i in range(self.mask_multiple_length)] |
|
) |
|
mask_idc = mask_idc[mask_idc < len(mask)] |
|
try: |
|
mask[mask_idc] = True |
|
except: |
|
print( |
|
"Assigning mask indexes {} to mask {} failed!".format( |
|
mask_idc, mask |
|
) |
|
) |
|
raise |
|
|
|
if self.return_masked_tokens: |
|
|
|
|
|
if self.mask_whole_words is not None: |
|
mask = np.repeat(mask, word_lens) |
|
new_item = np.full(len(mask), self.pad_idx) |
|
new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] |
|
return torch.from_numpy(new_item) |
|
|
|
|
|
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob |
|
if rand_or_unmask_prob > 0.0: |
|
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) |
|
if self.random_token_prob == 0.0: |
|
unmask = rand_or_unmask |
|
rand_mask = None |
|
elif self.leave_unmasked_prob == 0.0: |
|
unmask = None |
|
rand_mask = rand_or_unmask |
|
else: |
|
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob |
|
decision = np.random.rand(sz) < unmask_prob |
|
unmask = rand_or_unmask & decision |
|
rand_mask = rand_or_unmask & (~decision) |
|
else: |
|
unmask = rand_mask = None |
|
|
|
if unmask is not None: |
|
mask = mask ^ unmask |
|
|
|
if self.mask_whole_words is not None: |
|
mask = np.repeat(mask, word_lens) |
|
|
|
new_item = np.copy(item) |
|
new_item[mask] = self.mask_idx |
|
if rand_mask is not None: |
|
num_rand = rand_mask.sum() |
|
if num_rand > 0: |
|
if self.mask_whole_words is not None: |
|
rand_mask = np.repeat(rand_mask, word_lens) |
|
num_rand = rand_mask.sum() |
|
|
|
new_item[rand_mask] = np.random.choice( |
|
len(self.vocab), |
|
num_rand, |
|
p=self.weights, |
|
) |
|
|
|
return torch.from_numpy(new_item) |
|
|