File size: 1,558 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from . import FairseqDataset
class ConcatSentencesDataset(FairseqDataset):
def __init__(self, *datasets):
super().__init__()
self.datasets = datasets
assert all(
len(ds) == len(datasets[0]) for ds in datasets
), "datasets must have the same length"
def __getitem__(self, index):
return torch.cat([ds[index] for ds in self.datasets])
def __len__(self):
return len(self.datasets[0])
def collater(self, samples):
return self.datasets[0].collater(samples)
@property
def sizes(self):
return sum(ds.sizes for ds in self.datasets)
def num_tokens(self, index):
return sum(ds.num_tokens(index) for ds in self.datasets)
def size(self, index):
return sum(ds.size(index) for ds in self.datasets)
def ordered_indices(self):
return self.datasets[0].ordered_indices()
@property
def supports_prefetch(self):
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
def prefetch(self, indices):
for ds in self.datasets:
if getattr(ds, "supports_prefetch", False):
ds.prefetch(indices)
def set_epoch(self, epoch):
super().set_epoch(epoch)
for ds in self.datasets:
if hasattr(ds, "set_epoch"):
ds.set_epoch(epoch)
|