|
import logging |
|
from typing import Callable |
|
from typing import Collection |
|
from typing import Iterator |
|
|
|
import numpy as np |
|
from typeguard import check_argument_types |
|
|
|
from espnet2.iterators.abs_iter_factory import AbsIterFactory |
|
|
|
|
|
class MultipleIterFactory(AbsIterFactory): |
|
def __init__( |
|
self, |
|
build_funcs: Collection[Callable[[], AbsIterFactory]], |
|
seed: int = 0, |
|
shuffle: bool = False, |
|
): |
|
assert check_argument_types() |
|
self.build_funcs = list(build_funcs) |
|
self.seed = seed |
|
self.shuffle = shuffle |
|
|
|
def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator: |
|
if shuffle is None: |
|
shuffle = self.shuffle |
|
|
|
build_funcs = list(self.build_funcs) |
|
|
|
if shuffle: |
|
np.random.RandomState(epoch + self.seed).shuffle(build_funcs) |
|
|
|
for i, build_func in enumerate(build_funcs): |
|
logging.info(f"Building {i}th iter-factory...") |
|
iter_factory = build_func() |
|
assert isinstance(iter_factory, AbsIterFactory), type(iter_factory) |
|
yield from iter_factory.build_iter(epoch, shuffle) |
|
|