import shutil from time import sleep import pytest import numpy as np import tempfile import torch from ding.data.model_loader import FileModelLoader from ding.data.storage_loader import FileStorageLoader from ding.framework import task from ding.framework.context import OnlineRLContext from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger from ding.framework.parallel import Parallel from ding.utils.default_helper import set_pkg_seed from os import path def context_exchanger_main(): with task.start(ctx=OnlineRLContext()): if task.router.node_id == 0: task.add_role(task.role.LEARNER) elif task.router.node_id == 1: task.add_role(task.role.COLLECTOR) task.use(ContextExchanger(skip_n_iter=1)) if task.has_role(task.role.LEARNER): def learner_context(ctx: OnlineRLContext): assert len(ctx.trajectories) == 2 assert len(ctx.trajectory_end_idx) == 4 assert len(ctx.episodes) == 8 assert ctx.env_step > 0 assert ctx.env_episode > 0 yield ctx.train_iter += 1 task.use(learner_context) elif task.has_role(task.role.COLLECTOR): def collector_context(ctx: OnlineRLContext): if ctx.total_step > 0: assert ctx.train_iter > 0 yield ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] ctx.trajectory_end_idx = [1 for _ in range(4)] ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] ctx.env_step += 1 ctx.env_episode += 1 task.use(collector_context) task.run(max_step=3) @pytest.mark.tmp def test_context_exchanger(): Parallel.runner(n_parallel_workers=2)(context_exchanger_main) def context_exchanger_with_storage_loader_main(): with task.start(ctx=OnlineRLContext()): if task.router.node_id == 0: task.add_role(task.role.LEARNER) elif task.router.node_id == 1: task.add_role(task.role.COLLECTOR) tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") storage_loader = FileStorageLoader(dirname=tempdir) try: task.use(ContextExchanger(skip_n_iter=1, storage_loader=storage_loader)) if task.has_role(task.role.LEARNER): def learner_context(ctx: OnlineRLContext): assert len(ctx.trajectories) == 2 assert len(ctx.trajectory_end_idx) == 4 assert len(ctx.episodes) == 8 assert ctx.env_step > 0 assert ctx.env_episode > 0 yield ctx.train_iter += 1 task.use(learner_context) elif task.has_role(task.role.COLLECTOR): def collector_context(ctx: OnlineRLContext): if ctx.total_step > 0: assert ctx.train_iter > 0 yield ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] ctx.trajectory_end_idx = [1 for _ in range(4)] ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] ctx.env_step += 1 ctx.env_episode += 1 task.use(collector_context) task.run(max_step=3) finally: storage_loader.shutdown() sleep(1) if path.exists(tempdir): shutil.rmtree(tempdir) @pytest.mark.tmp def test_context_exchanger_with_storage_loader(): Parallel.runner(n_parallel_workers=2)(context_exchanger_with_storage_loader_main) class MockPolicy: def __init__(self) -> None: self._model = self._get_model(10, 10) def _get_model(self, X_shape, y_shape) -> torch.nn.Module: return torch.nn.Sequential( torch.nn.Linear(X_shape, 24), torch.nn.ReLU(), torch.nn.Linear(24, 24), torch.nn.ReLU(), torch.nn.Linear(24, y_shape) ) def train(self, X, y): loss_fn = torch.nn.MSELoss(reduction="mean") optimizer = torch.optim.Adam(self._model.parameters(), lr=0.01) y_pred = self._model(X) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() def predict(self, X): with torch.no_grad(): return self._model(X) def model_exchanger_main(): with task.start(ctx=OnlineRLContext()): set_pkg_seed(0, use_cuda=False) policy = MockPolicy() X = torch.rand(10) y = torch.rand(10) if task.router.node_id == 0: task.add_role(task.role.LEARNER) else: task.add_role(task.role.COLLECTOR) task.use(ModelExchanger(policy._model)) if task.has_role(task.role.LEARNER): def train(ctx): policy.train(X, y) sleep(0.3) task.use(train) else: y_pred1 = policy.predict(X) def pred(ctx): if ctx.total_step > 0: y_pred2 = policy.predict(X) # Ensure model is upgraded assert any(y_pred1 != y_pred2) sleep(0.3) task.use(pred) task.run(2) @pytest.mark.tmp def test_model_exchanger(): Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main) def model_exchanger_main_with_model_loader(): with task.start(ctx=OnlineRLContext()): set_pkg_seed(0, use_cuda=False) policy = MockPolicy() X = torch.rand(10) y = torch.rand(10) if task.router.node_id == 0: task.add_role(task.role.LEARNER) else: task.add_role(task.role.COLLECTOR) tempdir = path.join(tempfile.gettempdir(), "test_model_loader") model_loader = FileModelLoader(policy._model, dirname=tempdir) task.use(ModelExchanger(policy._model, model_loader=model_loader)) try: if task.has_role(task.role.LEARNER): def train(ctx): policy.train(X, y) sleep(0.3) task.use(train) else: y_pred1 = policy.predict(X) def pred(ctx): if ctx.total_step > 0: y_pred2 = policy.predict(X) # Ensure model is upgraded assert any(y_pred1 != y_pred2) sleep(0.3) task.use(pred) task.run(2) finally: model_loader.shutdown() sleep(0.3) if path.exists(tempdir): shutil.rmtree(tempdir) @pytest.mark.tmp def test_model_exchanger_with_model_loader(): Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader) def periodical_model_exchanger_main(): with task.start(ctx=OnlineRLContext()): set_pkg_seed(0, use_cuda=False) policy = MockPolicy() X = torch.rand(10) y = torch.rand(10) if task.router.node_id == 0: task.add_role(task.role.LEARNER) task.use(PeriodicalModelExchanger(policy._model, mode="send", period=3)) else: task.add_role(task.role.COLLECTOR) task.use(PeriodicalModelExchanger(policy._model, mode="receive", period=1, stale_toleration=3)) if task.has_role(task.role.LEARNER): def train(ctx): policy.train(X, y) sleep(0.3) task.use(train) else: y_pred1 = policy.predict(X) print("y_pred1: ", y_pred1) stale = 1 def pred(ctx): nonlocal stale y_pred2 = policy.predict(X) print("y_pred2: ", y_pred2) stale += 1 assert stale <= 3 or all(y_pred1 == y_pred2) if any(y_pred1 != y_pred2): stale = 1 sleep(0.3) task.use(pred) task.run(8) @pytest.mark.tmp def test_periodical_model_exchanger(): Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main)