File size: 979 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING
from ding.data.buffer import BufferedData
if TYPE_CHECKING:
from ding.data.buffer.buffer import Buffer
def sample_range_view(buffer_: 'Buffer', start: Optional[int] = None, end: Optional[int] = None) -> Callable:
"""
Overview:
The middleware that places restrictions on the range of indices during sampling.
Arguments:
- start (:obj:`int`): The starting index.
- end (:obj:`int`): One above the ending index.
"""
assert start is not None or end is not None
if start and start < 0:
start = buffer_.size + start
if end and end < 0:
end = buffer_.size + end
sample_range = slice(start, end)
def _sample_range_view(action: str, chain: Callable, *args, **kwargs) -> Any:
if action == "sample":
return chain(*args, sample_range=sample_range)
return chain(*args, **kwargs)
return _sample_range_view
|