File size: 15,670 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
import os
import itertools
import random
import uuid
from ditk import logging
import hickle
from typing import Any, Iterable, List, Optional, Tuple, Union
from collections import Counter
from collections import defaultdict, deque, OrderedDict
from ding.data.buffer import Buffer, apply_middleware, BufferedData
from ding.utils import fastcopy
from ding.torch_utils import get_null_data
class BufferIndex():
"""
Overview:
Save index string and offset in key value pair.
"""
def __init__(self, maxlen: int, *args, **kwargs):
self.maxlen = maxlen
self.__map = OrderedDict(*args, **kwargs)
self._last_key = next(reversed(self.__map)) if len(self) > 0 else None
self._cumlen = len(self.__map)
def get(self, key: str) -> int:
value = self.__map[key]
value = value % self._cumlen + min(0, (self.maxlen - self._cumlen))
return value
def __len__(self) -> int:
return len(self.__map)
def has(self, key: str) -> bool:
return key in self.__map
def append(self, key: str):
self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0
self._last_key = key
self._cumlen += 1
if len(self) > self.maxlen:
self.__map.popitem(last=False)
def clear(self):
self.__map = OrderedDict()
self._last_key = None
self._cumlen = 0
class DequeBuffer(Buffer):
"""
Overview:
A buffer implementation based on the deque structure.
"""
def __init__(self, size: int, sliced: bool = False) -> None:
"""
Overview:
The initialization method of DequeBuffer.
Arguments:
- size (:obj:`int`): The maximum number of objects that the buffer can hold.
- sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group
"""
super().__init__(size=size)
self.storage = deque(maxlen=size)
self.indices = BufferIndex(maxlen=size)
self.sliced = sliced
# Meta index is a dict which uses deque as values
self.meta_index = {}
@apply_middleware("push")
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
"""
Overview:
The method that input the objects and the related meta information into the buffer.
Arguments:
- data (:obj:`Any`): The input object which can be in any format.
- meta (:obj:`Optional[dict]`): A dict that helps describe data, such as\
category, label, priority, etc. Default to ``None``.
"""
return self._push(data, meta)
@apply_middleware("sample")
def sample(
self,
size: Optional[int] = None,
indices: Optional[List[str]] = None,
replace: bool = False,
sample_range: Optional[slice] = None,
ignore_insufficient: bool = False,
groupby: Optional[str] = None,
unroll_len: Optional[int] = None
) -> Union[List[BufferedData], List[List[BufferedData]]]:
"""
Overview:
The method that randomly sample data from the buffer or retrieve certain data by indices.
Arguments:
- size (:obj:`Optional[int]`): The number of objects to be obtained from the buffer.
If ``indices`` is not specified, the ``size`` is required to randomly sample the\
corresponding number of objects from the buffer.
- indices (:obj:`Optional[List[str]]`): Only used when you want to retrieve data by indices.
Default to ``None``.
- replace (:obj:`bool`): As the sampling process is carried out one by one, this parameter\
determines whether the previous samples will be put back into the buffer for subsequent\
sampling. Default to ``False``, it means that duplicate samples will not appear in one\
``sample`` call.
- sample_range (:obj:`Optional[slice]`): The indices range to sample data. Default to ``None``,\
it means no restrictions on the range of indices for the sampling process.
- ignore_insufficient (:obj:`bool`): whether throw `` ValueError`` if the sampled size is smaller\
than the required size. Default to ``False``.
- groupby (:obj:`Optional[str]`): If this parameter is activated, the method will return a\
target size of object groups.
- unroll_len (:obj:`Optional[int]`): The unroll length of a trajectory, used only when the\
``groupby`` is activated.
Returns:
- sampled_data (Union[List[BufferedData], List[List[BufferedData]]]): The sampling result.
"""
storage = self.storage
if sample_range:
storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step))
# Size and indices
assert size or indices, "One of size and indices must not be empty."
if (size and indices) and (size != len(indices)):
raise AssertionError("Size and indices length must be equal.")
if not size:
size = len(indices)
# Indices and groupby
assert not (indices and groupby), "Cannot use groupby and indicex at the same time."
# Groupby and unroll_len
assert not unroll_len or (
unroll_len and groupby
), "Parameter unroll_len needs to be used in conjunction with groupby."
value_error = None
sampled_data = []
if indices:
indices_set = set(indices)
hashed_data = filter(lambda item: item.index in indices_set, storage)
hashed_data = map(lambda item: (item.index, item), hashed_data)
hashed_data = dict(hashed_data)
# Re-sample and return in indices order
sampled_data = [hashed_data[index] for index in indices]
elif groupby:
sampled_data = self._sample_by_group(
size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced
)
else:
if replace:
sampled_data = random.choices(storage, k=size)
else:
try:
sampled_data = random.sample(storage, k=size)
except ValueError as e:
value_error = e
if value_error or len(sampled_data) != size:
if ignore_insufficient:
logging.warning(
"Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}".
format(self.count(), size)
)
else:
raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count()))
sampled_data = self._independence(sampled_data)
return sampled_data
@apply_middleware("update")
def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
"""
Overview:
the method that update data and the related meta information with a certain index.
Arguments:
- data (:obj:`Any`): The data which is supposed to replace the old one. If you set it\
to ``None``, nothing will happen to the old record.
- meta (:obj:`Optional[dict]`): The new dict which is supposed to merge with the old one.
"""
if not self.indices.has(index):
return False
i = self.indices.get(index)
item = self.storage[i]
if data is not None:
item.data = data
if meta is not None:
item.meta = meta
for key in self.meta_index:
self.meta_index[key][i] = meta[key] if key in meta else None
return True
@apply_middleware("delete")
def delete(self, indices: Union[str, Iterable[str]]) -> None:
"""
Overview:
The method that delete the data and related meta information by specific indices.
Arguments:
- indices (Union[str, Iterable[str]]): Where the data to be cleared in the buffer.
"""
if isinstance(indices, str):
indices = [indices]
del_idx = []
for index in indices:
if self.indices.has(index):
del_idx.append(self.indices.get(index))
if len(del_idx) == 0:
return
del_idx = sorted(del_idx, reverse=True)
for idx in del_idx:
del self.storage[idx]
remain_indices = [item.index for item in self.storage]
key_value_pairs = zip(remain_indices, range(len(indices)))
self.indices = BufferIndex(self.storage.maxlen, key_value_pairs)
def save_data(self, file_name: str):
if not os.path.exists(os.path.dirname(file_name)):
# If the folder for the specified file does not exist, it will be created.
if os.path.dirname(file_name) != "":
os.makedirs(os.path.dirname(file_name))
hickle.dump(
py_obj=(
self.storage,
self.indices,
self.meta_index,
), file_obj=file_name
)
def load_data(self, file_name: str):
self.storage, self.indices, self.meta_index = hickle.load(file_name)
def count(self) -> int:
"""
Overview:
The method that returns the current length of the buffer.
"""
return len(self.storage)
def get(self, idx: int) -> BufferedData:
"""
Overview:
The method that returns the BufferedData object given a specific index.
"""
return self.storage[idx]
@apply_middleware("clear")
def clear(self) -> None:
"""
Overview:
The method that clear all data, indices, and the meta information in the buffer.
"""
self.storage.clear()
self.indices.clear()
self.meta_index = {}
def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
index = uuid.uuid1().hex
if meta is None:
meta = {}
buffered = BufferedData(data=data, index=index, meta=meta)
self.storage.append(buffered)
self.indices.append(index)
# Add meta index
for key in self.meta_index:
self.meta_index[key].append(meta[key] if key in meta else None)
return buffered
def _independence(
self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]]
) -> Union[List[BufferedData], List[List[BufferedData]]]:
"""
Overview:
Make sure that each record is different from each other, but remember that this function
is different from clone_object. You may change the data in the buffer by modifying a record.
Arguments:
- buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data,
can be nested if groupby has been set.
"""
if len(buffered_samples) == 0:
return buffered_samples
occurred = defaultdict(int)
for i, buffered in enumerate(buffered_samples):
if isinstance(buffered, list):
sampled_list = buffered
# Loop over nested samples
for j, buffered in enumerate(sampled_list):
occurred[buffered.index] += 1
if occurred[buffered.index] > 1:
sampled_list[j] = fastcopy.copy(buffered)
elif isinstance(buffered, BufferedData):
occurred[buffered.index] += 1
if occurred[buffered.index] > 1:
buffered_samples[i] = fastcopy.copy(buffered)
else:
raise Exception("Get unexpected buffered type {}".format(type(buffered)))
return buffered_samples
def _sample_by_group(
self,
size: int,
groupby: str,
replace: bool = False,
unroll_len: Optional[int] = None,
storage: deque = None,
sliced: bool = False
) -> List[List[BufferedData]]:
"""
Overview:
Sampling by `group` instead of records, the result will be a collection
of lists with a length of `size`, but the length of each list may be different from other lists.
"""
if storage is None:
storage = self.storage
if groupby not in self.meta_index:
self._create_index(groupby)
def filter_by_unroll_len():
"Filter groups by unroll len, ensure count of items in each group is greater than unroll_len."
group_count = Counter(self.meta_index[groupby])
group_names = []
for key, count in group_count.items():
if count >= unroll_len:
group_names.append(key)
return group_names
if unroll_len and unroll_len > 1:
group_names = filter_by_unroll_len()
if len(group_names) == 0:
return []
else:
group_names = list(set(self.meta_index[groupby]))
sampled_groups = []
if replace:
sampled_groups = random.choices(group_names, k=size)
else:
try:
sampled_groups = random.sample(group_names, k=size)
except ValueError:
raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(group_names)))
# Build dict like {"group name": [records]}
sampled_data = defaultdict(list)
for buffered in storage:
meta_value = buffered.meta[groupby] if groupby in buffered.meta else None
if meta_value in sampled_groups:
sampled_data[buffered.meta[groupby]].append(buffered)
final_sampled_data = []
for group in sampled_groups:
seq_data = sampled_data[group]
# Filter records by unroll_len
if unroll_len:
# slice b unroll_len. If don’t do this, more likely obtain duplicate data, \
# and the training will easily crash.
if sliced:
start_indice = random.choice(range(max(1, len(seq_data))))
start_indice = start_indice // unroll_len
if start_indice == (len(seq_data) - 1) // unroll_len:
seq_data = seq_data[-unroll_len:]
else:
seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len]
else:
start_indice = random.choice(range(max(1, len(seq_data) - unroll_len)))
seq_data = seq_data[start_indice:start_indice + unroll_len]
final_sampled_data.append(seq_data)
return final_sampled_data
def _create_index(self, meta_key: str):
self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen)
for data in self.storage:
self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None)
def __iter__(self) -> deque:
return iter(self.storage)
def __copy__(self) -> "DequeBuffer":
buffer = type(self)(size=self.storage.maxlen)
buffer.storage = self.storage
buffer.meta_index = self.meta_index
buffer.indices = self.indices
return buffer
|