from typing import Union, Dict, Any, List from abc import ABC, abstractmethod import copy from easydict import EasyDict from ding.utils import import_module, BUFFER_REGISTRY class IBuffer(ABC): r""" Overview: Buffer interface Interfaces: default_config, push, update, sample, clear, count, state_dict, load_state_dict """ @classmethod def default_config(cls) -> EasyDict: r""" Overview: Default config of this buffer class. Returns: - default_config (:obj:`EasyDict`) """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' return cfg @abstractmethod def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None: r""" Overview: Push a data into buffer. Arguments: - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \ (in `Any` type), or many(int `List[Any]` type). - cur_collector_envstep (:obj:`int`): Collector's current env step. """ raise NotImplementedError @abstractmethod def update(self, info: Dict[str, list]) -> None: r""" Overview: Update data info, e.g. priority. Arguments: - info (:obj:`Dict[str, list]`): Info dict. Keys depends on the specific buffer type. """ raise NotImplementedError @abstractmethod def sample(self, batch_size: int, cur_learner_iter: int) -> list: r""" Overview: Sample data with length ``batch_size``. Arguments: - size (:obj:`int`): The number of the data that will be sampled. - cur_learner_iter (:obj:`int`): Learner's current iteration. Returns: - sampled_data (:obj:`list`): A list of data with length `batch_size`. """ raise NotImplementedError @abstractmethod def clear(self) -> None: """ Overview: Clear all the data and reset the related variables. """ raise NotImplementedError @abstractmethod def count(self) -> int: """ Overview: Count how many valid datas there are in the buffer. Returns: - count (:obj:`int`): Number of valid data. """ raise NotImplementedError @abstractmethod def save_data(self, file_name: str): """ Overview: Save buffer data into a file. Arguments: - file_name (:obj:`str`): file name of buffer data """ raise NotImplementedError @abstractmethod def load_data(self, file_name: str): """ Overview: Load buffer data from a file. Arguments: - file_name (:obj:`str`): file name of buffer data """ raise NotImplementedError @abstractmethod def state_dict(self) -> Dict[str, Any]: """ Overview: Provide a state dict to keep a record of current buffer. Returns: - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \ With the dict, one can easily reproduce the buffer. """ raise NotImplementedError @abstractmethod def load_state_dict(self, _state_dict: Dict[str, Any]) -> None: """ Overview: Load state dict to reproduce the buffer. Returns: - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. """ raise NotImplementedError def create_buffer(cfg: EasyDict, *args, **kwargs) -> IBuffer: r""" Overview: Create a buffer according to cfg and other arguments. Arguments: - cfg (:obj:`EasyDict`): Buffer config. ArgumentsKeys: - necessary: `type` """ import_module(cfg.get('import_names', [])) if cfg.type == 'naive': kwargs.pop('tb_logger', None) return BUFFER_REGISTRY.build(cfg.type, cfg, *args, **kwargs) def get_buffer_cls(cfg: EasyDict) -> type: r""" Overview: Get a buffer class according to cfg. Arguments: - cfg (:obj:`EasyDict`): Buffer config. ArgumentsKeys: - necessary: `type` """ import_module(cfg.get('import_names', [])) return BUFFER_REGISTRY.get(cfg.type)