zjowowen's picture
init space
079c32c
raw
history blame
4.46 kB
from typing import Union, Dict, Any, List
from abc import ABC, abstractmethod
import copy
from easydict import EasyDict
from ding.utils import import_module, BUFFER_REGISTRY
class IBuffer(ABC):
r"""
Overview:
Buffer interface
Interfaces:
default_config, push, update, sample, clear, count, state_dict, load_state_dict
"""
@classmethod
def default_config(cls) -> EasyDict:
r"""
Overview:
Default config of this buffer class.
Returns:
- default_config (:obj:`EasyDict`)
"""
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
@abstractmethod
def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None:
r"""
Overview:
Push a data into buffer.
Arguments:
- data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \
(in `Any` type), or many(int `List[Any]` type).
- cur_collector_envstep (:obj:`int`): Collector's current env step.
"""
raise NotImplementedError
@abstractmethod
def update(self, info: Dict[str, list]) -> None:
r"""
Overview:
Update data info, e.g. priority.
Arguments:
- info (:obj:`Dict[str, list]`): Info dict. Keys depends on the specific buffer type.
"""
raise NotImplementedError
@abstractmethod
def sample(self, batch_size: int, cur_learner_iter: int) -> list:
r"""
Overview:
Sample data with length ``batch_size``.
Arguments:
- size (:obj:`int`): The number of the data that will be sampled.
- cur_learner_iter (:obj:`int`): Learner's current iteration.
Returns:
- sampled_data (:obj:`list`): A list of data with length `batch_size`.
"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""
Overview:
Clear all the data and reset the related variables.
"""
raise NotImplementedError
@abstractmethod
def count(self) -> int:
"""
Overview:
Count how many valid datas there are in the buffer.
Returns:
- count (:obj:`int`): Number of valid data.
"""
raise NotImplementedError
@abstractmethod
def save_data(self, file_name: str):
"""
Overview:
Save buffer data into a file.
Arguments:
- file_name (:obj:`str`): file name of buffer data
"""
raise NotImplementedError
@abstractmethod
def load_data(self, file_name: str):
"""
Overview:
Load buffer data from a file.
Arguments:
- file_name (:obj:`str`): file name of buffer data
"""
raise NotImplementedError
@abstractmethod
def state_dict(self) -> Dict[str, Any]:
"""
Overview:
Provide a state dict to keep a record of current buffer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \
With the dict, one can easily reproduce the buffer.
"""
raise NotImplementedError
@abstractmethod
def load_state_dict(self, _state_dict: Dict[str, Any]) -> None:
"""
Overview:
Load state dict to reproduce the buffer.
Returns:
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer.
"""
raise NotImplementedError
def create_buffer(cfg: EasyDict, *args, **kwargs) -> IBuffer:
r"""
Overview:
Create a buffer according to cfg and other arguments.
Arguments:
- cfg (:obj:`EasyDict`): Buffer config.
ArgumentsKeys:
- necessary: `type`
"""
import_module(cfg.get('import_names', []))
if cfg.type == 'naive':
kwargs.pop('tb_logger', None)
return BUFFER_REGISTRY.build(cfg.type, cfg, *args, **kwargs)
def get_buffer_cls(cfg: EasyDict) -> type:
r"""
Overview:
Get a buffer class according to cfg.
Arguments:
- cfg (:obj:`EasyDict`): Buffer config.
ArgumentsKeys:
- necessary: `type`
"""
import_module(cfg.get('import_names', []))
return BUFFER_REGISTRY.get(cfg.type)