zjowowen's picture
init space
079c32c
from typing import Union, Dict, Any, List
from abc import ABC, abstractmethod
import copy
from easydict import EasyDict
from ding.utils import import_module, BUFFER_REGISTRY
class IBuffer(ABC):
r"""
Overview:
Buffer interface
Interfaces:
default_config, push, update, sample, clear, count, state_dict, load_state_dict
"""
@classmethod
def default_config(cls) -> EasyDict:
r"""
Overview:
Default config of this buffer class.
Returns:
- default_config (:obj:`EasyDict`)
"""
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
@abstractmethod
def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None:
r"""
Overview:
Push a data into buffer.
Arguments:
- data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \
(in `Any` type), or many(int `List[Any]` type).
- cur_collector_envstep (:obj:`int`): Collector's current env step.
"""
raise NotImplementedError
@abstractmethod
def update(self, info: Dict[str, list]) -> None:
r"""
Overview:
Update data info, e.g. priority.
Arguments:
- info (:obj:`Dict[str, list]`): Info dict. Keys depends on the specific buffer type.
"""
raise NotImplementedError
@abstractmethod
def sample(self, batch_size: int, cur_learner_iter: int) -> list:
r"""
Overview:
Sample data with length ``batch_size``.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration.
Returns:
- sampled_data (:obj:`list`): A list of data with length `batch_size`.
"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""
Overview:
Clear all the data and reset the related variables.
"""
raise NotImplementedError
@abstractmethod
def count(self) -> int:
"""
Overview:
Count how many valid datas there are in the buffer.
Returns:
- count (:obj:`int`): Number of valid data.
"""
raise NotImplementedError
@abstractmethod
def save_data(self, file_name: str):
"""
Overview:
Save buffer data into a file.
Arguments:
- file_name (:obj:`str`): file name of buffer data
"""
raise NotImplementedError
@abstractmethod
def load_data(self, file_name: str):
"""
Overview:
Load buffer data from a file.
Arguments:
- file_name (:obj:`str`): file name of buffer data
"""
raise NotImplementedError
@abstractmethod
def state_dict(self) -> Dict[str, Any]:
"""
Overview:
Provide a state dict to keep a record of current buffer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \
With the dict, one can easily reproduce the buffer.
"""
raise NotImplementedError
@abstractmethod
def load_state_dict(self, _state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load state dict to reproduce the buffer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer.
"""
raise NotImplementedError
def create_buffer(cfg: EasyDict, *args, **kwargs) -> IBuffer:
r"""
Overview:
Create a buffer according to cfg and other arguments.
Arguments:
- cfg (:obj:`EasyDict`): Buffer config.
ArgumentsKeys:
- necessary: `type`
"""
import_module(cfg.get('import_names', []))
if cfg.type == 'naive':
kwargs.pop('tb_logger', None)
return BUFFER_REGISTRY.build(cfg.type, cfg, *args, **kwargs)
def get_buffer_cls(cfg: EasyDict) -> type:
r"""
Overview:
Get a buffer class according to cfg.
Arguments:
- cfg (:obj:`EasyDict`): Buffer config.
ArgumentsKeys:
- necessary: `type`
"""
import_module(cfg.get('import_names', []))
return BUFFER_REGISTRY.get(cfg.type)