File size: 7,077 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from typing import Callable, Any, List, Dict, Optional, Union, TYPE_CHECKING
import copy
import numpy as np
import torch
from ding.utils import SumSegmentTree, MinSegmentTree
from ding.data.buffer.buffer import BufferedData
if TYPE_CHECKING:
from ding.data.buffer.buffer import Buffer
class PriorityExperienceReplay:
"""
Overview:
The middleware that implements priority experience replay (PER).
"""
def __init__(
self,
buffer: 'Buffer',
IS_weight: bool = True,
priority_power_factor: float = 0.6,
IS_weight_power_factor: float = 0.4,
IS_weight_anneal_train_iter: int = int(1e5),
) -> None:
"""
Arguments:
- buffer (:obj:`Buffer`): The buffer to use PER.
- IS_weight (:obj:`bool`): Whether use importance sampling or not.
- priority_power_factor (:obj:`float`): The factor that adjust the sensitivity between\
the sampling probability and the priority level.
- IS_weight_power_factor (:obj:`float`): The factor that adjust the sensitivity between\
the sample rarity and sampling probability in importance sampling.
- IS_weight_anneal_train_iter (:obj:`float`): The factor that controls the increasing of\
``IS_weight_power_factor`` during training.
"""
self.buffer = buffer
self.buffer_idx = {}
self.buffer_size = buffer.size
self.IS_weight = IS_weight
self.priority_power_factor = priority_power_factor
self.IS_weight_power_factor = IS_weight_power_factor
self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter
# Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data.
self.max_priority = 1.0
# Capacity needs to be the power of 2.
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
self.sum_tree = SumSegmentTree(capacity)
if self.IS_weight:
self.min_tree = MinSegmentTree(capacity)
self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter
self.pivot = 0
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData:
if meta is None:
if 'priority' in data:
meta = {'priority': data.pop('priority')}
else:
meta = {'priority': self.max_priority}
else:
if 'priority' not in meta:
meta['priority'] = self.max_priority
meta['priority_idx'] = self.pivot
self._update_tree(meta['priority'], self.pivot)
buffered = chain(data, meta=meta, *args, **kwargs)
index = buffered.index
self.buffer_idx[self.pivot] = index
self.pivot = (self.pivot + 1) % self.buffer_size
return buffered
def sample(self, chain: Callable, size: int, *args,
**kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
# Divide [0, 1) into size intervals on average
intervals = np.array([i * 1.0 / size for i in range(size)])
# Uniformly sample within each interval
mass = intervals + np.random.uniform(size=(size, )) * 1. / size
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
mass *= self.sum_tree.reduce()
indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass]
indices = [self.buffer_idx[i] for i in indices]
# Sample with indices
data = chain(indices=indices, *args, **kwargs)
if self.IS_weight:
# Calculate max weight for normalizing IS
sum_tree_root = self.sum_tree.reduce()
p_min = self.min_tree.reduce() / sum_tree_root
buffer_count = self.buffer.count()
max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor)
for i in range(len(data)):
meta = data[i].meta
priority_idx = meta['priority_idx']
p_sample = self.sum_tree[priority_idx] / sum_tree_root
weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor)
meta['priority_IS'] = weight / max_weight
data[i].data['priority_IS'] = torch.as_tensor([meta['priority_IS']]).float() # for compability
self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal)
return data
def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwargs) -> None:
update_flag = chain(index, data, meta, *args, **kwargs)
if update_flag: # when update succeed
assert meta is not None, "Please indicate dict-type meta in priority update"
new_priority, idx = meta['priority'], meta['priority_idx']
assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority)
new_priority += 1e-5 # Add epsilon to avoid priority == 0
self._update_tree(new_priority, idx)
self.max_priority = max(self.max_priority, new_priority)
def delete(self, chain: Callable, index: str, *args, **kwargs) -> None:
for item in self.buffer.storage:
meta = item.meta
priority_idx = meta['priority_idx']
self.sum_tree[priority_idx] = self.sum_tree.neutral_element
self.min_tree[priority_idx] = self.min_tree.neutral_element
self.buffer_idx.pop(priority_idx)
return chain(index, *args, **kwargs)
def clear(self, chain: Callable) -> None:
self.max_priority = 1.0
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
self.sum_tree = SumSegmentTree(capacity)
if self.IS_weight:
self.min_tree = MinSegmentTree(capacity)
self.buffer_idx = {}
self.pivot = 0
chain()
def _update_tree(self, priority: float, idx: int) -> None:
weight = priority ** self.priority_power_factor
self.sum_tree[idx] = weight
if self.IS_weight:
self.min_tree[idx] = weight
def state_dict(self) -> Dict:
return {
'max_priority': self.max_priority,
'IS_weight_power_factor': self.IS_weight_power_factor,
'sumtree': self.sumtree,
'mintree': self.mintree,
'buffer_idx': self.buffer_idx,
}
def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None:
for k, v in _state_dict.items():
if deepcopy:
setattr(self, '{}'.format(k), copy.deepcopy(v))
else:
setattr(self, '{}'.format(k), v)
def __call__(self, action: str, chain: Callable, *args, **kwargs) -> Any:
if action in ["push", "sample", "update", "delete", "clear"]:
return getattr(self, action)(chain, *args, **kwargs)
return chain(*args, **kwargs)
|