|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
|
import torch |
|
from torch.utils.data.dataloader import default_collate |
|
|
|
from . import FairseqDataset |
|
|
|
|
|
def _flatten(dico, prefix=None): |
|
"""Flatten a nested dictionary.""" |
|
new_dico = OrderedDict() |
|
if isinstance(dico, dict): |
|
prefix = prefix + "." if prefix is not None else "" |
|
for k, v in dico.items(): |
|
if v is None: |
|
continue |
|
new_dico.update(_flatten(v, prefix + k)) |
|
elif isinstance(dico, list): |
|
for i, v in enumerate(dico): |
|
new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) |
|
else: |
|
new_dico = OrderedDict({prefix: dico}) |
|
return new_dico |
|
|
|
|
|
def _unflatten(dico): |
|
"""Unflatten a flattened dictionary into a nested dictionary.""" |
|
new_dico = OrderedDict() |
|
for full_k, v in dico.items(): |
|
full_k = full_k.split(".") |
|
node = new_dico |
|
for k in full_k[:-1]: |
|
if k.startswith("[") and k.endswith("]"): |
|
k = int(k[1:-1]) |
|
if k not in node: |
|
node[k] = OrderedDict() |
|
node = node[k] |
|
node[full_k[-1]] = v |
|
return new_dico |
|
|
|
|
|
class NestedDictionaryDataset(FairseqDataset): |
|
def __init__(self, defn, sizes=None): |
|
super().__init__() |
|
self.defn = _flatten(defn) |
|
self.sizes = [sizes] if not isinstance(sizes, (list, tuple)) else sizes |
|
|
|
first = None |
|
for v in self.defn.values(): |
|
if not isinstance( |
|
v, |
|
( |
|
FairseqDataset, |
|
torch.utils.data.Dataset, |
|
), |
|
): |
|
raise ValueError("Expected Dataset but found: {}".format(v.__class__)) |
|
first = first or v |
|
if len(v) > 0: |
|
assert len(v) == len(first), "dataset lengths must match" |
|
|
|
self._len = len(first) |
|
|
|
def __getitem__(self, index): |
|
return OrderedDict((k, ds[index]) for k, ds in self.defn.items()) |
|
|
|
def __len__(self): |
|
return self._len |
|
|
|
def collater(self, samples): |
|
"""Merge a list of samples to form a mini-batch. |
|
|
|
Args: |
|
samples (List[dict]): samples to collate |
|
|
|
Returns: |
|
dict: a mini-batch suitable for forwarding with a Model |
|
""" |
|
if len(samples) == 0: |
|
return {} |
|
sample = OrderedDict() |
|
for k, ds in self.defn.items(): |
|
try: |
|
sample[k] = ds.collater([s[k] for s in samples]) |
|
except NotImplementedError: |
|
sample[k] = default_collate([s[k] for s in samples]) |
|
return _unflatten(sample) |
|
|
|
def num_tokens(self, index): |
|
"""Return the number of tokens in a sample. This value is used to |
|
enforce ``--max-tokens`` during batching.""" |
|
return max(s[index] for s in self.sizes) |
|
|
|
def size(self, index): |
|
"""Return an example's size as a float or tuple. This value is used when |
|
filtering a dataset with ``--max-positions``.""" |
|
if len(self.sizes) == 1: |
|
return self.sizes[0][index] |
|
else: |
|
return (s[index] for s in self.sizes) |
|
|
|
@property |
|
def supports_prefetch(self): |
|
"""Whether this dataset supports prefetching.""" |
|
return any(ds.supports_prefetch for ds in self.defn.values()) |
|
|
|
def prefetch(self, indices): |
|
"""Prefetch the data required for this epoch.""" |
|
for ds in self.defn.values(): |
|
if getattr(ds, "supports_prefetch", False): |
|
ds.prefetch(indices) |
|
|
|
@property |
|
def can_reuse_epoch_itr_across_epochs(self): |
|
return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) |
|
|
|
def set_epoch(self, epoch): |
|
super().set_epoch(epoch) |
|
for ds in self.defn.values(): |
|
ds.set_epoch(epoch) |
|
|