import os from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional from easydict import EasyDict from ditk import logging import torch from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type from ding.data.buffer.middleware import PriorityExperienceReplay from ding.framework import task from ding.utils import get_rank if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None): """ Overview: Push episodes or trajectories into the buffer. Arguments: - cfg (:obj:`EasyDict`): Config. - buffer (:obj:`Buffer`): Buffer to push the data in. """ if task.router.is_active and not task.has_role(task.role.LEARNER): return task.void() def _push(ctx: "OnlineRLContext"): """ Overview: In ctx, either `ctx.trajectories` or `ctx.episodes` should not be None. Input of ctx: - trajectories (:obj:`List[Dict]`): Trajectories. - episodes (:obj:`List[Dict]`): Episodes. """ if ctx.trajectories is not None: # each data in buffer is a transition if group_by_env: for i, t in enumerate(ctx.trajectories): buffer_.push(t, {'env': t.env_data_id.item()}) else: for t in ctx.trajectories: buffer_.push(t) ctx.trajectories = None elif ctx.episodes is not None: # each data in buffer is a episode for t in ctx.episodes: buffer_.push(t) ctx.episodes = None else: raise RuntimeError("Either ctx.trajectories or ctx.episodes should be not None.") return _push def buffer_saver(cfg: EasyDict, buffer_: Buffer, every_envstep: int = 1000, replace: bool = False): """ Overview: Save current buffer data. Arguments: - cfg (:obj:`EasyDict`): Config. - buffer (:obj:`Buffer`): Buffer to push the data in. - every_envstep (:obj:`int`): save at every env step. - replace (:obj:`bool`): Whether replace the last file. """ buffer_saver_env_counter = -every_envstep def _save(ctx: "OnlineRLContext"): """ Overview: In ctx, `ctx.env_step` should not be None. Input of ctx: - env_step (:obj:`int`): env step. """ nonlocal buffer_saver_env_counter if ctx.env_step is not None: if ctx.env_step >= every_envstep + buffer_saver_env_counter: buffer_saver_env_counter = ctx.env_step if replace: buffer_.save_data(os.path.join(cfg.exp_name, "replaybuffer", "data_latest.hkl")) else: buffer_.save_data( os.path.join(cfg.exp_name, "replaybuffer", "data_envstep_{}.hkl".format(ctx.env_step)) ) else: raise RuntimeError("buffer_saver only supports collecting data by step rather than episode.") return _save def offpolicy_data_fetcher( cfg: EasyDict, buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]], data_shortage_warning: bool = False, ) -> Callable: """ Overview: The return function is a generator which meanly fetch a batch of data from a buffer, \ a list of buffers, or a dict of buffers. Arguments: - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. - buffer (:obj:`Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]`): \ The buffer where the data is fetched from. \ ``Buffer`` type means a buffer.\ ``List[Tuple[Buffer, float]]`` type means a list of tuple. In each tuple there is a buffer and a float. \ The float defines, how many batch_size is the size of the data \ which is sampled from the corresponding buffer.\ ``Dict[str, Buffer]`` type means a dict in which the value of each element is a buffer. \ For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer \ and assigned to the same key of `ctx.train_data`. - data_shortage_warning (:obj:`bool`): Whether to output warning when data shortage occurs in fetching. """ def _fetch(ctx: "OnlineRLContext"): """ Input of ctx: - train_output (:obj:`Union[Dict, Deque[Dict]]`): This attribute should exist \ if `buffer_` is of type Buffer and if `buffer_` use the middleware `PriorityExperienceReplay`. \ The meta data `priority` of the sampled data in the `buffer_` will be updated \ to the `priority` attribute of `ctx.train_output` if `ctx.train_output` is a dict, \ or the `priority` attribute of `ctx.train_output`'s popped element \ if `ctx.train_output` is a deque of dicts. Output of ctx: - train_data (:obj:`Union[List[Dict], Dict[str, List[Dict]]]`): The fetched data. \ ``List[Dict]`` type means a list of data. `train_data` is of this type if the type of `buffer_` is Buffer or List. ``Dict[str, List[Dict]]]`` type means a dict, in which the value of each key-value pair is a list of data. `train_data` is of this type if the type of `buffer_` is Dict. """ try: unroll_len = cfg.policy.collect.unroll_len if isinstance(buffer_, Buffer): if unroll_len > 1: buffered_data = buffer_.sample( cfg.policy.learn.batch_size, groupby="env", unroll_len=unroll_len, replace=True ) ctx.train_data = [[t.data for t in d] for d in buffered_data] # B, unroll_len else: buffered_data = buffer_.sample(cfg.policy.learn.batch_size) ctx.train_data = [d.data for d in buffered_data] elif isinstance(buffer_, List): # like sqil, r2d3 assert unroll_len == 1, "not support" buffered_data = [] for buffer_elem, p in buffer_: data_elem = buffer_elem.sample(int(cfg.policy.learn.batch_size * p)) assert data_elem is not None buffered_data.append(data_elem) buffered_data = sum(buffered_data, []) ctx.train_data = [d.data for d in buffered_data] elif isinstance(buffer_, Dict): # like ppg_offpolicy assert unroll_len == 1, "not support" buffered_data = {k: v.sample(cfg.policy.learn.batch_size) for k, v in buffer_.items()} ctx.train_data = {k: [d.data for d in v] for k, v in buffered_data.items()} else: raise TypeError("not support buffer argument type: {}".format(type(buffer_))) assert buffered_data is not None except (ValueError, AssertionError): if data_shortage_warning: # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode. # Fetcher will skip this this attempt. logging.warning( "Replay buffer's data is not enough to support training, so skip this training to wait more data." ) ctx.train_data = None return yield if isinstance(buffer_, Buffer): if any([isinstance(m, PriorityExperienceReplay) for m in buffer_._middleware]): index = [d.index for d in buffered_data] meta = [d.meta for d in buffered_data] # such as priority if isinstance(ctx.train_output, List): priority = ctx.train_output.pop()['priority'] else: priority = ctx.train_output['priority'] for idx, m, p in zip(index, meta, priority): m['priority'] = p buffer_.update(index=idx, data=None, meta=m) return _fetch def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: from threading import Thread from queue import Queue import time stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device): torch.set_num_threads(4) nonlocal stream idx_iter = iter(range(len(dataset) - batch_size)) if len(dataset) < batch_size: logging.warning('batch_size is too large!!!!') with torch.cuda.stream(stream): while True: if queue.full(): time.sleep(0.1) else: try: start_idx = next(idx_iter) except StopIteration: del idx_iter idx_iter = iter(range(len(dataset) - batch_size)) start_idx = next(idx_iter) data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)] data = [[i[j] for i in data] for j in range(len(data[0]))] data = [torch.stack(x).to(device) for x in data] queue.put(data) queue = Queue(maxsize=50) device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu' producer_thread = Thread( target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer' ) def _fetch(ctx: "OfflineRLContext"): nonlocal queue, producer_thread if not producer_thread.is_alive(): time.sleep(5) producer_thread.start() while queue.empty(): time.sleep(0.001) ctx.train_data = queue.get() return _fetch def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable: """ Overview: The outer function transforms a Pytorch `Dataset` to `DataLoader`. \ The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\ Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \ and https://pytorch.org/docs/stable/data.html for more details. Arguments: - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`. - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data. """ # collate_fn is executed in policy now dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x) dataloader = iter(dataloader) def _fetch(ctx: "OfflineRLContext"): """ Overview: Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \ After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1. Input of ctx: - train_epoch (:obj:`int`): Number of `train_epoch`. Output of ctx: - train_data (:obj:`List[Tensor]`): The fetched data batch. """ nonlocal dataloader try: ctx.train_data = next(dataloader) # noqa except StopIteration: ctx.train_epoch += 1 del dataloader dataloader = DataLoader( dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x ) dataloader = iter(dataloader) ctx.train_data = next(dataloader) # TODO apply data update (e.g. priority) in offline setting when necessary ctx.trained_env_step += len(ctx.train_data) return _fetch def offline_data_saver(data_path: str, data_type: str = 'hdf5') -> Callable: """ Overview: Save the expert data of offline RL in a directory. Arguments: - data_path (:obj:`str`): File path where the expert data will be written into, which is usually ./expert.pkl'. - data_type (:obj:`str`): Define the type of the saved data. \ The type of saved data is pkl if `data_type == 'naive'`. \ The type of saved data is hdf5 if `data_type == 'hdf5'`. """ def _save(ctx: "OnlineRLContext"): """ Input of ctx: - trajectories (:obj:`List[Tensor]`): The expert data to be saved. """ data = ctx.trajectories offline_data_save_type(data, data_path, data_type) ctx.trajectories = None return _save def sqil_data_pusher(cfg: EasyDict, buffer_: Buffer, expert: bool) -> Callable: """ Overview: Push trajectories into the buffer in sqil learning pipeline. Arguments: - cfg (:obj:`EasyDict`): Config. - buffer (:obj:`Buffer`): Buffer to push the data in. - expert (:obj:`bool`): Whether the pushed data is expert data or not. \ In each element of the pushed data, the reward will be set to 1 if this attribute is `True`, otherwise 0. """ def _pusher(ctx: "OnlineRLContext"): """ Input of ctx: - trajectories (:obj:`List[Dict]`): The trajectories to be pushed. """ for t in ctx.trajectories: if expert: t.reward = torch.ones_like(t.reward) else: t.reward = torch.zeros_like(t.reward) buffer_.push(t) ctx.trajectories = None return _pusher