|
import shutil |
|
from time import sleep |
|
import pytest |
|
import numpy as np |
|
import tempfile |
|
|
|
import torch |
|
from ding.data.model_loader import FileModelLoader |
|
from ding.data.storage_loader import FileStorageLoader |
|
from ding.framework import task |
|
from ding.framework.context import OnlineRLContext |
|
from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger |
|
from ding.framework.parallel import Parallel |
|
from ding.utils.default_helper import set_pkg_seed |
|
from os import path |
|
|
|
|
|
def context_exchanger_main(): |
|
with task.start(ctx=OnlineRLContext()): |
|
if task.router.node_id == 0: |
|
task.add_role(task.role.LEARNER) |
|
elif task.router.node_id == 1: |
|
task.add_role(task.role.COLLECTOR) |
|
|
|
task.use(ContextExchanger(skip_n_iter=1)) |
|
|
|
if task.has_role(task.role.LEARNER): |
|
|
|
def learner_context(ctx: OnlineRLContext): |
|
assert len(ctx.trajectories) == 2 |
|
assert len(ctx.trajectory_end_idx) == 4 |
|
assert len(ctx.episodes) == 8 |
|
assert ctx.env_step > 0 |
|
assert ctx.env_episode > 0 |
|
yield |
|
ctx.train_iter += 1 |
|
|
|
task.use(learner_context) |
|
elif task.has_role(task.role.COLLECTOR): |
|
|
|
def collector_context(ctx: OnlineRLContext): |
|
if ctx.total_step > 0: |
|
assert ctx.train_iter > 0 |
|
yield |
|
ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] |
|
ctx.trajectory_end_idx = [1 for _ in range(4)] |
|
ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] |
|
ctx.env_step += 1 |
|
ctx.env_episode += 1 |
|
|
|
task.use(collector_context) |
|
|
|
task.run(max_step=3) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_context_exchanger(): |
|
Parallel.runner(n_parallel_workers=2)(context_exchanger_main) |
|
|
|
|
|
def context_exchanger_with_storage_loader_main(): |
|
with task.start(ctx=OnlineRLContext()): |
|
if task.router.node_id == 0: |
|
task.add_role(task.role.LEARNER) |
|
elif task.router.node_id == 1: |
|
task.add_role(task.role.COLLECTOR) |
|
|
|
tempdir = path.join(tempfile.gettempdir(), "test_storage_loader") |
|
storage_loader = FileStorageLoader(dirname=tempdir) |
|
try: |
|
task.use(ContextExchanger(skip_n_iter=1, storage_loader=storage_loader)) |
|
|
|
if task.has_role(task.role.LEARNER): |
|
|
|
def learner_context(ctx: OnlineRLContext): |
|
assert len(ctx.trajectories) == 2 |
|
assert len(ctx.trajectory_end_idx) == 4 |
|
assert len(ctx.episodes) == 8 |
|
assert ctx.env_step > 0 |
|
assert ctx.env_episode > 0 |
|
yield |
|
ctx.train_iter += 1 |
|
|
|
task.use(learner_context) |
|
elif task.has_role(task.role.COLLECTOR): |
|
|
|
def collector_context(ctx: OnlineRLContext): |
|
if ctx.total_step > 0: |
|
assert ctx.train_iter > 0 |
|
yield |
|
ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)] |
|
ctx.trajectory_end_idx = [1 for _ in range(4)] |
|
ctx.episodes = [np.random.rand(10, 10) for _ in range(8)] |
|
ctx.env_step += 1 |
|
ctx.env_episode += 1 |
|
|
|
task.use(collector_context) |
|
|
|
task.run(max_step=3) |
|
finally: |
|
storage_loader.shutdown() |
|
sleep(1) |
|
if path.exists(tempdir): |
|
shutil.rmtree(tempdir) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_context_exchanger_with_storage_loader(): |
|
Parallel.runner(n_parallel_workers=2)(context_exchanger_with_storage_loader_main) |
|
|
|
|
|
class MockPolicy: |
|
|
|
def __init__(self) -> None: |
|
self._model = self._get_model(10, 10) |
|
|
|
def _get_model(self, X_shape, y_shape) -> torch.nn.Module: |
|
return torch.nn.Sequential( |
|
torch.nn.Linear(X_shape, 24), torch.nn.ReLU(), torch.nn.Linear(24, 24), torch.nn.ReLU(), |
|
torch.nn.Linear(24, y_shape) |
|
) |
|
|
|
def train(self, X, y): |
|
loss_fn = torch.nn.MSELoss(reduction="mean") |
|
optimizer = torch.optim.Adam(self._model.parameters(), lr=0.01) |
|
y_pred = self._model(X) |
|
loss = loss_fn(y_pred, y) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
def predict(self, X): |
|
with torch.no_grad(): |
|
return self._model(X) |
|
|
|
|
|
def model_exchanger_main(): |
|
with task.start(ctx=OnlineRLContext()): |
|
set_pkg_seed(0, use_cuda=False) |
|
policy = MockPolicy() |
|
X = torch.rand(10) |
|
y = torch.rand(10) |
|
|
|
if task.router.node_id == 0: |
|
task.add_role(task.role.LEARNER) |
|
else: |
|
task.add_role(task.role.COLLECTOR) |
|
|
|
task.use(ModelExchanger(policy._model)) |
|
|
|
if task.has_role(task.role.LEARNER): |
|
|
|
def train(ctx): |
|
policy.train(X, y) |
|
sleep(0.3) |
|
|
|
task.use(train) |
|
else: |
|
y_pred1 = policy.predict(X) |
|
|
|
def pred(ctx): |
|
if ctx.total_step > 0: |
|
y_pred2 = policy.predict(X) |
|
|
|
assert any(y_pred1 != y_pred2) |
|
sleep(0.3) |
|
|
|
task.use(pred) |
|
|
|
task.run(2) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_model_exchanger(): |
|
Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main) |
|
|
|
|
|
def model_exchanger_main_with_model_loader(): |
|
with task.start(ctx=OnlineRLContext()): |
|
set_pkg_seed(0, use_cuda=False) |
|
policy = MockPolicy() |
|
X = torch.rand(10) |
|
y = torch.rand(10) |
|
|
|
if task.router.node_id == 0: |
|
task.add_role(task.role.LEARNER) |
|
else: |
|
task.add_role(task.role.COLLECTOR) |
|
|
|
tempdir = path.join(tempfile.gettempdir(), "test_model_loader") |
|
model_loader = FileModelLoader(policy._model, dirname=tempdir) |
|
task.use(ModelExchanger(policy._model, model_loader=model_loader)) |
|
|
|
try: |
|
if task.has_role(task.role.LEARNER): |
|
|
|
def train(ctx): |
|
policy.train(X, y) |
|
sleep(0.3) |
|
|
|
task.use(train) |
|
else: |
|
y_pred1 = policy.predict(X) |
|
|
|
def pred(ctx): |
|
if ctx.total_step > 0: |
|
y_pred2 = policy.predict(X) |
|
|
|
assert any(y_pred1 != y_pred2) |
|
sleep(0.3) |
|
|
|
task.use(pred) |
|
task.run(2) |
|
finally: |
|
model_loader.shutdown() |
|
sleep(0.3) |
|
if path.exists(tempdir): |
|
shutil.rmtree(tempdir) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_model_exchanger_with_model_loader(): |
|
Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader) |
|
|
|
|
|
def periodical_model_exchanger_main(): |
|
with task.start(ctx=OnlineRLContext()): |
|
set_pkg_seed(0, use_cuda=False) |
|
policy = MockPolicy() |
|
X = torch.rand(10) |
|
y = torch.rand(10) |
|
|
|
if task.router.node_id == 0: |
|
task.add_role(task.role.LEARNER) |
|
task.use(PeriodicalModelExchanger(policy._model, mode="send", period=3)) |
|
else: |
|
task.add_role(task.role.COLLECTOR) |
|
task.use(PeriodicalModelExchanger(policy._model, mode="receive", period=1, stale_toleration=3)) |
|
|
|
if task.has_role(task.role.LEARNER): |
|
|
|
def train(ctx): |
|
policy.train(X, y) |
|
sleep(0.3) |
|
|
|
task.use(train) |
|
else: |
|
y_pred1 = policy.predict(X) |
|
print("y_pred1: ", y_pred1) |
|
stale = 1 |
|
|
|
def pred(ctx): |
|
nonlocal stale |
|
y_pred2 = policy.predict(X) |
|
print("y_pred2: ", y_pred2) |
|
stale += 1 |
|
assert stale <= 3 or all(y_pred1 == y_pred2) |
|
if any(y_pred1 != y_pred2): |
|
stale = 1 |
|
|
|
sleep(0.3) |
|
|
|
task.use(pred) |
|
task.run(8) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_periodical_model_exchanger(): |
|
Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main) |
|
|