File size: 7,617 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
from abc import abstractmethod, ABC
from typing import Any, List, Optional, Union, Callable
import copy
from dataclasses import dataclass
from functools import wraps
from ding.utils import fastcopy


def apply_middleware(func_name: str):

    def wrap_func(base_func: Callable):

        @wraps(base_func)
        def handler(buffer, *args, **kwargs):
            """
            Overview:
                The real processing starts here, we apply the middleware one by one,
                each middleware will receive next `chained` function, which is an executor of next
                middleware. You can change the input arguments to the next `chained` middleware, and you
                also can get the return value from the next middleware, so you have the
                maximum freedom to choose at what stage to implement your method.
            """

            def wrap_handler(middleware, *args, **kwargs):
                if len(middleware) == 0:
                    return base_func(buffer, *args, **kwargs)

                def chain(*args, **kwargs):
                    return wrap_handler(middleware[1:], *args, **kwargs)

                func = middleware[0]
                return func(func_name, chain, *args, **kwargs)

            return wrap_handler(buffer._middleware, *args, **kwargs)

        return handler

    return wrap_func


@dataclass
class BufferedData:
    data: Any
    index: str
    meta: dict


# Register new dispatcher on fastcopy to avoid circular references
def _copy_buffereddata(d: BufferedData) -> BufferedData:
    return BufferedData(data=fastcopy.copy(d.data), index=d.index, meta=fastcopy.copy(d.meta))


fastcopy.dispatch[BufferedData] = _copy_buffereddata


class Buffer(ABC):
    """
    Buffer is an abstraction of device storage, third-party services or data structures,
    For example, memory queue, sum-tree, redis, or di-store.
    """

    def __init__(self, size: int) -> None:
        self._middleware = []
        self.size = size

    @abstractmethod
    def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
        """
        Overview:
            Push data and it's meta information in buffer.
        Arguments:
            - data (:obj:`Any`): The data which will be pushed into buffer.
            - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
        Returns:
            - buffered_data (:obj:`BufferedData`): The pushed data.
        """
        raise NotImplementedError

    @abstractmethod
    def sample(
            self,
            size: Optional[int] = None,
            indices: Optional[List[str]] = None,
            replace: bool = False,
            sample_range: Optional[slice] = None,
            ignore_insufficient: bool = False,
            groupby: Optional[str] = None,
            unroll_len: Optional[int] = None
    ) -> Union[List[BufferedData], List[List[BufferedData]]]:
        """
        Overview:
            Sample data with length ``size``.
        Arguments:
            - size (:obj:`Optional[int]`): The number of the data that will be sampled.
            - indices (:obj:`Optional[List[str]]`): Sample with multiple indices.
            - replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer.
            - sample_range (:obj:`slice`): Sample range slice.
            - ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
                with no repetition will not cause an exception.
            - groupby (:obj:`Optional[str]`): Groupby key in meta, i.e. groupby="episode"
            - unroll_len (:obj:`Optional[int]`): Number of consecutive frames within a group.
        Returns:
            - sample_data (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`):
                A list of data with length ``size``, may be nested if groupby is set.
        """
        raise NotImplementedError

    @abstractmethod
    def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
        """
        Overview:
            Update data and meta by index
        Arguments:
            - index (:obj:`str`): Index of data.
            - data (:obj:`any`): Pure data.
            - meta (:obj:`dict`): Meta information.
        Returns:
            - success (:obj:`bool`): Success or not, if data with the index not exist in buffer, return false.
        """
        raise NotImplementedError

    @abstractmethod
    def delete(self, index: str):
        """
        Overview:
            Delete one data sample by index
        Arguments:
            - index (:obj:`str`): Index
        """
        raise NotImplementedError

    @abstractmethod
    def save_data(self, file_name: str):
        """
        Overview:
            Save buffer data into a file.
        Arguments:
            - file_name (:obj:`str`): file name of buffer data
        """
        raise NotImplementedError

    @abstractmethod
    def load_data(self, file_name: str):
        """
        Overview:
            Load buffer data from a file.
        Arguments:
            - file_name (:obj:`str`): file name of buffer data
        """
        raise NotImplementedError

    @abstractmethod
    def count(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def clear(self) -> None:
        raise NotImplementedError

    @abstractmethod
    def get(self, idx: int) -> BufferedData:
        """
        Overview:
            Get item by subscript index
        Arguments:
            - idx (:obj:`int`): Subscript index
        Returns:
            - buffered_data (:obj:`BufferedData`): Item from buffer
        """
        raise NotImplementedError

    def use(self, func: Callable) -> "Buffer":
        """
        Overview:
            Use algorithm middleware to modify the behavior of the buffer.
            Every middleware should be a callable function, it will receive three argument parts, including:
            1. The buffer instance, you can use this instance to visit every thing of the buffer, including the storage.
            2. The functions called by the user, there are three methods named `push` , `sample` and `clear` , \
               so you can use these function name to decide which action to choose.
            3. The remaining arguments passed by the user to the original function, will be passed in `*args` .

            Each middleware handler should return two parts of the value, including:
            1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, \
               no more middleware will be executed during this execution
            2. The remaining values, will be passed to the next middleware or the default function in the buffer.
        Arguments:
            - func (:obj:`Callable`): The middleware handler
        Returns:
            - buffer (:obj:`Buffer`): The instance self
        """
        self._middleware.append(func)
        return self

    def view(self) -> "Buffer":
        r"""
        Overview:
            A view is a new instance of buffer, with a deepcopy of every property except the storage.
            The storage is shared among all the buffer instances.
        Returns:
            - buffer (:obj:`Buffer`): The instance self
        """
        return copy.copy(self)

    def __copy__(self) -> "Buffer":
        raise NotImplementedError

    def __len__(self) -> int:
        return self.count()

    def __getitem__(self, idx: int) -> BufferedData:
        return self.get(idx)