# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import numpy as np import torch.utils.data from fairseq.data import data_utils logger = logging.getLogger(__name__) class EpochListening: """Mixin for receiving updates whenever the epoch increments.""" @property def can_reuse_epoch_itr_across_epochs(self): """ Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for this dataset across epochs. This needs to return ``False`` if the sample sizes can change across epochs, in which case we may need to regenerate batches at each epoch. If your dataset relies in ``set_epoch`` then you should consider setting this to ``False``. """ return True def set_epoch(self, epoch): """Will receive the updated epoch number at the beginning of the epoch.""" pass class FairseqDataset(torch.utils.data.Dataset, EpochListening): """A dataset that provides helpers for batching.""" def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def collater(self, samples): """Merge a list of samples to form a mini-batch. Args: samples (List[dict]): samples to collate Returns: dict: a mini-batch suitable for forwarding with a Model """ raise NotImplementedError def num_tokens(self, index): """Return the number of tokens in a sample. This value is used to enforce ``--max-tokens`` during batching.""" raise NotImplementedError def num_tokens_vec(self, indices): """Return the number of tokens for a set of positions defined by indices. This value is used to enforce ``--max-tokens`` during batching.""" raise NotImplementedError def size(self, index): """Return an example's size as a float or tuple. This value is used when filtering a dataset with ``--max-positions``.""" raise NotImplementedError def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" return np.arange(len(self), dtype=np.int64) @property def supports_prefetch(self): """Whether this dataset supports prefetching.""" return False def attr(self, attr: str, index: int): return getattr(self, attr, None) def prefetch(self, indices): """Prefetch the data required for this epoch.""" raise NotImplementedError def get_batch_shapes(self): """ Return a list of valid batch shapes, for example:: [(8, 512), (16, 256), (32, 128)] The first dimension of each tuple is the batch size and can be ``None`` to automatically infer the max batch size based on ``--max-tokens``. The second dimension of each tuple is the max supported length as given by :func:`fairseq.data.FairseqDataset.num_tokens`. This will be used by :func:`fairseq.data.FairseqDataset.batch_by_size` to restrict batch shapes. This is useful on TPUs to avoid too many dynamic shapes (and recompilations). """ return None def batch_by_size( self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1, ): """ Given an ordered set of indices, return batches according to *max_tokens*, *max_sentences* and *required_batch_size_multiple*. """ from fairseq.data import data_utils fixed_shapes = self.get_batch_shapes() if fixed_shapes is not None: def adjust_bsz(bsz, num_tokens): if bsz is None: assert max_tokens is not None, "Must specify --max-tokens" bsz = max_tokens // num_tokens if max_sentences is not None: bsz = min(bsz, max_sentences) elif ( bsz >= required_batch_size_multiple and bsz % required_batch_size_multiple != 0 ): bsz -= bsz % required_batch_size_multiple return bsz fixed_shapes = np.array( [ [adjust_bsz(bsz, num_tokens), num_tokens] for (bsz, num_tokens) in fixed_shapes ] ) try: num_tokens_vec = self.num_tokens_vec(indices).astype("int64") except NotImplementedError: num_tokens_vec = None return data_utils.batch_by_size( indices, num_tokens_fn=self.num_tokens, num_tokens_vec=num_tokens_vec, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, fixed_shapes=fixed_shapes, ) def filter_indices_by_size(self, indices, max_sizes): """ Filter a list of sample indices. Remove those that are longer than specified in *max_sizes*. WARNING: don't update, override method in child classes Args: indices (np.array): original array of sample indices max_sizes (int or list[int] or tuple[int]): max sample size, can be defined separately for src and tgt (then list or tuple) Returns: np.array: filtered sample array list: list of removed indices """ if isinstance(max_sizes, float) or isinstance(max_sizes, int): if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray): ignored = indices[self.sizes[indices] > max_sizes].tolist() indices = indices[self.sizes[indices] <= max_sizes] elif ( hasattr(self, "sizes") and isinstance(self.sizes, list) and len(self.sizes) == 1 ): ignored = indices[self.sizes[0][indices] > max_sizes].tolist() indices = indices[self.sizes[0][indices] <= max_sizes] else: indices, ignored = data_utils._filter_by_size_dynamic( indices, self.size, max_sizes ) else: indices, ignored = data_utils._filter_by_size_dynamic( indices, self.size, max_sizes ) return indices, ignored @property def supports_fetch_outside_dataloader(self): """Whether this dataset supports fetching outside the workers of the dataloader.""" return True class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): """ For datasets that need to be read sequentially, usually because the data is being streamed or otherwise can't be manipulated on a single machine. """ def __iter__(self): raise NotImplementedError