|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from collections import OrderedDict |
|
from typing import Dict, Sequence |
|
|
|
import numpy as np |
|
|
|
from . import FairseqDataset, LanguagePairDataset |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class RoundRobinZipDatasets(FairseqDataset): |
|
"""Zip multiple :class:`~fairseq.data.FairseqDataset` instances together. |
|
|
|
Shorter datasets are repeated in a round-robin fashion to match the length |
|
of the longest one. |
|
|
|
Args: |
|
datasets (Dict[~fairseq.data.FairseqDataset]): a dictionary of |
|
:class:`~fairseq.data.FairseqDataset` instances. |
|
eval_key (str, optional): a key used at evaluation time that causes |
|
this instance to pass-through batches from *datasets[eval_key]*. |
|
""" |
|
|
|
def __init__(self, datasets, eval_key=None): |
|
super().__init__() |
|
if isinstance(datasets, dict): |
|
datasets = OrderedDict(datasets) |
|
assert isinstance(datasets, OrderedDict) |
|
assert datasets, "Can't make a RoundRobinZipDatasets out of nothing" |
|
for dataset in datasets.values(): |
|
assert isinstance(dataset, FairseqDataset) |
|
|
|
self.datasets = datasets |
|
self.eval_key = eval_key |
|
|
|
self.longest_dataset_key = max(datasets, key=lambda k: len(datasets[k])) |
|
self.longest_dataset = datasets[self.longest_dataset_key] |
|
self._ordered_indices: Dict[str, Sequence[int]] = None |
|
|
|
def _map_index(self, key, index): |
|
assert ( |
|
self._ordered_indices is not None |
|
), "Must call RoundRobinZipDatasets.ordered_indices() first" |
|
o = self._ordered_indices[key] |
|
return o[index % len(o)] |
|
|
|
def __getitem__(self, index): |
|
if self.eval_key is None: |
|
return OrderedDict( |
|
[ |
|
(key, dataset[self._map_index(key, index)]) |
|
for key, dataset in self.datasets.items() |
|
] |
|
) |
|
else: |
|
|
|
return self.datasets[self.eval_key][self._map_index(self.eval_key, index)] |
|
|
|
def __len__(self): |
|
if self._ordered_indices is not None: |
|
return len(self._ordered_indices[self.longest_dataset_key]) |
|
return len(self.longest_dataset) |
|
|
|
def collater(self, samples): |
|
"""Merge a list of samples to form a mini-batch.""" |
|
if len(samples) == 0: |
|
return None |
|
if self.eval_key is None: |
|
return OrderedDict( |
|
[ |
|
(key, dataset.collater([sample[key] for sample in samples])) |
|
for key, dataset in self.datasets.items() |
|
] |
|
) |
|
else: |
|
|
|
return self.datasets[self.eval_key].collater(samples) |
|
|
|
def num_tokens(self, index): |
|
"""Return an example's length (number of tokens), used for batching.""" |
|
|
|
return max( |
|
dataset.num_tokens(self._map_index(key, index)) |
|
for key, dataset in self.datasets.items() |
|
) |
|
|
|
def size(self, index): |
|
"""Return an example's size as a float or tuple. This value is used when |
|
filtering a dataset with ``--max-positions``.""" |
|
return { |
|
key: dataset.size(self._map_index(key, index)) |
|
for key, dataset in self.datasets.items() |
|
} |
|
|
|
def ordered_indices(self): |
|
"""Ordered indices for batching.""" |
|
if self._ordered_indices is None: |
|
|
|
|
|
|
|
self._ordered_indices = OrderedDict( |
|
[ |
|
(key, dataset.ordered_indices()) |
|
for key, dataset in self.datasets.items() |
|
] |
|
) |
|
return np.arange(len(self)) |
|
|
|
def filter_indices_by_size(self, indices, max_positions=None): |
|
""" |
|
Filter each sub-dataset independently, then update the round robin to work |
|
on the filtered sub-datasets. |
|
""" |
|
|
|
def _deep_until_language_pair(dataset): |
|
if isinstance(dataset, LanguagePairDataset): |
|
return dataset |
|
if hasattr(dataset, "tgt_dataset"): |
|
return _deep_until_language_pair(dataset.tgt_dataset) |
|
if hasattr(dataset, "dataset"): |
|
return _deep_until_language_pair(dataset.dataset) |
|
raise Exception(f"Don't know how to unwrap this dataset: {dataset}") |
|
|
|
if not isinstance(max_positions, dict): |
|
max_positions = {k: max_positions for k in self.datasets.keys()} |
|
ignored_some = False |
|
for key, dataset in self.datasets.items(): |
|
dataset = _deep_until_language_pair(dataset) |
|
self._ordered_indices[key], ignored = dataset.filter_indices_by_size( |
|
self._ordered_indices[key], max_positions[key] |
|
) |
|
if len(ignored) > 0: |
|
ignored_some = True |
|
logger.warning( |
|
f"{len(ignored)} samples from {key} have invalid sizes and will be skipped, " |
|
f"max_positions={max_positions[key]}, first few sample ids={ignored[:10]}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
return (np.arange(len(self)), [0] if ignored_some else []) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return all( |
|
getattr(dataset, "supports_prefetch", False) |
|
for dataset in self.datasets.values() |
|
) |
|
|
|
def prefetch(self, indices): |
|
for key, dataset in self.datasets.items(): |
|
dataset.prefetch([self._map_index(key, index) for index in indices]) |
|
|