|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
|
|
from . import FairseqDataset, data_utils |
|
|
|
|
|
def collate(samples, pad_idx, eos_idx, fixed_pad_length=None, pad_to_bsz=None): |
|
if len(samples) == 0: |
|
return {} |
|
|
|
def merge(key, is_list=False): |
|
if is_list: |
|
res = [] |
|
for i in range(len(samples[0][key])): |
|
res.append( |
|
data_utils.collate_tokens( |
|
[s[key][i] for s in samples], |
|
pad_idx, |
|
eos_idx, |
|
left_pad=False, |
|
pad_to_length=fixed_pad_length, |
|
pad_to_bsz=pad_to_bsz, |
|
) |
|
) |
|
return res |
|
else: |
|
return data_utils.collate_tokens( |
|
[s[key] for s in samples], |
|
pad_idx, |
|
eos_idx, |
|
left_pad=False, |
|
pad_to_length=fixed_pad_length, |
|
pad_to_bsz=pad_to_bsz, |
|
) |
|
|
|
src_tokens = merge("source") |
|
if samples[0]["target"] is not None: |
|
is_target_list = isinstance(samples[0]["target"], list) |
|
target = merge("target", is_target_list) |
|
else: |
|
target = src_tokens |
|
|
|
return { |
|
"id": torch.LongTensor([s["id"] for s in samples]), |
|
"nsentences": len(samples), |
|
"ntokens": sum(len(s["source"]) for s in samples), |
|
"net_input": { |
|
"src_tokens": src_tokens, |
|
"src_lengths": torch.LongTensor([s["source"].numel() for s in samples]), |
|
}, |
|
"target": target, |
|
} |
|
|
|
|
|
class MonolingualDataset(FairseqDataset): |
|
""" |
|
A wrapper around torch.utils.data.Dataset for monolingual data. |
|
|
|
Args: |
|
dataset (torch.utils.data.Dataset): dataset to wrap |
|
sizes (List[int]): sentence lengths |
|
vocab (~fairseq.data.Dictionary): vocabulary |
|
shuffle (bool, optional): shuffle the elements before batching |
|
(default: True). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
sizes, |
|
src_vocab, |
|
tgt_vocab=None, |
|
add_eos_for_other_targets=False, |
|
shuffle=False, |
|
targets=None, |
|
add_bos_token=False, |
|
fixed_pad_length=None, |
|
pad_to_bsz=None, |
|
src_lang_idx=None, |
|
tgt_lang_idx=None, |
|
): |
|
self.dataset = dataset |
|
self.sizes = np.array(sizes) |
|
self.vocab = src_vocab |
|
self.tgt_vocab = tgt_vocab or src_vocab |
|
self.add_eos_for_other_targets = add_eos_for_other_targets |
|
self.shuffle = shuffle |
|
self.add_bos_token = add_bos_token |
|
self.fixed_pad_length = fixed_pad_length |
|
self.pad_to_bsz = pad_to_bsz |
|
self.src_lang_idx = src_lang_idx |
|
self.tgt_lang_idx = tgt_lang_idx |
|
|
|
assert targets is None or all( |
|
t in {"self", "future", "past"} for t in targets |
|
), "targets must be none or one of 'self', 'future', 'past'" |
|
if targets is not None and len(targets) == 0: |
|
targets = None |
|
self.targets = targets |
|
|
|
def __getitem__(self, index): |
|
if self.targets is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source, future_target, past_target = self.dataset[index] |
|
source, target = self._make_source_target( |
|
source, future_target, past_target |
|
) |
|
else: |
|
source = self.dataset[index] |
|
target = None |
|
source, target = self._maybe_add_bos(source, target) |
|
return {"id": index, "source": source, "target": target} |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def _make_source_target(self, source, future_target, past_target): |
|
if self.targets is not None: |
|
target = [] |
|
|
|
if ( |
|
self.add_eos_for_other_targets |
|
and (("self" in self.targets) or ("past" in self.targets)) |
|
and source[-1] != self.vocab.eos() |
|
): |
|
|
|
source = torch.cat([source, source.new([self.vocab.eos()])]) |
|
|
|
if "future" in self.targets: |
|
future_target = torch.cat( |
|
[future_target, future_target.new([self.vocab.pad()])] |
|
) |
|
if "past" in self.targets: |
|
|
|
|
|
past_target = torch.cat( |
|
[ |
|
past_target.new([self.vocab.pad()]), |
|
past_target[1:], |
|
source[-2, None], |
|
] |
|
) |
|
|
|
for t in self.targets: |
|
if t == "self": |
|
target.append(source) |
|
elif t == "future": |
|
target.append(future_target) |
|
elif t == "past": |
|
target.append(past_target) |
|
else: |
|
raise Exception("invalid target " + t) |
|
|
|
if len(target) == 1: |
|
target = target[0] |
|
else: |
|
target = future_target |
|
|
|
return source, self._filter_vocab(target) |
|
|
|
def _maybe_add_bos(self, source, target): |
|
if self.add_bos_token: |
|
source = torch.cat([source.new([self.vocab.bos()]), source]) |
|
if target is not None: |
|
target = torch.cat([target.new([self.tgt_vocab.bos()]), target]) |
|
return source, target |
|
|
|
def num_tokens_vec(self, indices): |
|
"""Return the number of tokens for a set of positions defined by indices. |
|
This value is used to enforce ``--max-tokens`` during batching.""" |
|
return self.sizes[indices] |
|
|
|
def _filter_vocab(self, target): |
|
if len(self.tgt_vocab) != len(self.vocab): |
|
|
|
def _filter(target): |
|
mask = target.ge(len(self.tgt_vocab)) |
|
if mask.any(): |
|
target[mask] = self.tgt_vocab.unk() |
|
return target |
|
|
|
if isinstance(target, list): |
|
return [_filter(t) for t in target] |
|
return _filter(target) |
|
return target |
|
|
|
def collater(self, samples): |
|
"""Merge a list of samples to form a mini-batch. |
|
|
|
Args: |
|
samples (List[dict]): samples to collate |
|
|
|
Returns: |
|
dict: a mini-batch with the following keys: |
|
|
|
- `id` (LongTensor): example IDs in the original input order |
|
- `ntokens` (int): total number of tokens in the batch |
|
- `net_input` (dict): the input to the Model, containing keys: |
|
|
|
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in |
|
the source sentence of shape `(bsz, src_len)`. Padding will |
|
appear on the right. |
|
|
|
- `target` (LongTensor): a padded 2D Tensor of tokens in the |
|
target sentence of shape `(bsz, tgt_len)`. Padding will appear |
|
on the right. |
|
""" |
|
return collate( |
|
samples, |
|
self.vocab.pad(), |
|
self.vocab.eos(), |
|
self.fixed_pad_length, |
|
self.pad_to_bsz, |
|
) |
|
|
|
def num_tokens(self, index): |
|
"""Return the number of tokens in a sample. This value is used to |
|
enforce ``--max-tokens`` during batching.""" |
|
return self.sizes[index] |
|
|
|
def size(self, index): |
|
"""Return an example's size as a float or tuple. This value is used when |
|
filtering a dataset with ``--max-positions``.""" |
|
return self.sizes[index] |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
if self.shuffle: |
|
order = [np.random.permutation(len(self))] |
|
else: |
|
order = [np.arange(len(self))] |
|
order.append(self.sizes) |
|
return np.lexsort(order) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return getattr(self.dataset, "supports_prefetch", False) |
|
|
|
def prefetch(self, indices): |
|
self.dataset.prefetch(indices) |
|
|