|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from . import FairseqDataset, data_utils |
|
|
|
|
|
def collate( |
|
samples, |
|
pad_idx, |
|
eos_idx, |
|
vocab, |
|
left_pad_source=False, |
|
left_pad_target=False, |
|
input_feeding=True, |
|
pad_to_length=None, |
|
): |
|
assert input_feeding |
|
if len(samples) == 0: |
|
return {} |
|
|
|
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): |
|
return data_utils.collate_tokens( |
|
[s[key] for s in samples], |
|
pad_idx, |
|
eos_idx=None, |
|
left_pad=left_pad, |
|
move_eos_to_beginning=move_eos_to_beginning, |
|
pad_to_length=pad_to_length, |
|
) |
|
|
|
id = torch.LongTensor([s["id"] for s in samples]) |
|
src_tokens = merge( |
|
"source", |
|
left_pad=left_pad_source, |
|
pad_to_length=pad_to_length["source"] if pad_to_length is not None else None, |
|
) |
|
|
|
src_lengths = torch.LongTensor([s["source"].numel() for s in samples]) |
|
src_lengths, sort_order = src_lengths.sort(descending=True) |
|
id = id.index_select(0, sort_order) |
|
src_tokens = src_tokens.index_select(0, sort_order) |
|
|
|
prev_output_tokens = None |
|
target = None |
|
if samples[0].get("target", None) is not None: |
|
target = merge( |
|
"target", |
|
left_pad=left_pad_target, |
|
pad_to_length=pad_to_length["target"] |
|
if pad_to_length is not None |
|
else None, |
|
) |
|
target = target.index_select(0, sort_order) |
|
ntokens = sum(len(s["target"]) for s in samples) |
|
|
|
if input_feeding: |
|
|
|
|
|
prev_output_tokens = merge( |
|
"target", |
|
left_pad=left_pad_target, |
|
move_eos_to_beginning=True, |
|
pad_to_length=pad_to_length["target"] |
|
if pad_to_length is not None |
|
else None, |
|
) |
|
prev_output_tokens = prev_output_tokens.index_select(0, sort_order) |
|
else: |
|
ntokens = sum(len(s["source"]) for s in samples) |
|
|
|
batch = { |
|
"id": id, |
|
"ntokens": ntokens, |
|
"net_input": { |
|
"src_tokens": src_tokens, |
|
"src_lengths": src_lengths, |
|
}, |
|
"target": target, |
|
"nsentences": samples[0]["source"].size(0), |
|
"sort_order": sort_order, |
|
} |
|
if prev_output_tokens is not None: |
|
batch["net_input"]["prev_output_tokens"] = prev_output_tokens |
|
|
|
return batch |
|
|
|
|
|
class DenoisingDataset(FairseqDataset): |
|
""" |
|
A wrapper around TokenBlockDataset for BART dataset. |
|
|
|
Args: |
|
dataset (TokenBlockDataset): dataset to wrap |
|
sizes (List[int]): sentence lengths |
|
vocab (~fairseq.data.Dictionary): vocabulary |
|
mask_idx (int): dictionary index used for masked token |
|
mask_whole_words: only mask whole words. This should be a byte mask |
|
over vocab indices, indicating whether it is the beginning of a |
|
word. We will extend any mask to encompass the whole word. |
|
shuffle (bool, optional): shuffle the elements before batching. |
|
Default: ``True`` |
|
seed: Seed for random number generator for reproducibility. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
sizes, |
|
vocab, |
|
mask_idx, |
|
mask_whole_words, |
|
shuffle, |
|
seed, |
|
mask, |
|
mask_random, |
|
insert, |
|
rotate, |
|
permute_sentences, |
|
bpe, |
|
replace_length, |
|
mask_length, |
|
poisson_lambda, |
|
eos=None, |
|
item_transform_func=None, |
|
): |
|
self.dataset = dataset |
|
|
|
self.sizes = sizes |
|
|
|
self.vocab = vocab |
|
self.shuffle = shuffle |
|
self.seed = seed |
|
self.mask_idx = mask_idx |
|
self.mask_whole_word = mask_whole_words |
|
self.mask_ratio = mask |
|
self.random_ratio = mask_random |
|
self.insert_ratio = insert |
|
self.rotate_ratio = rotate |
|
self.permute_sentence_ratio = permute_sentences |
|
self.eos = eos if eos is not None else vocab.eos() |
|
self.item_transform_func = item_transform_func |
|
|
|
if bpe != "gpt2": |
|
self.full_stop_index = self.vocab.eos() |
|
else: |
|
assert bpe == "gpt2" |
|
self.full_stop_index = self.vocab.index("13") |
|
|
|
self.replace_length = replace_length |
|
if self.replace_length not in [-1, 0, 1]: |
|
raise ValueError(f"invalid arg: replace_length={self.replace_length}") |
|
if mask_length not in ["subword", "word", "span-poisson"]: |
|
raise ValueError(f"invalid arg: mask-length={mask_length}") |
|
if mask_length == "subword" and replace_length not in [0, 1]: |
|
raise ValueError(f"if using subwords, use replace-length=1 or 0") |
|
|
|
self.mask_span_distribution = None |
|
if mask_length == "span-poisson": |
|
_lambda = poisson_lambda |
|
|
|
lambda_to_the_k = 1 |
|
e_to_the_minus_lambda = math.exp(-_lambda) |
|
k_factorial = 1 |
|
ps = [] |
|
for k in range(0, 128): |
|
ps.append(e_to_the_minus_lambda * lambda_to_the_k / k_factorial) |
|
lambda_to_the_k *= _lambda |
|
k_factorial *= k + 1 |
|
if ps[-1] < 0.0000001: |
|
break |
|
ps = torch.FloatTensor(ps) |
|
self.mask_span_distribution = torch.distributions.Categorical(ps) |
|
|
|
self.epoch = 0 |
|
|
|
@property |
|
def can_reuse_epoch_itr_across_epochs(self): |
|
return True |
|
|
|
def set_epoch(self, epoch, **unused): |
|
self.epoch = epoch |
|
|
|
def __getitem__(self, index): |
|
with data_utils.numpy_seed(self.seed, self.epoch, index): |
|
tokens = self.dataset[index] |
|
assert tokens[-1] == self.eos |
|
source, target = tokens, tokens.clone() |
|
|
|
if self.permute_sentence_ratio > 0.0: |
|
source = self.permute_sentences(source, self.permute_sentence_ratio) |
|
|
|
if self.mask_ratio > 0: |
|
source = self.add_whole_word_mask(source, self.mask_ratio) |
|
|
|
if self.insert_ratio > 0: |
|
source = self.add_insertion_noise(source, self.insert_ratio) |
|
|
|
if self.rotate_ratio > 0.0 and np.random.random() < self.rotate_ratio: |
|
source = self.add_rolling_noise(source) |
|
|
|
if self.item_transform_func is not None: |
|
source, target = self.item_transform_func(source, target) |
|
|
|
assert (source >= 0).all() |
|
assert (source[1:-1] >= 1).all() |
|
assert (source <= len(self.vocab)).all() |
|
assert source[0] == self.vocab.bos() |
|
assert source[-1] == self.eos |
|
return { |
|
"id": index, |
|
"source": source, |
|
"target": target, |
|
} |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def permute_sentences(self, source, p=1.0): |
|
full_stops = source == self.full_stop_index |
|
|
|
full_stops[-2] = 1 |
|
|
|
|
|
sentence_ends = (full_stops[1:] * ~full_stops[:-1]).nonzero(as_tuple=False) + 2 |
|
result = source.clone() |
|
|
|
num_sentences = sentence_ends.size(0) |
|
num_to_permute = math.ceil((num_sentences * 2 * p) / 2.0) |
|
substitutions = torch.randperm(num_sentences)[:num_to_permute] |
|
ordering = torch.arange(0, num_sentences) |
|
ordering[substitutions] = substitutions[torch.randperm(num_to_permute)] |
|
|
|
|
|
index = 1 |
|
for i in ordering: |
|
sentence = source[(sentence_ends[i - 1] if i > 0 else 1) : sentence_ends[i]] |
|
result[index : index + sentence.size(0)] = sentence |
|
index += sentence.size(0) |
|
return result |
|
|
|
def word_starts(self, source): |
|
if self.mask_whole_word is not None: |
|
is_word_start = self.mask_whole_word.gather(0, source) |
|
else: |
|
is_word_start = torch.ones(source.size()) |
|
is_word_start[0] = 0 |
|
is_word_start[-1] = 0 |
|
return is_word_start |
|
|
|
def add_whole_word_mask(self, source, p): |
|
is_word_start = self.word_starts(source) |
|
num_to_mask = int(math.ceil(is_word_start.float().sum() * p)) |
|
num_inserts = 0 |
|
if num_to_mask == 0: |
|
return source |
|
|
|
if self.mask_span_distribution is not None: |
|
lengths = self.mask_span_distribution.sample(sample_shape=(num_to_mask,)) |
|
|
|
|
|
cum_length = torch.cumsum(lengths, 0) |
|
while cum_length[-1] < num_to_mask: |
|
lengths = torch.cat( |
|
[ |
|
lengths, |
|
self.mask_span_distribution.sample(sample_shape=(num_to_mask,)), |
|
], |
|
dim=0, |
|
) |
|
cum_length = torch.cumsum(lengths, 0) |
|
|
|
|
|
i = 0 |
|
while cum_length[i] < num_to_mask: |
|
i += 1 |
|
lengths[i] = num_to_mask - (0 if i == 0 else cum_length[i - 1]) |
|
num_to_mask = i + 1 |
|
lengths = lengths[:num_to_mask] |
|
|
|
|
|
lengths = lengths[lengths > 0] |
|
num_inserts = num_to_mask - lengths.size(0) |
|
num_to_mask -= num_inserts |
|
if num_to_mask == 0: |
|
return self.add_insertion_noise(source, num_inserts / source.size(0)) |
|
|
|
assert (lengths > 0).all() |
|
else: |
|
lengths = torch.ones((num_to_mask,)).long() |
|
assert is_word_start[-1] == 0 |
|
word_starts = is_word_start.nonzero(as_tuple=False) |
|
indices = word_starts[ |
|
torch.randperm(word_starts.size(0))[:num_to_mask] |
|
].squeeze(1) |
|
mask_random = torch.FloatTensor(num_to_mask).uniform_() < self.random_ratio |
|
|
|
source_length = source.size(0) |
|
assert source_length - 1 not in indices |
|
to_keep = torch.ones(source_length, dtype=torch.bool) |
|
is_word_start[ |
|
-1 |
|
] = 255 |
|
if self.replace_length == 0: |
|
to_keep[indices] = 0 |
|
else: |
|
|
|
source[indices] = self.mask_idx |
|
source[indices[mask_random]] = torch.randint( |
|
1, len(self.vocab), size=(mask_random.sum(),) |
|
) |
|
|
|
if self.mask_span_distribution is not None: |
|
assert len(lengths.size()) == 1 |
|
assert lengths.size() == indices.size() |
|
lengths -= 1 |
|
while indices.size(0) > 0: |
|
assert lengths.size() == indices.size() |
|
lengths -= is_word_start[indices + 1].long() |
|
uncompleted = lengths >= 0 |
|
indices = indices[uncompleted] + 1 |
|
mask_random = mask_random[uncompleted] |
|
lengths = lengths[uncompleted] |
|
if self.replace_length != -1: |
|
|
|
to_keep[indices] = 0 |
|
else: |
|
|
|
source[indices] = self.mask_idx |
|
source[indices[mask_random]] = torch.randint( |
|
1, len(self.vocab), size=(mask_random.sum(),) |
|
) |
|
else: |
|
|
|
while indices.size(0) > 0: |
|
uncompleted = is_word_start[indices + 1] == 0 |
|
indices = indices[uncompleted] + 1 |
|
mask_random = mask_random[uncompleted] |
|
if self.replace_length != -1: |
|
|
|
to_keep[indices] = 0 |
|
else: |
|
|
|
source[indices] = self.mask_idx |
|
source[indices[mask_random]] = torch.randint( |
|
1, len(self.vocab), size=(mask_random.sum(),) |
|
) |
|
|
|
assert source_length - 1 not in indices |
|
|
|
source = source[to_keep] |
|
|
|
if num_inserts > 0: |
|
source = self.add_insertion_noise(source, num_inserts / source.size(0)) |
|
|
|
return source |
|
|
|
def add_permuted_noise(self, tokens, p): |
|
num_words = len(tokens) |
|
num_to_permute = math.ceil(((num_words * 2) * p) / 2.0) |
|
substitutions = torch.randperm(num_words - 2)[:num_to_permute] + 1 |
|
tokens[substitutions] = tokens[substitutions[torch.randperm(num_to_permute)]] |
|
return tokens |
|
|
|
def add_rolling_noise(self, tokens): |
|
offset = np.random.randint(1, max(1, tokens.size(-1) - 1) + 1) |
|
tokens = torch.cat( |
|
(tokens[0:1], tokens[offset:-1], tokens[1:offset], tokens[-1:]), |
|
dim=0, |
|
) |
|
return tokens |
|
|
|
def add_insertion_noise(self, tokens, p): |
|
if p == 0.0: |
|
return tokens |
|
|
|
num_tokens = len(tokens) |
|
n = int(math.ceil(num_tokens * p)) |
|
|
|
noise_indices = torch.randperm(num_tokens + n - 2)[:n] + 1 |
|
noise_mask = torch.zeros(size=(num_tokens + n,), dtype=torch.bool) |
|
noise_mask[noise_indices] = 1 |
|
result = torch.LongTensor(n + len(tokens)).fill_(-1) |
|
|
|
num_random = int(math.ceil(n * self.random_ratio)) |
|
result[noise_indices[num_random:]] = self.mask_idx |
|
result[noise_indices[:num_random]] = torch.randint( |
|
low=1, high=len(self.vocab), size=(num_random,) |
|
) |
|
|
|
result[~noise_mask] = tokens |
|
|
|
assert (result >= 0).all() |
|
return result |
|
|
|
def collater(self, samples, pad_to_length=None): |
|
"""Merge a list of samples to form a mini-batch. |
|
Args: |
|
samples (List[dict]): samples to collate |
|
Returns: |
|
dict: a mini-batch of data |
|
""" |
|
return collate( |
|
samples, self.vocab.pad(), self.eos, self.vocab, pad_to_length=pad_to_length |
|
) |
|
|
|
def num_tokens(self, index): |
|
"""Return the number of tokens in a sample. This value is used to |
|
enforce ``--max-tokens`` during batching.""" |
|
return self.sizes[index] |
|
|
|
def size(self, index): |
|
"""Return an example's size as a float or tuple. This value is used when |
|
filtering a dataset with ``--max-positions``.""" |
|
return self.sizes[index] |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
if self.shuffle: |
|
indices = np.random.permutation(len(self)) |
|
else: |
|
indices = np.arange(len(self)) |
|
return indices[np.argsort(self.sizes[indices], kind="mergesort")] |
|
|
|
def prefetch(self, indices): |
|
self.src.prefetch(indices) |
|
self.tgt.prefetch(indices) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return ( |
|
hasattr(self.src, "supports_prefetch") |
|
and self.src.supports_prefetch |
|
and hasattr(self.tgt, "supports_prefetch") |
|
and self.tgt.supports_prefetch |
|
) |
|
|