File size: 12,168 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, List, Tuple
import numpy as np
import torch
from fairseq.data import Dictionary, FairseqDataset, data_utils
from fairseq.data.concat_dataset import ConcatDataset
from fairseq.data.legacy.block_pair_dataset import BlockPairDataset
from fairseq.data.token_block_dataset import TokenBlockDataset
class MaskedLMDataset(FairseqDataset):
"""
A wrapper Dataset for masked language modelling. The dataset
wraps around TokenBlockDataset or BlockedPairDataset and creates a batch
where the input blocks are masked according to the specified masking
probability. Additionally the batch can also contain sentence level targets
if this is specified.
Args:
dataset: Dataset which generates blocks of data. Only BlockPairDataset
and TokenBlockDataset are supported.
sizes: Sentence lengths
vocab: Dictionary with the vocabulary and special tokens.
pad_idx: Id of padding token in dictionary
mask_idx: Id of mask token in dictionary
classif_token_idx: Id of classification token in dictionary. This is the
token associated with the sentence embedding (Eg: CLS for BERT)
sep_token_idx: Id of separator token in dictionary
(Eg: SEP in BERT)
seed: Seed for random number generator for reproducibility.
shuffle: Shuffle the elements before batching.
has_pairs: Specifies whether the underlying dataset
generates a pair of blocks along with a sentence_target or not.
Setting it to True assumes that the underlying dataset generates a
label for the pair of sentences which is surfaced as
sentence_target. The default value assumes a single block with no
sentence target.
segment_id: An optional segment id for filling in the segment labels
when we are in the single block setting (Eg: XLM). Default is 0.
masking_ratio: specifies what percentage of the blocks should be masked.
masking_prob: specifies the probability of a given token being
replaced with the "MASK" token.
random_token_prob: specifies the probability of a given token being
replaced by a random token from the vocabulary.
"""
def __init__(
self,
dataset: FairseqDataset,
sizes: np.ndarray,
vocab: Dictionary,
pad_idx: int,
mask_idx: int,
classif_token_idx: int,
sep_token_idx: int,
seed: int = 1,
shuffle: bool = True,
has_pairs: bool = True,
segment_id: int = 0,
masking_ratio: float = 0.15,
masking_prob: float = 0.8,
random_token_prob: float = 0.1,
):
# Make sure the input datasets are the ones supported
assert (
isinstance(dataset, TokenBlockDataset)
or isinstance(dataset, BlockPairDataset)
or isinstance(dataset, ConcatDataset)
), (
"MaskedLMDataset only wraps TokenBlockDataset or BlockPairDataset or "
"ConcatDataset"
)
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.pad_idx = pad_idx
self.mask_idx = mask_idx
self.classif_token_idx = classif_token_idx
self.sep_token_idx = sep_token_idx
self.shuffle = shuffle
self.seed = seed
self.has_pairs = has_pairs
self.segment_id = segment_id
self.masking_ratio = masking_ratio
self.masking_prob = masking_prob
self.random_token_prob = random_token_prob
# If we have only one block then sizes needs to be updated to include
# the classification token
if not has_pairs:
self.sizes = self.sizes + 1
def __getitem__(self, index: int):
# if has_pairs, then expect 2 blocks and a sentence target
if self.has_pairs:
(block_one, block_two, sentence_target) = self.dataset[index]
else:
block_one = self.dataset[index]
return {
"id": index,
"block_one": block_one,
"block_two": block_two if self.has_pairs else None,
"sentence_target": sentence_target if self.has_pairs else None,
}
def __len__(self):
return len(self.dataset)
def _mask_block(
self,
sentence: np.ndarray,
mask_idx: int,
pad_idx: int,
dictionary_token_range: Tuple,
):
"""
Mask tokens for Masked Language Model training
Samples mask_ratio tokens that will be predicted by LM.
Note:This function may not be efficient enough since we had multiple
conversions between np and torch, we can replace them with torch
operators later.
Args:
sentence: 1d tensor to be masked
mask_idx: index to use for masking the sentence
pad_idx: index to use for masking the target for tokens we aren't
predicting
dictionary_token_range: range of indices in dictionary which can
be used for random word replacement
(e.g. without special characters)
Return:
masked_sent: masked sentence
target: target with words which we are not predicting replaced
by pad_idx
"""
masked_sent = np.copy(sentence)
sent_length = len(sentence)
mask_num = math.ceil(sent_length * self.masking_ratio)
mask = np.random.choice(sent_length, mask_num, replace=False)
target = np.copy(sentence)
for i in range(sent_length):
if i in mask:
rand = np.random.random()
# replace with mask if probability is less than masking_prob
# (Eg: 0.8)
if rand < self.masking_prob:
masked_sent[i] = mask_idx
# replace with random token if probability is less than
# masking_prob + random_token_prob (Eg: 0.9)
elif rand < (self.masking_prob + self.random_token_prob):
# sample random token from dictionary
masked_sent[i] = np.random.randint(
dictionary_token_range[0], dictionary_token_range[1]
)
else:
target[i] = pad_idx
return masked_sent, target
def _collate(self, samples: List[Dict], pad_idx: int, eos_idx: int):
"""
Does the heavy lifting for creating a batch from the input list of
examples. The logic is as follows:
1. Mask the input blocks. In case has_pair is True then we have 2
blocks to mask.
2. Prepend the first masked block tensor with the special token
used as sentence embedding. Eg: CLS in BERT. This happens
irrespective of the value of has_pair.
3. If has_pair is True, then append the first masked block with the
special separator token (eg: SEP for BERT) and compute segment
label accordingly. In this case, also append the second masked
block with this special separator token and compute its segment
label.
4. For the targets tensor, prepend and append with padding index
accordingly.
5. Concatenate all tensors.
"""
if len(samples) == 0:
return {}
# To ensure determinism, we reset the state of the PRNG after every
# batch based on the seed and the first id of the batch. This ensures
# that across epochs we get the same mask for the same example. This
# is needed for reproducibility and is how BERT does masking
# TODO: Can we add deteminism without this constraint?
with data_utils.numpy_seed(self.seed + samples[0]["id"]):
for s in samples:
# token range is needed for replacing with random token during
# masking
token_range = (self.vocab.nspecial, len(self.vocab))
# mask according to specified probabilities.
masked_blk_one, masked_tgt_one = self._mask_block(
s["block_one"],
self.mask_idx,
self.pad_idx,
token_range,
)
tokens = np.concatenate([[self.classif_token_idx], masked_blk_one])
targets = np.concatenate([[self.pad_idx], masked_tgt_one])
segments = np.ones(len(tokens)) * self.segment_id
# if has_pairs is True then we need to add the SEP token to both
# the blocks after masking and re-compute segments based on the new
# lengths.
if self.has_pairs:
tokens_one = np.concatenate([tokens, [self.sep_token_idx]])
targets_one = np.concatenate([targets, [self.pad_idx]])
masked_blk_two, masked_tgt_two = self._mask_block(
s["block_two"], self.mask_idx, self.pad_idx, token_range
)
tokens_two = np.concatenate([masked_blk_two, [self.sep_token_idx]])
targets_two = np.concatenate([masked_tgt_two, [self.pad_idx]])
# block + 1 sep + 1 special (CLS)
segments_one = np.zeros(len(tokens_one))
# block + 1 sep
segments_two = np.ones(len(tokens_two))
tokens = np.concatenate([tokens_one, tokens_two])
targets = np.concatenate([targets_one, targets_two])
segments = np.concatenate([segments_one, segments_two])
s["source"] = torch.LongTensor(tokens)
s["segment_labels"] = torch.LongTensor(segments)
s["lm_target"] = torch.LongTensor(targets)
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False
)
return {
"id": torch.LongTensor([s["id"] for s in samples]),
"ntokens": sum(len(s["source"]) for s in samples),
"net_input": {
"src_tokens": merge("source"),
"segment_labels": merge("segment_labels"),
},
"lm_target": merge("lm_target"),
"sentence_target": torch.LongTensor([s["sentence_target"] for s in samples])
if self.has_pairs
else None,
"nsentences": len(samples),
}
def collater(self, samples: List[Dict]):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return self._collate(samples, self.vocab.pad(), self.vocab.eos())
def num_tokens(self, index: int):
"""
Return the number of tokens in a sample. This value is used to
enforce max-tokens during batching.
"""
return self.sizes[index]
def size(self, index: int):
"""
Return an example's size as a float or tuple. This value is used when
filtering a dataset with max-positions.
"""
return self.sizes[index]
def ordered_indices(self):
"""
Return an ordered list of indices. Batches will be constructed based
on this order.
"""
if self.shuffle:
return np.random.permutation(len(self))
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def prefetch(self, indices):
self.dataset.prefetch(indices)
|