File size: 1,768 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from typing import Optional, Callable, List, Any, Iterable
import torch
def example_get_data_fn() -> Any:
"""
Overview:
Get data from file or other middleware
.. note::
staticmethod or static function, all the operation is on CPU
"""
# 1. read data from file or other middleware
# 2. data post-processing(e.g.: normalization, to tensor)
# 3. return data
pass
class IDataLoader:
"""
Overview:
Base class of data loader
Interfaces:
``__init__``, ``__next__``, ``__iter__``, ``_get_data``, ``close``
"""
def __next__(self, batch_size: Optional[int] = None) -> torch.Tensor:
"""
Overview:
Get one batch data
Arguments:
- batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \
if batch_size is None, use default batch_size value
"""
# get one batch train data
if batch_size is None:
batch_size = self._batch_size
data = self._get_data(batch_size)
return self._collate_fn(data)
def __iter__(self) -> Iterable:
"""
Overview:
Get data iterator
"""
return self
def _get_data(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
"""
Overview:
Get one batch data
Arguments:
- batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \
if batch_size is None, use default batch_size value
"""
raise NotImplementedError
def close(self) -> None:
"""
Overview:
Close data loader
"""
# release resource
pass
|