|
from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING |
|
from ding.data.buffer import BufferedData |
|
if TYPE_CHECKING: |
|
from ding.data.buffer.buffer import Buffer |
|
|
|
|
|
def sample_range_view(buffer_: 'Buffer', start: Optional[int] = None, end: Optional[int] = None) -> Callable: |
|
""" |
|
Overview: |
|
The middleware that places restrictions on the range of indices during sampling. |
|
Arguments: |
|
- start (:obj:`int`): The starting index. |
|
- end (:obj:`int`): One above the ending index. |
|
""" |
|
assert start is not None or end is not None |
|
if start and start < 0: |
|
start = buffer_.size + start |
|
if end and end < 0: |
|
end = buffer_.size + end |
|
sample_range = slice(start, end) |
|
|
|
def _sample_range_view(action: str, chain: Callable, *args, **kwargs) -> Any: |
|
if action == "sample": |
|
return chain(*args, sample_range=sample_range) |
|
return chain(*args, **kwargs) |
|
|
|
return _sample_range_view |
|
|