# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from abc import ABCMeta, abstractmethod from typing import Any, Dict, Union from torch.utils.data import DataLoader class BaseLoop(metaclass=ABCMeta): """Base loop class. All subclasses inherited from ``BaseLoop`` should overwrite the :meth:`run` method. Args: runner (Runner): A reference of runner. dataloader (Dataloader or dict): An iterator to generate one batch of dataset each iteration. """ def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: self._runner = runner if isinstance(dataloader, dict): # Determine whether or not different ranks use different seed. diff_rank_seed = runner._randomness_cfg.get( 'diff_rank_seed', False) self.dataloader = runner.build_dataloader( dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) else: self.dataloader = dataloader @property def runner(self): return self._runner @abstractmethod def run(self) -> Any: """Execute loop."""