|
|
|
|
|
|
|
|
|
|
|
import bisect |
|
|
|
import numpy as np |
|
from torch.utils.data.dataloader import default_collate |
|
|
|
from . import FairseqDataset |
|
|
|
|
|
class ConcatDataset(FairseqDataset): |
|
@staticmethod |
|
def cumsum(sequence, sample_ratios): |
|
r, s = [], 0 |
|
for e, ratio in zip(sequence, sample_ratios): |
|
curr_len = int(ratio * len(e)) |
|
r.append(curr_len + s) |
|
s += curr_len |
|
return r |
|
|
|
def __init__(self, datasets, sample_ratios=1): |
|
super(ConcatDataset, self).__init__() |
|
assert len(datasets) > 0, "datasets should not be an empty iterable" |
|
self.datasets = list(datasets) |
|
if isinstance(sample_ratios, int): |
|
sample_ratios = [sample_ratios] * len(self.datasets) |
|
self.sample_ratios = sample_ratios |
|
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) |
|
self.real_sizes = [len(d) for d in self.datasets] |
|
|
|
def __len__(self): |
|
return self.cumulative_sizes[-1] |
|
|
|
def __getitem__(self, idx): |
|
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) |
|
return self.datasets[dataset_idx][sample_idx] |
|
|
|
def _get_dataset_and_sample_index(self, idx: int): |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) |
|
if dataset_idx == 0: |
|
sample_idx = idx |
|
else: |
|
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] |
|
sample_idx = sample_idx % self.real_sizes[dataset_idx] |
|
return dataset_idx, sample_idx |
|
|
|
def collater(self, samples, **extra_args): |
|
|
|
if hasattr(self.datasets[0], "collater"): |
|
return self.datasets[0].collater(samples, **extra_args) |
|
else: |
|
return default_collate(samples, **extra_args) |
|
|
|
def size(self, idx: int): |
|
""" |
|
Return an example's size as a float or tuple. |
|
""" |
|
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) |
|
return self.datasets[dataset_idx].size(sample_idx) |
|
|
|
def num_tokens(self, index: int): |
|
return np.max(self.size(index)) |
|
|
|
def attr(self, attr: str, index: int): |
|
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index) |
|
return getattr(self.datasets[dataset_idx], attr, None) |
|
|
|
@property |
|
def sizes(self): |
|
_dataset_sizes = [] |
|
for ds, sr in zip(self.datasets, self.sample_ratios): |
|
if isinstance(ds.sizes, np.ndarray): |
|
_dataset_sizes.append(np.tile(ds.sizes, sr)) |
|
else: |
|
|
|
assert isinstance(ds.sizes, list) |
|
_dataset_sizes.append(np.tile(ds.sizes[0], sr)) |
|
return np.concatenate(_dataset_sizes) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return all(d.supports_prefetch for d in self.datasets) |
|
|
|
def ordered_indices(self): |
|
""" |
|
Returns indices sorted by length. So less padding is needed. |
|
""" |
|
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1: |
|
|
|
indices = np.arange(len(self)) |
|
sizes = self.sizes |
|
tgt_sizes = ( |
|
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None |
|
) |
|
src_sizes = ( |
|
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes |
|
) |
|
|
|
if tgt_sizes is not None: |
|
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")] |
|
return indices[np.argsort(src_sizes[indices], kind="mergesort")] |
|
else: |
|
return np.argsort(self.sizes) |
|
|
|
def prefetch(self, indices): |
|
frm = 0 |
|
for to, ds in zip(self.cumulative_sizes, self.datasets): |
|
real_size = len(ds) |
|
if getattr(ds, "supports_prefetch", False): |
|
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) |
|
frm = to |
|
|
|
@property |
|
def can_reuse_epoch_itr_across_epochs(self): |
|
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets) |
|
|
|
def set_epoch(self, epoch): |
|
super().set_epoch(epoch) |
|
for ds in self.datasets: |
|
if hasattr(ds, "set_epoch"): |
|
ds.set_epoch(epoch) |
|
|