File size: 13,769 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 |
import os
from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional
from easydict import EasyDict
from ditk import logging
import torch
from import Buffer, Dataset, DataLoader, offline_data_save_type
from import PriorityExperienceReplay
from ding.framework import task
from ding.utils import get_rank
from ding.framework import OnlineRLContext, OfflineRLContext
def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None):
Push episodes or trajectories into the buffer.
- cfg (:obj:`EasyDict`): Config.
- buffer (:obj:`Buffer`): Buffer to push the data in.
if task.router.is_active and not task.has_role(task.role.LEARNER):
return task.void()
def _push(ctx: "OnlineRLContext"):
In ctx, either `ctx.trajectories` or `ctx.episodes` should not be None.
Input of ctx:
- trajectories (:obj:`List[Dict]`): Trajectories.
- episodes (:obj:`List[Dict]`): Episodes.
if ctx.trajectories is not None: # each data in buffer is a transition
if group_by_env:
for i, t in enumerate(ctx.trajectories):
buffer_.push(t, {'env': t.env_data_id.item()})
for t in ctx.trajectories:
ctx.trajectories = None
elif ctx.episodes is not None: # each data in buffer is a episode
for t in ctx.episodes:
ctx.episodes = None
raise RuntimeError("Either ctx.trajectories or ctx.episodes should be not None.")
return _push
def buffer_saver(cfg: EasyDict, buffer_: Buffer, every_envstep: int = 1000, replace: bool = False):
Save current buffer data.
- cfg (:obj:`EasyDict`): Config.
- buffer (:obj:`Buffer`): Buffer to push the data in.
- every_envstep (:obj:`int`): save at every env step.
- replace (:obj:`bool`): Whether replace the last file.
buffer_saver_env_counter = -every_envstep
def _save(ctx: "OnlineRLContext"):
In ctx, `ctx.env_step` should not be None.
Input of ctx:
- env_step (:obj:`int`): env step.
nonlocal buffer_saver_env_counter
if ctx.env_step is not None:
if ctx.env_step >= every_envstep + buffer_saver_env_counter:
buffer_saver_env_counter = ctx.env_step
if replace:
buffer_.save_data(os.path.join(cfg.exp_name, "replaybuffer", "data_latest.hkl"))
os.path.join(cfg.exp_name, "replaybuffer", "data_envstep_{}.hkl".format(ctx.env_step))
raise RuntimeError("buffer_saver only supports collecting data by step rather than episode.")
return _save
def offpolicy_data_fetcher(
cfg: EasyDict,
buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
data_shortage_warning: bool = False,
) -> Callable:
The return function is a generator which meanly fetch a batch of data from a buffer, \
a list of buffers, or a dict of buffers.
- cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
- buffer (:obj:`Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]`): \
The buffer where the data is fetched from. \
``Buffer`` type means a buffer.\
``List[Tuple[Buffer, float]]`` type means a list of tuple. In each tuple there is a buffer and a float. \
The float defines, how many batch_size is the size of the data \
which is sampled from the corresponding buffer.\
``Dict[str, Buffer]`` type means a dict in which the value of each element is a buffer. \
For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer \
and assigned to the same key of `ctx.train_data`.
- data_shortage_warning (:obj:`bool`): Whether to output warning when data shortage occurs in fetching.
def _fetch(ctx: "OnlineRLContext"):
Input of ctx:
- train_output (:obj:`Union[Dict, Deque[Dict]]`): This attribute should exist \
if `buffer_` is of type Buffer and if `buffer_` use the middleware `PriorityExperienceReplay`. \
The meta data `priority` of the sampled data in the `buffer_` will be updated \
to the `priority` attribute of `ctx.train_output` if `ctx.train_output` is a dict, \
or the `priority` attribute of `ctx.train_output`'s popped element \
if `ctx.train_output` is a deque of dicts.
Output of ctx:
- train_data (:obj:`Union[List[Dict], Dict[str, List[Dict]]]`): The fetched data. \
``List[Dict]`` type means a list of data.
`train_data` is of this type if the type of `buffer_` is Buffer or List.
``Dict[str, List[Dict]]]`` type means a dict, in which the value of each key-value pair
is a list of data. `train_data` is of this type if the type of `buffer_` is Dict.
unroll_len = cfg.policy.collect.unroll_len
if isinstance(buffer_, Buffer):
if unroll_len > 1:
buffered_data = buffer_.sample(
cfg.policy.learn.batch_size, groupby="env", unroll_len=unroll_len, replace=True
ctx.train_data = [[ for t in d] for d in buffered_data] # B, unroll_len
buffered_data = buffer_.sample(cfg.policy.learn.batch_size)
ctx.train_data = [ for d in buffered_data]
elif isinstance(buffer_, List): # like sqil, r2d3
assert unroll_len == 1, "not support"
buffered_data = []
for buffer_elem, p in buffer_:
data_elem = buffer_elem.sample(int(cfg.policy.learn.batch_size * p))
assert data_elem is not None
buffered_data = sum(buffered_data, [])
ctx.train_data = [ for d in buffered_data]
elif isinstance(buffer_, Dict): # like ppg_offpolicy
assert unroll_len == 1, "not support"
buffered_data = {k: v.sample(cfg.policy.learn.batch_size) for k, v in buffer_.items()}
ctx.train_data = {k: [ for d in v] for k, v in buffered_data.items()}
raise TypeError("not support buffer argument type: {}".format(type(buffer_)))
assert buffered_data is not None
except (ValueError, AssertionError):
if data_shortage_warning:
# You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode.
# Fetcher will skip this this attempt.
"Replay buffer's data is not enough to support training, so skip this training to wait more data."
ctx.train_data = None
if isinstance(buffer_, Buffer):
if any([isinstance(m, PriorityExperienceReplay) for m in buffer_._middleware]):
index = [d.index for d in buffered_data]
meta = [d.meta for d in buffered_data]
# such as priority
if isinstance(ctx.train_output, List):
priority = ctx.train_output.pop()['priority']
priority = ctx.train_output['priority']
for idx, m, p in zip(index, meta, priority):
m['priority'] = p
buffer_.update(index=idx, data=None, meta=m)
return _fetch
def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:
from threading import Thread
from queue import Queue
import time
stream = torch.cuda.Stream()
def producer(queue, dataset, batch_size, device):
nonlocal stream
idx_iter = iter(range(len(dataset) - batch_size))
if len(dataset) < batch_size:
logging.warning('batch_size is too large!!!!')
while True:
if queue.full():
start_idx = next(idx_iter)
except StopIteration:
del idx_iter
idx_iter = iter(range(len(dataset) - batch_size))
start_idx = next(idx_iter)
data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
data = [[i[j] for i in data] for j in range(len(data[0]))]
data = [torch.stack(x).to(device) for x in data]
queue = Queue(maxsize=50)
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
producer_thread = Thread(
target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer'
def _fetch(ctx: "OfflineRLContext"):
nonlocal queue, producer_thread
if not producer_thread.is_alive():
while queue.empty():
ctx.train_data = queue.get()
return _fetch
def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
The outer function transforms a Pytorch `Dataset` to `DataLoader`. \
The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\
Please refer to the link \
and for more details.
- cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
- dataset (:obj:`Dataset`): The dataset of type `` which stores the data.
# collate_fn is executed in policy now
dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
dataloader = iter(dataloader)
def _fetch(ctx: "OfflineRLContext"):
Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \
After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1.
Input of ctx:
- train_epoch (:obj:`int`): Number of `train_epoch`.
Output of ctx:
- train_data (:obj:`List[Tensor]`): The fetched data batch.
nonlocal dataloader
ctx.train_data = next(dataloader) # noqa
except StopIteration:
ctx.train_epoch += 1
del dataloader
dataloader = DataLoader(
dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x
dataloader = iter(dataloader)
ctx.train_data = next(dataloader)
# TODO apply data update (e.g. priority) in offline setting when necessary
ctx.trained_env_step += len(ctx.train_data)
return _fetch
def offline_data_saver(data_path: str, data_type: str = 'hdf5') -> Callable:
Save the expert data of offline RL in a directory.
- data_path (:obj:`str`): File path where the expert data will be written into, which is usually ./expert.pkl'.
- data_type (:obj:`str`): Define the type of the saved data. \
The type of saved data is pkl if `data_type == 'naive'`. \
The type of saved data is hdf5 if `data_type == 'hdf5'`.
def _save(ctx: "OnlineRLContext"):
Input of ctx:
- trajectories (:obj:`List[Tensor]`): The expert data to be saved.
data = ctx.trajectories
offline_data_save_type(data, data_path, data_type)
ctx.trajectories = None
return _save
def sqil_data_pusher(cfg: EasyDict, buffer_: Buffer, expert: bool) -> Callable:
Push trajectories into the buffer in sqil learning pipeline.
- cfg (:obj:`EasyDict`): Config.
- buffer (:obj:`Buffer`): Buffer to push the data in.
- expert (:obj:`bool`): Whether the pushed data is expert data or not. \
In each element of the pushed data, the reward will be set to 1 if this attribute is `True`, otherwise 0.
def _pusher(ctx: "OnlineRLContext"):
Input of ctx:
- trajectories (:obj:`List[Dict]`): The trajectories to be pushed.
for t in ctx.trajectories:
if expert:
t.reward = torch.ones_like(t.reward)
t.reward = torch.zeros_like(t.reward)
ctx.trajectories = None
return _pusher