import os import itertools import random import uuid from ditk import logging import hickle from typing import Any, Iterable, List, Optional, Tuple, Union from collections import Counter from collections import defaultdict, deque, OrderedDict from ding.data.buffer import Buffer, apply_middleware, BufferedData from ding.utils import fastcopy from ding.torch_utils import get_null_data class BufferIndex(): """ Overview: Save index string and offset in key value pair. """ def __init__(self, maxlen: int, *args, **kwargs): self.maxlen = maxlen self.__map = OrderedDict(*args, **kwargs) self._last_key = next(reversed(self.__map)) if len(self) > 0 else None self._cumlen = len(self.__map) def get(self, key: str) -> int: value = self.__map[key] value = value % self._cumlen + min(0, (self.maxlen - self._cumlen)) return value def __len__(self) -> int: return len(self.__map) def has(self, key: str) -> bool: return key in self.__map def append(self, key: str): self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0 self._last_key = key self._cumlen += 1 if len(self) > self.maxlen: self.__map.popitem(last=False) def clear(self): self.__map = OrderedDict() self._last_key = None self._cumlen = 0 class DequeBuffer(Buffer): """ Overview: A buffer implementation based on the deque structure. """ def __init__(self, size: int, sliced: bool = False) -> None: """ Overview: The initialization method of DequeBuffer. Arguments: - size (:obj:`int`): The maximum number of objects that the buffer can hold. - sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group """ super().__init__(size=size) self.storage = deque(maxlen=size) self.indices = BufferIndex(maxlen=size) self.sliced = sliced # Meta index is a dict which uses deque as values self.meta_index = {} @apply_middleware("push") def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: """ Overview: The method that input the objects and the related meta information into the buffer. Arguments: - data (:obj:`Any`): The input object which can be in any format. - meta (:obj:`Optional[dict]`): A dict that helps describe data, such as\ category, label, priority, etc. Default to ``None``. """ return self._push(data, meta) @apply_middleware("sample") def sample( self, size: Optional[int] = None, indices: Optional[List[str]] = None, replace: bool = False, sample_range: Optional[slice] = None, ignore_insufficient: bool = False, groupby: Optional[str] = None, unroll_len: Optional[int] = None ) -> Union[List[BufferedData], List[List[BufferedData]]]: """ Overview: The method that randomly sample data from the buffer or retrieve certain data by indices. Arguments: - size (:obj:`Optional[int]`): The number of objects to be obtained from the buffer. If ``indices`` is not specified, the ``size`` is required to randomly sample the\ corresponding number of objects from the buffer. - indices (:obj:`Optional[List[str]]`): Only used when you want to retrieve data by indices. Default to ``None``. - replace (:obj:`bool`): As the sampling process is carried out one by one, this parameter\ determines whether the previous samples will be put back into the buffer for subsequent\ sampling. Default to ``False``, it means that duplicate samples will not appear in one\ ``sample`` call. - sample_range (:obj:`Optional[slice]`): The indices range to sample data. Default to ``None``,\ it means no restrictions on the range of indices for the sampling process. - ignore_insufficient (:obj:`bool`): whether throw `` ValueError`` if the sampled size is smaller\ than the required size. Default to ``False``. - groupby (:obj:`Optional[str]`): If this parameter is activated, the method will return a\ target size of object groups. - unroll_len (:obj:`Optional[int]`): The unroll length of a trajectory, used only when the\ ``groupby`` is activated. Returns: - sampled_data (Union[List[BufferedData], List[List[BufferedData]]]): The sampling result. """ storage = self.storage if sample_range: storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step)) # Size and indices assert size or indices, "One of size and indices must not be empty." if (size and indices) and (size != len(indices)): raise AssertionError("Size and indices length must be equal.") if not size: size = len(indices) # Indices and groupby assert not (indices and groupby), "Cannot use groupby and indicex at the same time." # Groupby and unroll_len assert not unroll_len or ( unroll_len and groupby ), "Parameter unroll_len needs to be used in conjunction with groupby." value_error = None sampled_data = [] if indices: indices_set = set(indices) hashed_data = filter(lambda item: item.index in indices_set, storage) hashed_data = map(lambda item: (item.index, item), hashed_data) hashed_data = dict(hashed_data) # Re-sample and return in indices order sampled_data = [hashed_data[index] for index in indices] elif groupby: sampled_data = self._sample_by_group( size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced ) else: if replace: sampled_data = random.choices(storage, k=size) else: try: sampled_data = random.sample(storage, k=size) except ValueError as e: value_error = e if value_error or len(sampled_data) != size: if ignore_insufficient: logging.warning( "Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}". format(self.count(), size) ) else: raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count())) sampled_data = self._independence(sampled_data) return sampled_data @apply_middleware("update") def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: """ Overview: the method that update data and the related meta information with a certain index. Arguments: - data (:obj:`Any`): The data which is supposed to replace the old one. If you set it\ to ``None``, nothing will happen to the old record. - meta (:obj:`Optional[dict]`): The new dict which is supposed to merge with the old one. """ if not self.indices.has(index): return False i = self.indices.get(index) item = self.storage[i] if data is not None: item.data = data if meta is not None: item.meta = meta for key in self.meta_index: self.meta_index[key][i] = meta[key] if key in meta else None return True @apply_middleware("delete") def delete(self, indices: Union[str, Iterable[str]]) -> None: """ Overview: The method that delete the data and related meta information by specific indices. Arguments: - indices (Union[str, Iterable[str]]): Where the data to be cleared in the buffer. """ if isinstance(indices, str): indices = [indices] del_idx = [] for index in indices: if self.indices.has(index): del_idx.append(self.indices.get(index)) if len(del_idx) == 0: return del_idx = sorted(del_idx, reverse=True) for idx in del_idx: del self.storage[idx] remain_indices = [item.index for item in self.storage] key_value_pairs = zip(remain_indices, range(len(indices))) self.indices = BufferIndex(self.storage.maxlen, key_value_pairs) def save_data(self, file_name: str): if not os.path.exists(os.path.dirname(file_name)): # If the folder for the specified file does not exist, it will be created. if os.path.dirname(file_name) != "": os.makedirs(os.path.dirname(file_name)) hickle.dump( py_obj=( self.storage, self.indices, self.meta_index, ), file_obj=file_name ) def load_data(self, file_name: str): self.storage, self.indices, self.meta_index = hickle.load(file_name) def count(self) -> int: """ Overview: The method that returns the current length of the buffer. """ return len(self.storage) def get(self, idx: int) -> BufferedData: """ Overview: The method that returns the BufferedData object given a specific index. """ return self.storage[idx] @apply_middleware("clear") def clear(self) -> None: """ Overview: The method that clear all data, indices, and the meta information in the buffer. """ self.storage.clear() self.indices.clear() self.meta_index = {} def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: index = uuid.uuid1().hex if meta is None: meta = {} buffered = BufferedData(data=data, index=index, meta=meta) self.storage.append(buffered) self.indices.append(index) # Add meta index for key in self.meta_index: self.meta_index[key].append(meta[key] if key in meta else None) return buffered def _independence( self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]] ) -> Union[List[BufferedData], List[List[BufferedData]]]: """ Overview: Make sure that each record is different from each other, but remember that this function is different from clone_object. You may change the data in the buffer by modifying a record. Arguments: - buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data, can be nested if groupby has been set. """ if len(buffered_samples) == 0: return buffered_samples occurred = defaultdict(int) for i, buffered in enumerate(buffered_samples): if isinstance(buffered, list): sampled_list = buffered # Loop over nested samples for j, buffered in enumerate(sampled_list): occurred[buffered.index] += 1 if occurred[buffered.index] > 1: sampled_list[j] = fastcopy.copy(buffered) elif isinstance(buffered, BufferedData): occurred[buffered.index] += 1 if occurred[buffered.index] > 1: buffered_samples[i] = fastcopy.copy(buffered) else: raise Exception("Get unexpected buffered type {}".format(type(buffered))) return buffered_samples def _sample_by_group( self, size: int, groupby: str, replace: bool = False, unroll_len: Optional[int] = None, storage: deque = None, sliced: bool = False ) -> List[List[BufferedData]]: """ Overview: Sampling by `group` instead of records, the result will be a collection of lists with a length of `size`, but the length of each list may be different from other lists. """ if storage is None: storage = self.storage if groupby not in self.meta_index: self._create_index(groupby) def filter_by_unroll_len(): "Filter groups by unroll len, ensure count of items in each group is greater than unroll_len." group_count = Counter(self.meta_index[groupby]) group_names = [] for key, count in group_count.items(): if count >= unroll_len: group_names.append(key) return group_names if unroll_len and unroll_len > 1: group_names = filter_by_unroll_len() if len(group_names) == 0: return [] else: group_names = list(set(self.meta_index[groupby])) sampled_groups = [] if replace: sampled_groups = random.choices(group_names, k=size) else: try: sampled_groups = random.sample(group_names, k=size) except ValueError: raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(group_names))) # Build dict like {"group name": [records]} sampled_data = defaultdict(list) for buffered in storage: meta_value = buffered.meta[groupby] if groupby in buffered.meta else None if meta_value in sampled_groups: sampled_data[buffered.meta[groupby]].append(buffered) final_sampled_data = [] for group in sampled_groups: seq_data = sampled_data[group] # Filter records by unroll_len if unroll_len: # slice b unroll_len. If don’t do this, more likely obtain duplicate data, \ # and the training will easily crash. if sliced: start_indice = random.choice(range(max(1, len(seq_data)))) start_indice = start_indice // unroll_len if start_indice == (len(seq_data) - 1) // unroll_len: seq_data = seq_data[-unroll_len:] else: seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len] else: start_indice = random.choice(range(max(1, len(seq_data) - unroll_len))) seq_data = seq_data[start_indice:start_indice + unroll_len] final_sampled_data.append(seq_data) return final_sampled_data def _create_index(self, meta_key: str): self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen) for data in self.storage: self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None) def __iter__(self) -> deque: return iter(self.storage) def __copy__(self) -> "DequeBuffer": buffer = type(self)(size=self.storage.maxlen) buffer.storage = self.storage buffer.meta_index = self.meta_index buffer.indices = self.indices return buffer