conex / espnet2 /iterators /multiple_iter_factory.py
tobiasc's picture
Initial commit
ad16788
import logging
from typing import Callable
from typing import Collection
from typing import Iterator
import numpy as np
from typeguard import check_argument_types
from espnet2.iterators.abs_iter_factory import AbsIterFactory
class MultipleIterFactory(AbsIterFactory):
def __init__(
self,
build_funcs: Collection[Callable[[], AbsIterFactory]],
seed: int = 0,
shuffle: bool = False,
):
assert check_argument_types()
self.build_funcs = list(build_funcs)
self.seed = seed
self.shuffle = shuffle
def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator:
if shuffle is None:
shuffle = self.shuffle
build_funcs = list(self.build_funcs)
if shuffle:
np.random.RandomState(epoch + self.seed).shuffle(build_funcs)
for i, build_func in enumerate(build_funcs):
logging.info(f"Building {i}th iter-factory...")
iter_factory = build_func()
assert isinstance(iter_factory, AbsIterFactory), type(iter_factory)
yield from iter_factory.build_iter(epoch, shuffle)