|
|
|
|
|
|
|
|
|
|
|
from torch.utils.data.dataloader import default_collate |
|
|
|
from . import FairseqDataset |
|
|
|
|
|
class BaseWrapperDataset(FairseqDataset): |
|
def __init__(self, dataset): |
|
super().__init__() |
|
self.dataset = dataset |
|
|
|
def __getitem__(self, index): |
|
return self.dataset[index] |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def collater(self, samples): |
|
if hasattr(self.dataset, "collater"): |
|
return self.dataset.collater(samples) |
|
else: |
|
return default_collate(samples) |
|
|
|
@property |
|
def sizes(self): |
|
return self.dataset.sizes |
|
|
|
def num_tokens(self, index): |
|
return self.dataset.num_tokens(index) |
|
|
|
def size(self, index): |
|
return self.dataset.size(index) |
|
|
|
def ordered_indices(self): |
|
return self.dataset.ordered_indices() |
|
|
|
@property |
|
def supports_prefetch(self): |
|
return getattr(self.dataset, "supports_prefetch", False) |
|
|
|
def attr(self, attr: str, index: int): |
|
return self.dataset.attr(attr, index) |
|
|
|
def prefetch(self, indices): |
|
self.dataset.prefetch(indices) |
|
|
|
def get_batch_shapes(self): |
|
return self.dataset.get_batch_shapes() |
|
|
|
def batch_by_size( |
|
self, |
|
indices, |
|
max_tokens=None, |
|
max_sentences=None, |
|
required_batch_size_multiple=1, |
|
): |
|
return self.dataset.batch_by_size( |
|
indices, |
|
max_tokens=max_tokens, |
|
max_sentences=max_sentences, |
|
required_batch_size_multiple=required_batch_size_multiple, |
|
) |
|
|
|
def filter_indices_by_size(self, indices, max_sizes): |
|
return self.dataset.filter_indices_by_size(indices, max_sizes) |
|
|
|
@property |
|
def can_reuse_epoch_itr_across_epochs(self): |
|
return self.dataset.can_reuse_epoch_itr_across_epochs |
|
|
|
def set_epoch(self, epoch): |
|
super().set_epoch(epoch) |
|
if hasattr(self.dataset, "set_epoch"): |
|
self.dataset.set_epoch(epoch) |
|
|