import numpy as np from time import sleep, time from dataclasses import fields from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union from ditk import logging from ding.framework import task from ding.data import StorageLoader, Storage, ModelLoader if TYPE_CHECKING: from ding.framework.context import Context from torch.nn import Module class ContextExchanger: def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None: """ Overview: Exchange context between processes, support properties: trajectories, episodes, env_step, env_episode, train_iter Arguments: - skip_n_iter (:obj:`int`): For collectors, it may be necessary to skip waiting \ for the first n iterations to collect data for the learner to learn. This parameter \ will not work on learner. - storage_loader (:obj:`Optional[StorageLoader]`): Turn data into storage class to reduce \ the network overhead. """ if not task.router.is_active: raise RuntimeError("ContextHandler should be used in parallel mode!") self._state = {} self._local_state = {} # just save local state, not send to remote node if task.has_role(task.role.COLLECTOR): self._local_state['env_step'] = 0 self._local_state['env_episode'] = 0 self._event_name = "context_exchanger_{role}" self._skip_n_iter = skip_n_iter self._storage_loader = storage_loader for role in task.role: # Only subscribe to other roles if not task.has_role(role): task.on(self._event_name.format(role=role), self.put) if storage_loader: task.once("finish", lambda _: storage_loader.shutdown()) def __new__(cls, *args, **kwargs): if not task.router.is_active: return task.void() if len(task.roles) == 0: logging.warning("The task does not have any roles defined, the ContextExchanger will not work.") return task.void() if len(task.roles) > 1: logging.warning( "Use multiple roles in one exchanger may lead to unexpected result, please check your code." ) return super(ContextExchanger, cls).__new__(cls) def __call__(self, ctx: "Context"): self.merge(ctx) yield payload = self.fetch(ctx) if payload: if self._storage_loader and task.has_role(task.role.COLLECTOR): payload = self._storage_loader.save(payload) for role in task.roles: task.emit(self._event_name.format(role=role), payload, only_remote=True) def __del__(self): if self._storage_loader: self._storage_loader.shutdown() def put(self, payload: Union[Dict, Storage]): """ Overview: Get attributes from ctx on the callback of event. Each attribute should have a standalone put handler, which named `_put_{key}` """ def callback(payload: Dict): for key, item in payload.items(): fn_name = "_put_{}".format(key) if hasattr(self, fn_name): getattr(self, fn_name)(item) else: logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) if isinstance(payload, Storage): assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object." self._storage_loader.load(payload, callback) else: callback(payload) def fetch(self, ctx: "Context") -> Dict[str, Any]: """ Overview: Fetch attributes from ctx before emit them to the event bus. Each attribute should have a standalone fetch handler, which named `_fetch_{key}` """ payload = {} for field in fields(ctx): key, item = field.name, getattr(ctx, field.name) fn_name = "_fetch_{}".format(key) if hasattr(self, fn_name): value = getattr(self, fn_name)(item) if value is not None: payload[key] = value return payload def merge(self, ctx: "Context"): if task.has_role(task.role.LEARNER): # Learner should always wait for trajs. # TODO: Automaticlly wait based on properties, not roles. while len(self._state) == 0: sleep(0.01) elif ctx.total_step >= self._skip_n_iter: start = time() while len(self._state) == 0: if time() - start > 60: logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id)) break sleep(0.01) for k, v in self._state.items(): if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): pure_k = k.split('increment_')[-1] setattr(ctx, pure_k, getattr(ctx, pure_k) + v) else: setattr(ctx, k, v) self._state = {} # Handle each attibute of context def _put_trajectories(self, traj: List[Any]): if not task.has_role(task.role.LEARNER): return if "trajectories" not in self._state: self._state["trajectories"] = [] self._state["trajectories"].extend(traj) def _fetch_trajectories(self, traj: List[Any]): if task.has_role(task.role.COLLECTOR): return traj def _put_episodes(self, episodes: List[Any]): if not task.has_role(task.role.LEARNER): return if "episodes" not in self._state: self._state["episodes"] = [] self._state["episodes"].extend(episodes) def _fetch_episodes(self, episodes: List[Any]): if task.has_role(task.role.COLLECTOR): return episodes def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]): if not task.has_role(task.role.LEARNER): return if "trajectory_end_idx" not in self._state: self._state["trajectory_end_idx"] = [] self._state["trajectory_end_idx"].extend(trajectory_end_idx) def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): if task.has_role(task.role.COLLECTOR): return trajectory_end_idx def _put_env_step(self, increment_env_step: int): if not task.has_role(task.role.COLLECTOR): if 'increment_env_step' not in self._state: self._state['increment_env_step'] = 0 self._state["increment_env_step"] += increment_env_step def _fetch_env_step(self, env_step: int): if task.has_role(task.role.COLLECTOR): increment_env_step = env_step - self._local_state['env_step'] self._local_state['env_step'] = env_step return increment_env_step def _put_env_episode(self, increment_env_episode: int): if not task.has_role(task.role.COLLECTOR): if 'increment_env_episode' not in self._state: self._state['increment_env_episode'] = 0 self._state["increment_env_episode"] += increment_env_episode def _fetch_env_episode(self, env_episode: int): if task.has_role(task.role.COLLECTOR): increment_env_episode = env_episode - self._local_state['env_episode'] self._local_state['env_episode'] = env_episode return increment_env_episode def _put_train_iter(self, train_iter: int): if not task.has_role(task.role.LEARNER): self._state["train_iter"] = train_iter def _fetch_train_iter(self, train_iter: int): if task.has_role(task.role.LEARNER): return train_iter class ModelExchanger: def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) -> None: """ Overview: Exchange model between processes, only the learner will send the model, otherwise the model will only be received. If you are using a shared model on a single host, there is no need to use this middleware. Arguments: - model (:obj:`torch.nn.Module`): Pytorch module. - model_loader (:obj:`ModelLoader`): Encode model in subprocess. """ self._model = model self._model_loader = model_loader self._event_name = "model_exchanger" self._state_dict_cache: Optional[Union[object, Storage]] = None self._is_learner = task.has_role(task.role.LEARNER) if not self._is_learner: task.on(self._event_name, self._cache_state_dict) if model_loader: task.once("finish", lambda _: model_loader.shutdown()) def _cache_state_dict(self, state_dict: Union[object, Storage]): self._state_dict_cache = state_dict def __new__(cls, *args, **kwargs): if not task.router.is_active: return task.void() if len(task.roles) == 0: logging.warning("The task does not have any roles defined, the ModelExchanger will not work.") return task.void() if len(task.roles) > 1: logging.warning( "Use multiple roles in one exchanger may lead to unexpected result, please check your code." ) return super(ModelExchanger, cls).__new__(cls) def __call__(self, ctx: "Context") -> Any: if self._model_loader: self._model_loader.start() if not self._is_learner: if ctx.total_step != 0: # Skip first iteration self._update_model() else: yield self._send_model() def _update_model(self): start = time() while True: if task.finish: return if time() - start > 60: logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) break if self._state_dict_cache is None: sleep(0.01) else: if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: try: self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) self._state_dict_cache = None break except FileNotFoundError as e: logging.warning( "Model file has been deleted on node {}, maybe you can increase the ttl.".format( task.router.node_id ) ) self._state_dict_cache = None continue else: self._model.load_state_dict(self._state_dict_cache) self._state_dict_cache = None break def _send_model(self): if self._model_loader: self._model_loader.save(self._send_callback) else: task.emit(self._event_name, self._model.state_dict(), only_remote=True) def _send_callback(self, storage: Storage): if task.running: task.emit(self._event_name, storage, only_remote=True) def __del__(self): if self._model_loader: self._model_loader.shutdown() class PeriodicalModelExchanger: def __init__( self, model: "Module", mode: str, period: int = 1, delay_toleration: float = np.inf, stale_toleration: int = 1, event_name: str = "model_exchanger", model_loader: Optional[ModelLoader] = None ) -> None: """ Overview: Exchange model between processes, set the mode to "send" or "receive" to specify the role of the process. If you are using a shared model on a single host, there is no need to use this middleware. Arguments: - model (:obj:`torch.nn.Module`): Pytorch module. - mode (:obj:`str`): "send" or "receive". - period (:obj:`int`): The period of model exchange. - delay_toleration (:obj:`float`): The permitted time interval for receiving model after being sent. - stale_toleration (:obj:`int`): The permitted number of iterations for receiving model after being sent. - event_name (:obj:`str`): The event name for model exchange. - model_loader (:obj:`ModelLoader`): ModelLoader for this PeriodicalModelExchanger to use. """ self._model = model self._model_loader = model_loader self._event_name = event_name self._period = period self._mode = mode if self._mode == "receive": self._id_counter = -1 self._model_id = -1 else: self._id_counter = 0 self._stale_toleration = stale_toleration self._model_stale = stale_toleration self._delay_toleration = delay_toleration self._state_dict_cache: Optional[Union[object, Storage]] = None if self._mode == "receive": task.on(self._event_name, self._cache_state_dict) if model_loader: task.once("finish", lambda _: model_loader.shutdown()) def _cache_state_dict(self, msg: Dict[str, Any]): if msg['id'] % self._period == 0: self._state_dict_cache = msg['model'] self._id_counter = msg['id'] self._time = msg['time'] def __new__(cls, *args, **kwargs): return super(PeriodicalModelExchanger, cls).__new__(cls) def __call__(self, ctx: "Context") -> Any: if self._model_loader: self._model_loader.start() if self._mode == "receive": if ctx.total_step != 0: # Skip first iteration self._update_model() elif self._mode == "send": yield if self._id_counter % self._period == 0: self._send_model(id=self._id_counter) self._id_counter += 1 else: raise NotImplementedError def _update_model(self): start = time() while True: if task.finish: return if time() - start > 60: logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) self._model_stale += 1 break if self._state_dict_cache is None: if self._model_stale < self._stale_toleration and time() - self._time < self._delay_toleration: self._model_stale += 1 break else: sleep(0.01) else: if self._id_counter > self._model_id and time() - self._time < self._delay_toleration: if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: try: self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) self._state_dict_cache = None self._model_id = self._id_counter self._model_stale = 1 break except FileNotFoundError as e: logging.warning( "Model file has been deleted on node {}, maybe you can increase the ttl.".format( task.router.node_id ) ) self._state_dict_cache = None continue else: self._model.load_state_dict(self._state_dict_cache) self._state_dict_cache = None self._model_id = self._id_counter self._model_stale = 1 break else: self._model_stale += 1 def _send_model(self, id: int): if self._model_loader: self._model_loader.save(self._send_callback) else: task.emit(self._event_name, {'id': id, 'model': self._model.state_dict(), 'time': time()}, only_remote=True) def _send_callback(self, storage: Storage): if task.running: task.emit(self._event_name, storage, only_remote=True) def __del__(self): if self._model_loader: self._model_loader.shutdown()