|
import random |
|
import time |
|
import socket |
|
import pytest |
|
import multiprocessing as mp |
|
from ditk import logging |
|
from ding.framework import task |
|
from ding.framework.parallel import Parallel |
|
from ding.framework.context import OnlineRLContext |
|
from ding.framework.middleware.barrier import Barrier |
|
|
|
PORTS_LIST = ["1235", "1236", "1237"] |
|
|
|
|
|
class EnvStepMiddleware: |
|
|
|
def __call__(self, ctx): |
|
yield |
|
ctx.env_step += 1 |
|
|
|
|
|
class SleepMiddleware: |
|
|
|
def __init__(self, node_id): |
|
self.node_id = node_id |
|
|
|
def random_sleep(self, diection, step): |
|
random.seed(self.node_id + step) |
|
sleep_second = random.randint(1, 5) |
|
logging.info("Node:[{}] env_step:[{}]-{} will sleep:{}s".format(self.node_id, step, diection, sleep_second)) |
|
for i in range(sleep_second): |
|
time.sleep(1) |
|
print("Node:[{}] sleepping...".format(self.node_id)) |
|
logging.info("Node:[{}] env_step:[{}]-{} wake up!".format(self.node_id, step, diection)) |
|
|
|
def __call__(self, ctx): |
|
self.random_sleep("forward", ctx.env_step) |
|
yield |
|
self.random_sleep("backward", ctx.env_step) |
|
|
|
|
|
def star_barrier(): |
|
with task.start(ctx=OnlineRLContext()): |
|
node_id = task.router.node_id |
|
if node_id == 0: |
|
attch_from_nums = 3 |
|
else: |
|
attch_from_nums = 0 |
|
barrier = Barrier(attch_from_nums) |
|
task.use(barrier, lock=False) |
|
task.use(SleepMiddleware(node_id), lock=False) |
|
task.use(barrier, lock=False) |
|
task.use(EnvStepMiddleware(), lock=False) |
|
try: |
|
task.run(2) |
|
except Exception as e: |
|
logging.error(e) |
|
assert False |
|
|
|
|
|
def mesh_barrier(): |
|
with task.start(ctx=OnlineRLContext()): |
|
node_id = task.router.node_id |
|
attch_from_nums = 3 - task.router.node_id |
|
barrier = Barrier(attch_from_nums) |
|
task.use(barrier, lock=False) |
|
task.use(SleepMiddleware(node_id), lock=False) |
|
task.use(barrier, lock=False) |
|
task.use(EnvStepMiddleware(), lock=False) |
|
try: |
|
task.run(2) |
|
except Exception as e: |
|
logging.error(e) |
|
assert False |
|
|
|
|
|
def unmatch_barrier(): |
|
with task.start(ctx=OnlineRLContext()): |
|
node_id = task.router.node_id |
|
attch_from_nums = 3 - task.router.node_id |
|
task.use(Barrier(attch_from_nums, 5), lock=False) |
|
if node_id != 2: |
|
task.use(Barrier(attch_from_nums, 5), lock=False) |
|
try: |
|
task.run(2) |
|
except TimeoutError as e: |
|
assert node_id != 2 |
|
logging.info("Node:[{}] timeout with barrier".format(node_id)) |
|
else: |
|
time.sleep(5) |
|
assert node_id == 2 |
|
logging.info("Node:[{}] finish barrier".format(node_id)) |
|
|
|
|
|
def launch_barrier(args): |
|
i, topo, fn, test_id = args |
|
address = socket.gethostbyname(socket.gethostname()) |
|
topology = "alone" |
|
attach_to = [] |
|
port_base = PORTS_LIST[test_id] |
|
port = port_base + str(i) |
|
if topo == 'star': |
|
if i != 0: |
|
attach_to = ['tcp://{}:{}{}'.format(address, port_base, 0)] |
|
elif topo == 'mesh': |
|
for j in range(i): |
|
attach_to.append('tcp://{}:{}{}'.format(address, port_base, j)) |
|
|
|
Parallel.runner( |
|
node_ids=i, |
|
ports=int(port), |
|
attach_to=attach_to, |
|
topology=topology, |
|
protocol="tcp", |
|
n_parallel_workers=1, |
|
startup_interval=0 |
|
)(fn) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_star_topology_barrier(): |
|
ctx = mp.get_context("spawn") |
|
with ctx.Pool(processes=4) as pool: |
|
pool.map(launch_barrier, [[i, 'star', star_barrier, 0] for i in range(4)]) |
|
pool.close() |
|
pool.join() |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_mesh_topology_barrier(): |
|
ctx = mp.get_context("spawn") |
|
with ctx.Pool(processes=4) as pool: |
|
pool.map(launch_barrier, [[i, 'mesh', mesh_barrier, 1] for i in range(4)]) |
|
pool.close() |
|
pool.join() |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_unmatch_barrier(): |
|
ctx = mp.get_context("spawn") |
|
with ctx.Pool(processes=4) as pool: |
|
pool.map(launch_barrier, [[i, 'mesh', unmatch_barrier, 2] for i in range(4)]) |
|
pool.close() |
|
pool.join() |
|
|