|
from typing import Any |
|
from typing import Sequence |
|
from typing import Union |
|
|
|
import numpy as np |
|
from torch.utils.data import DataLoader |
|
from typeguard import check_argument_types |
|
|
|
from espnet2.iterators.abs_iter_factory import AbsIterFactory |
|
from espnet2.samplers.abs_sampler import AbsSampler |
|
|
|
|
|
class RawSampler(AbsSampler): |
|
def __init__(self, batches): |
|
self.batches = batches |
|
|
|
def __len__(self): |
|
return len(self.batches) |
|
|
|
def __iter__(self): |
|
return iter(self.batches) |
|
|
|
def generate(self, seed): |
|
return list(self.batches) |
|
|
|
|
|
class SequenceIterFactory(AbsIterFactory): |
|
"""Build iterator for each epoch. |
|
|
|
This class simply creates pytorch DataLoader except for the following points: |
|
- The random seed is decided according to the number of epochs. This feature |
|
guarantees reproducibility when resuming from middle of training process. |
|
- Enable to restrict the number of samples for one epoch. This features |
|
controls the interval number between training and evaluation. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset, |
|
batches: Union[AbsSampler, Sequence[Sequence[Any]]], |
|
num_iters_per_epoch: int = None, |
|
seed: int = 0, |
|
shuffle: bool = False, |
|
num_workers: int = 0, |
|
collate_fn=None, |
|
pin_memory: bool = False, |
|
): |
|
assert check_argument_types() |
|
|
|
if not isinstance(batches, AbsSampler): |
|
self.sampler = RawSampler(batches) |
|
else: |
|
self.sampler = batches |
|
|
|
self.dataset = dataset |
|
self.num_iters_per_epoch = num_iters_per_epoch |
|
self.shuffle = shuffle |
|
self.seed = seed |
|
self.num_workers = num_workers |
|
self.collate_fn = collate_fn |
|
|
|
self.pin_memory = pin_memory |
|
|
|
def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader: |
|
if shuffle is None: |
|
shuffle = self.shuffle |
|
|
|
if self.num_iters_per_epoch is not None: |
|
N = len(self.sampler) |
|
|
|
if self.num_iters_per_epoch < N: |
|
N = len(self.sampler) |
|
real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N) |
|
|
|
if offset >= self.num_iters_per_epoch: |
|
current_batches = self.sampler.generate(real_epoch + self.seed) |
|
if shuffle: |
|
np.random.RandomState(real_epoch + self.seed).shuffle( |
|
current_batches |
|
) |
|
batches = current_batches[ |
|
offset - self.num_iters_per_epoch : offset |
|
] |
|
else: |
|
prev_batches = self.sampler.generate(real_epoch - 1 + self.seed) |
|
current_batches = self.sampler.generate(real_epoch + self.seed) |
|
if shuffle: |
|
np.random.RandomState(real_epoch - 1 + self.seed).shuffle( |
|
prev_batches |
|
) |
|
np.random.RandomState(real_epoch + self.seed).shuffle( |
|
current_batches |
|
) |
|
batches = ( |
|
prev_batches[offset - self.num_iters_per_epoch :] |
|
+ current_batches[:offset] |
|
) |
|
|
|
|
|
else: |
|
_epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N) |
|
_remain = self.num_iters_per_epoch |
|
batches = [] |
|
current_batches = self.sampler.generate(_epoch + self.seed) |
|
if shuffle: |
|
np.random.RandomState(_epoch + self.seed).shuffle(current_batches) |
|
while _remain > 0: |
|
|
|
_batches = current_batches[_cursor : _cursor + _remain] |
|
batches += _batches |
|
if _cursor + _remain >= N: |
|
_epoch += 1 |
|
_cursor = 0 |
|
current_batches = self.sampler.generate(_epoch + self.seed) |
|
if shuffle: |
|
np.random.RandomState(_epoch + self.seed).shuffle( |
|
current_batches |
|
) |
|
else: |
|
_cursor = _cursor + _remain |
|
_remain -= len(_batches) |
|
|
|
assert len(batches) == self.num_iters_per_epoch |
|
|
|
else: |
|
batches = self.sampler.generate(epoch + self.seed) |
|
if shuffle: |
|
np.random.RandomState(epoch + self.seed).shuffle(batches) |
|
|
|
|
|
if self.collate_fn is not None: |
|
kwargs = dict(collate_fn=self.collate_fn) |
|
else: |
|
kwargs = {} |
|
|
|
return DataLoader( |
|
dataset=self.dataset, |
|
batch_sampler=batches, |
|
num_workers=self.num_workers, |
|
pin_memory=self.pin_memory, |
|
**kwargs, |
|
) |
|
|