gomoku / DI-engine /ding /data /buffer /deque_buffer_wrapper.py
zjowowen's picture
init space
079c32c
import os
from typing import Optional
import copy
from easydict import EasyDict
import numpy as np
import hickle
from ding.data.buffer import DequeBuffer
from ding.data.buffer.middleware import use_time_check, PriorityExperienceReplay
from ding.utils import BUFFER_REGISTRY
@BUFFER_REGISTRY.register('deque')
class DequeBufferWrapper(object):
@classmethod
def default_config(cls: type) -> EasyDict:
cfg = EasyDict(copy.deepcopy(cls.config))
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg
config = dict(
replay_buffer_size=10000,
max_use=float("inf"),
train_iter_per_log=100,
priority=False,
priority_IS_weight=False,
priority_power_factor=0.6,
IS_weight_power_factor=0.4,
IS_weight_anneal_train_iter=int(1e5),
priority_max_limit=1000,
)
def __init__(
self,
cfg: EasyDict,
tb_logger: Optional[object] = None,
exp_name: str = 'default_experiement',
instance_name: str = 'buffer'
) -> None:
self.cfg = cfg
self.priority_max_limit = cfg.priority_max_limit
self.name = '{}_iter'.format(instance_name)
self.tb_logger = tb_logger
self.buffer = DequeBuffer(size=cfg.replay_buffer_size)
self.last_log_train_iter = -1
# use_count middleware
if self.cfg.max_use != float("inf"):
self.buffer.use(use_time_check(self.buffer, max_use=self.cfg.max_use))
# priority middleware
if self.cfg.priority:
self.buffer.use(
PriorityExperienceReplay(
self.buffer,
IS_weight=self.cfg.priority_IS_weight,
priority_power_factor=self.cfg.priority_power_factor,
IS_weight_power_factor=self.cfg.IS_weight_power_factor,
IS_weight_anneal_train_iter=self.cfg.IS_weight_anneal_train_iter
)
)
self.last_sample_index = None
self.last_sample_meta = None
def sample(self, size: int, train_iter: int = 0):
output = self.buffer.sample(size=size, ignore_insufficient=True)
if len(output) > 0:
if self.last_log_train_iter == -1 or train_iter - self.last_log_train_iter >= self.cfg.train_iter_per_log:
meta = [o.meta for o in output]
if self.cfg.max_use != float("inf"):
use_count_avg = np.mean([m['use_count'] for m in meta])
self.tb_logger.add_scalar('{}/use_count_avg'.format(self.name), use_count_avg, train_iter)
if self.cfg.priority:
self.last_sample_index = [o.index for o in output]
self.last_sample_meta = meta
priority_list = [m['priority'] for m in meta]
priority_avg = np.mean(priority_list)
priority_max = np.max(priority_list)
self.tb_logger.add_scalar('{}/priority_avg'.format(self.name), priority_avg, train_iter)
self.tb_logger.add_scalar('{}/priority_max'.format(self.name), priority_max, train_iter)
self.tb_logger.add_scalar('{}/buffer_data_count'.format(self.name), self.buffer.count(), train_iter)
self.last_log_train_iter = train_iter
data = [o.data for o in output]
if self.cfg.priority_IS_weight:
IS = [o.meta['priority_IS'] for o in output]
for i in range(len(data)):
data[i]['IS'] = IS[i]
return data
else:
return None
def push(self, data, cur_collector_envstep: int = -1) -> None:
for d in data:
meta = {}
if self.cfg.priority and 'priority' in d:
init_priority = d.pop('priority')
meta['priority'] = init_priority
self.buffer.push(d, meta=meta)
def update(self, meta: dict) -> None:
if not self.cfg.priority:
return
if self.last_sample_index is None:
return
new_meta = self.last_sample_meta
for m, p in zip(new_meta, meta['priority']):
m['priority'] = min(self.priority_max_limit, p)
for idx, m in zip(self.last_sample_index, new_meta):
self.buffer.update(idx, data=None, meta=m)
self.last_sample_index = None
self.last_sample_meta = None
def count(self) -> int:
return self.buffer.count()
def save_data(self, file_name):
self.buffer.save_data(file_name)
def load_data(self, file_name: str):
self.buffer.load_data(file_name)