TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from torch.utils.data.dataloader import default_collate
from . import FairseqDataset
class BaseWrapperDataset(FairseqDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, "collater"):
return self.dataset.collater(samples)
else:
return default_collate(samples)
@property
def sizes(self):
return self.dataset.sizes
def num_tokens(self, index):
return self.dataset.num_tokens(index)
def size(self, index):
return self.dataset.size(index)
def ordered_indices(self):
return self.dataset.ordered_indices()
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def attr(self, attr: str, index: int):
return self.dataset.attr(attr, index)
def prefetch(self, indices):
self.dataset.prefetch(indices)
def get_batch_shapes(self):
return self.dataset.get_batch_shapes()
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
return self.dataset.batch_by_size(
indices,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
def filter_indices_by_size(self, indices, max_sizes):
return self.dataset.filter_indices_by_size(indices, max_sizes)
@property
def can_reuse_epoch_itr_across_epochs(self):
return self.dataset.can_reuse_epoch_itr_across_epochs
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)