|
import numpy as np |
|
from time import sleep, time |
|
from dataclasses import fields |
|
from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union |
|
from ditk import logging |
|
from ding.framework import task |
|
from ding.data import StorageLoader, Storage, ModelLoader |
|
if TYPE_CHECKING: |
|
from ding.framework.context import Context |
|
from torch.nn import Module |
|
|
|
|
|
class ContextExchanger: |
|
|
|
def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None: |
|
""" |
|
Overview: |
|
Exchange context between processes, |
|
support properties: trajectories, episodes, env_step, env_episode, train_iter |
|
Arguments: |
|
- skip_n_iter (:obj:`int`): For collectors, it may be necessary to skip waiting \ |
|
for the first n iterations to collect data for the learner to learn. This parameter \ |
|
will not work on learner. |
|
- storage_loader (:obj:`Optional[StorageLoader]`): Turn data into storage class to reduce \ |
|
the network overhead. |
|
""" |
|
if not task.router.is_active: |
|
raise RuntimeError("ContextHandler should be used in parallel mode!") |
|
self._state = {} |
|
self._local_state = {} |
|
if task.has_role(task.role.COLLECTOR): |
|
self._local_state['env_step'] = 0 |
|
self._local_state['env_episode'] = 0 |
|
self._event_name = "context_exchanger_{role}" |
|
self._skip_n_iter = skip_n_iter |
|
self._storage_loader = storage_loader |
|
for role in task.role: |
|
if not task.has_role(role): |
|
task.on(self._event_name.format(role=role), self.put) |
|
if storage_loader: |
|
task.once("finish", lambda _: storage_loader.shutdown()) |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if not task.router.is_active: |
|
return task.void() |
|
|
|
if len(task.roles) == 0: |
|
logging.warning("The task does not have any roles defined, the ContextExchanger will not work.") |
|
return task.void() |
|
|
|
if len(task.roles) > 1: |
|
logging.warning( |
|
"Use multiple roles in one exchanger may lead to unexpected result, please check your code." |
|
) |
|
|
|
return super(ContextExchanger, cls).__new__(cls) |
|
|
|
def __call__(self, ctx: "Context"): |
|
self.merge(ctx) |
|
yield |
|
payload = self.fetch(ctx) |
|
if payload: |
|
if self._storage_loader and task.has_role(task.role.COLLECTOR): |
|
payload = self._storage_loader.save(payload) |
|
for role in task.roles: |
|
task.emit(self._event_name.format(role=role), payload, only_remote=True) |
|
|
|
def __del__(self): |
|
if self._storage_loader: |
|
self._storage_loader.shutdown() |
|
|
|
def put(self, payload: Union[Dict, Storage]): |
|
""" |
|
Overview: |
|
Get attributes from ctx on the callback of event. |
|
Each attribute should have a standalone put handler, which named `_put_{key}` |
|
""" |
|
|
|
def callback(payload: Dict): |
|
for key, item in payload.items(): |
|
fn_name = "_put_{}".format(key) |
|
if hasattr(self, fn_name): |
|
getattr(self, fn_name)(item) |
|
else: |
|
logging.warning("Receive unexpected key ({}) in context exchanger".format(key)) |
|
|
|
if isinstance(payload, Storage): |
|
assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object." |
|
self._storage_loader.load(payload, callback) |
|
else: |
|
callback(payload) |
|
|
|
def fetch(self, ctx: "Context") -> Dict[str, Any]: |
|
""" |
|
Overview: |
|
Fetch attributes from ctx before emit them to the event bus. |
|
Each attribute should have a standalone fetch handler, which named `_fetch_{key}` |
|
""" |
|
payload = {} |
|
for field in fields(ctx): |
|
key, item = field.name, getattr(ctx, field.name) |
|
fn_name = "_fetch_{}".format(key) |
|
if hasattr(self, fn_name): |
|
value = getattr(self, fn_name)(item) |
|
if value is not None: |
|
payload[key] = value |
|
return payload |
|
|
|
def merge(self, ctx: "Context"): |
|
if task.has_role(task.role.LEARNER): |
|
|
|
|
|
while len(self._state) == 0: |
|
sleep(0.01) |
|
elif ctx.total_step >= self._skip_n_iter: |
|
start = time() |
|
while len(self._state) == 0: |
|
if time() - start > 60: |
|
logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id)) |
|
break |
|
sleep(0.01) |
|
|
|
for k, v in self._state.items(): |
|
if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'): |
|
pure_k = k.split('increment_')[-1] |
|
setattr(ctx, pure_k, getattr(ctx, pure_k) + v) |
|
else: |
|
setattr(ctx, k, v) |
|
self._state = {} |
|
|
|
|
|
def _put_trajectories(self, traj: List[Any]): |
|
if not task.has_role(task.role.LEARNER): |
|
return |
|
if "trajectories" not in self._state: |
|
self._state["trajectories"] = [] |
|
self._state["trajectories"].extend(traj) |
|
|
|
def _fetch_trajectories(self, traj: List[Any]): |
|
if task.has_role(task.role.COLLECTOR): |
|
return traj |
|
|
|
def _put_episodes(self, episodes: List[Any]): |
|
if not task.has_role(task.role.LEARNER): |
|
return |
|
if "episodes" not in self._state: |
|
self._state["episodes"] = [] |
|
self._state["episodes"].extend(episodes) |
|
|
|
def _fetch_episodes(self, episodes: List[Any]): |
|
if task.has_role(task.role.COLLECTOR): |
|
return episodes |
|
|
|
def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]): |
|
if not task.has_role(task.role.LEARNER): |
|
return |
|
if "trajectory_end_idx" not in self._state: |
|
self._state["trajectory_end_idx"] = [] |
|
self._state["trajectory_end_idx"].extend(trajectory_end_idx) |
|
|
|
def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]): |
|
if task.has_role(task.role.COLLECTOR): |
|
return trajectory_end_idx |
|
|
|
def _put_env_step(self, increment_env_step: int): |
|
if not task.has_role(task.role.COLLECTOR): |
|
if 'increment_env_step' not in self._state: |
|
self._state['increment_env_step'] = 0 |
|
self._state["increment_env_step"] += increment_env_step |
|
|
|
def _fetch_env_step(self, env_step: int): |
|
if task.has_role(task.role.COLLECTOR): |
|
increment_env_step = env_step - self._local_state['env_step'] |
|
self._local_state['env_step'] = env_step |
|
return increment_env_step |
|
|
|
def _put_env_episode(self, increment_env_episode: int): |
|
if not task.has_role(task.role.COLLECTOR): |
|
if 'increment_env_episode' not in self._state: |
|
self._state['increment_env_episode'] = 0 |
|
self._state["increment_env_episode"] += increment_env_episode |
|
|
|
def _fetch_env_episode(self, env_episode: int): |
|
if task.has_role(task.role.COLLECTOR): |
|
increment_env_episode = env_episode - self._local_state['env_episode'] |
|
self._local_state['env_episode'] = env_episode |
|
return increment_env_episode |
|
|
|
def _put_train_iter(self, train_iter: int): |
|
if not task.has_role(task.role.LEARNER): |
|
self._state["train_iter"] = train_iter |
|
|
|
def _fetch_train_iter(self, train_iter: int): |
|
if task.has_role(task.role.LEARNER): |
|
return train_iter |
|
|
|
|
|
class ModelExchanger: |
|
|
|
def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) -> None: |
|
""" |
|
Overview: |
|
Exchange model between processes, only the learner will send the model, |
|
otherwise the model will only be received. |
|
If you are using a shared model on a single host, there is no need to use this middleware. |
|
Arguments: |
|
- model (:obj:`torch.nn.Module`): Pytorch module. |
|
- model_loader (:obj:`ModelLoader`): Encode model in subprocess. |
|
""" |
|
self._model = model |
|
self._model_loader = model_loader |
|
self._event_name = "model_exchanger" |
|
self._state_dict_cache: Optional[Union[object, Storage]] = None |
|
self._is_learner = task.has_role(task.role.LEARNER) |
|
if not self._is_learner: |
|
task.on(self._event_name, self._cache_state_dict) |
|
if model_loader: |
|
task.once("finish", lambda _: model_loader.shutdown()) |
|
|
|
def _cache_state_dict(self, state_dict: Union[object, Storage]): |
|
self._state_dict_cache = state_dict |
|
|
|
def __new__(cls, *args, **kwargs): |
|
if not task.router.is_active: |
|
return task.void() |
|
|
|
if len(task.roles) == 0: |
|
logging.warning("The task does not have any roles defined, the ModelExchanger will not work.") |
|
return task.void() |
|
|
|
if len(task.roles) > 1: |
|
logging.warning( |
|
"Use multiple roles in one exchanger may lead to unexpected result, please check your code." |
|
) |
|
|
|
return super(ModelExchanger, cls).__new__(cls) |
|
|
|
def __call__(self, ctx: "Context") -> Any: |
|
if self._model_loader: |
|
self._model_loader.start() |
|
|
|
if not self._is_learner: |
|
if ctx.total_step != 0: |
|
self._update_model() |
|
else: |
|
yield |
|
self._send_model() |
|
|
|
def _update_model(self): |
|
start = time() |
|
while True: |
|
if task.finish: |
|
return |
|
if time() - start > 60: |
|
logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) |
|
break |
|
if self._state_dict_cache is None: |
|
sleep(0.01) |
|
else: |
|
if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: |
|
try: |
|
self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) |
|
self._state_dict_cache = None |
|
break |
|
except FileNotFoundError as e: |
|
logging.warning( |
|
"Model file has been deleted on node {}, maybe you can increase the ttl.".format( |
|
task.router.node_id |
|
) |
|
) |
|
self._state_dict_cache = None |
|
continue |
|
else: |
|
self._model.load_state_dict(self._state_dict_cache) |
|
self._state_dict_cache = None |
|
break |
|
|
|
def _send_model(self): |
|
if self._model_loader: |
|
self._model_loader.save(self._send_callback) |
|
else: |
|
task.emit(self._event_name, self._model.state_dict(), only_remote=True) |
|
|
|
def _send_callback(self, storage: Storage): |
|
if task.running: |
|
task.emit(self._event_name, storage, only_remote=True) |
|
|
|
def __del__(self): |
|
if self._model_loader: |
|
self._model_loader.shutdown() |
|
|
|
|
|
class PeriodicalModelExchanger: |
|
|
|
def __init__( |
|
self, |
|
model: "Module", |
|
mode: str, |
|
period: int = 1, |
|
delay_toleration: float = np.inf, |
|
stale_toleration: int = 1, |
|
event_name: str = "model_exchanger", |
|
model_loader: Optional[ModelLoader] = None |
|
) -> None: |
|
""" |
|
Overview: |
|
Exchange model between processes, set the mode to "send" or "receive" to specify the role of the process. |
|
If you are using a shared model on a single host, there is no need to use this middleware. |
|
Arguments: |
|
- model (:obj:`torch.nn.Module`): Pytorch module. |
|
- mode (:obj:`str`): "send" or "receive". |
|
- period (:obj:`int`): The period of model exchange. |
|
- delay_toleration (:obj:`float`): The permitted time interval for receiving model after being sent. |
|
- stale_toleration (:obj:`int`): The permitted number of iterations for receiving model after being sent. |
|
- event_name (:obj:`str`): The event name for model exchange. |
|
- model_loader (:obj:`ModelLoader`): ModelLoader for this PeriodicalModelExchanger to use. |
|
""" |
|
self._model = model |
|
self._model_loader = model_loader |
|
self._event_name = event_name |
|
self._period = period |
|
self._mode = mode |
|
if self._mode == "receive": |
|
self._id_counter = -1 |
|
self._model_id = -1 |
|
else: |
|
self._id_counter = 0 |
|
self._stale_toleration = stale_toleration |
|
self._model_stale = stale_toleration |
|
self._delay_toleration = delay_toleration |
|
self._state_dict_cache: Optional[Union[object, Storage]] = None |
|
|
|
if self._mode == "receive": |
|
task.on(self._event_name, self._cache_state_dict) |
|
if model_loader: |
|
task.once("finish", lambda _: model_loader.shutdown()) |
|
|
|
def _cache_state_dict(self, msg: Dict[str, Any]): |
|
if msg['id'] % self._period == 0: |
|
self._state_dict_cache = msg['model'] |
|
self._id_counter = msg['id'] |
|
self._time = msg['time'] |
|
|
|
def __new__(cls, *args, **kwargs): |
|
return super(PeriodicalModelExchanger, cls).__new__(cls) |
|
|
|
def __call__(self, ctx: "Context") -> Any: |
|
if self._model_loader: |
|
self._model_loader.start() |
|
|
|
if self._mode == "receive": |
|
if ctx.total_step != 0: |
|
self._update_model() |
|
elif self._mode == "send": |
|
yield |
|
if self._id_counter % self._period == 0: |
|
self._send_model(id=self._id_counter) |
|
self._id_counter += 1 |
|
else: |
|
raise NotImplementedError |
|
|
|
def _update_model(self): |
|
start = time() |
|
while True: |
|
if task.finish: |
|
return |
|
if time() - start > 60: |
|
logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id)) |
|
self._model_stale += 1 |
|
break |
|
if self._state_dict_cache is None: |
|
if self._model_stale < self._stale_toleration and time() - self._time < self._delay_toleration: |
|
self._model_stale += 1 |
|
break |
|
else: |
|
sleep(0.01) |
|
else: |
|
if self._id_counter > self._model_id and time() - self._time < self._delay_toleration: |
|
if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None: |
|
try: |
|
self._model.load_state_dict(self._model_loader.load(self._state_dict_cache)) |
|
self._state_dict_cache = None |
|
self._model_id = self._id_counter |
|
self._model_stale = 1 |
|
break |
|
except FileNotFoundError as e: |
|
logging.warning( |
|
"Model file has been deleted on node {}, maybe you can increase the ttl.".format( |
|
task.router.node_id |
|
) |
|
) |
|
self._state_dict_cache = None |
|
continue |
|
else: |
|
self._model.load_state_dict(self._state_dict_cache) |
|
self._state_dict_cache = None |
|
self._model_id = self._id_counter |
|
self._model_stale = 1 |
|
break |
|
else: |
|
self._model_stale += 1 |
|
|
|
def _send_model(self, id: int): |
|
if self._model_loader: |
|
self._model_loader.save(self._send_callback) |
|
else: |
|
task.emit(self._event_name, {'id': id, 'model': self._model.state_dict(), 'time': time()}, only_remote=True) |
|
|
|
def _send_callback(self, storage: Storage): |
|
if task.running: |
|
task.emit(self._event_name, storage, only_remote=True) |
|
|
|
def __del__(self): |
|
if self._model_loader: |
|
self._model_loader.shutdown() |
|
|