zjowowen's picture
init space
079c32c
from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING
from collections import defaultdict
from ding.data.buffer import BufferedData
if TYPE_CHECKING:
from ding.data.buffer.buffer import Buffer
def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable:
"""
Overview:
This middleware aims to check the usage times of data in buffer. If the usage times of a data is
greater than or equal to max_use, this data will be removed from buffer as soon as possible.
Arguments:
- max_use (:obj:`int`): The max reused (resampled) count for any individual object.
"""
use_count = defaultdict(int)
def _need_delete(item: BufferedData) -> bool:
nonlocal use_count
idx = item.index
use_count[idx] += 1
item.meta['use_count'] = use_count[idx]
if use_count[idx] >= max_use:
return True
else:
return False
def _check_use_count(sampled_data: List[BufferedData]):
delete_indices = [item.index for item in filter(_need_delete, sampled_data)]
buffer_.delete(delete_indices)
for index in delete_indices:
del use_count[index]
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
sampled_data = chain(*args, **kwargs)
if len(sampled_data) == 0:
return sampled_data
if isinstance(sampled_data[0], BufferedData):
_check_use_count(sampled_data)
else:
for grouped_data in sampled_data:
_check_use_count(grouped_data)
return sampled_data
def _use_time_check(action: str, chain: Callable, *args, **kwargs) -> Any:
if action == "sample":
return sample(chain, *args, **kwargs)
return chain(*args, **kwargs)
return _use_time_check