from typing import Optional, Callable, List, Any, Iterable import torch def example_get_data_fn() -> Any: """ Overview: Get data from file or other middleware .. note:: staticmethod or static function, all the operation is on CPU """ # 1. read data from file or other middleware # 2. data post-processing(e.g.: normalization, to tensor) # 3. return data pass class IDataLoader: """ Overview: Base class of data loader Interfaces: ``__init__``, ``__next__``, ``__iter__``, ``_get_data``, ``close`` """ def __next__(self, batch_size: Optional[int] = None) -> torch.Tensor: """ Overview: Get one batch data Arguments: - batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \ if batch_size is None, use default batch_size value """ # get one batch train data if batch_size is None: batch_size = self._batch_size data = self._get_data(batch_size) return self._collate_fn(data) def __iter__(self) -> Iterable: """ Overview: Get data iterator """ return self def _get_data(self, batch_size: Optional[int] = None) -> List[torch.Tensor]: """ Overview: Get one batch data Arguments: - batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \ if batch_size is None, use default batch_size value """ raise NotImplementedError def close(self) -> None: """ Overview: Close data loader """ # release resource pass