zjowowen's picture
init space
079c32c
raw
history blame
4.58 kB
from collections import defaultdict
import pytest
import time
from ding.framework import Parallel
def parallel_main():
msg = defaultdict(bool)
def test_callback(key):
msg[key] = True
router = Parallel()
router.on("test_callback", test_callback)
# Wait for nodes to bind
time.sleep(0.7)
for _ in range(30):
router.emit("test_callback", "ping")
if msg["ping"]:
break
time.sleep(0.03)
assert msg["ping"]
# Avoid can not receiving messages from each other after exit parallel
time.sleep(0.7)
@pytest.mark.tmp
def test_parallel_run():
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main)
Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main)
def uncaught_exception_main():
router = Parallel()
if router.node_id == 0:
time.sleep(0.1)
raise Exception("uncaught exception")
else:
time.sleep(0.2)
@pytest.mark.tmp
def test_uncaught_exception():
# Make one process crash, then the parent process will also crash and output the stack of the wrong process.
with pytest.raises(Exception) as exc_info:
Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(uncaught_exception_main)
e = exc_info._excinfo[1]
assert "uncaught exception" in str(e)
def disconnected_main():
router = Parallel()
if router.node_id == 0:
time.sleep(0.1)
# Receive two messages then exit
greets = []
router.on("greeting", lambda: greets.append("."))
for _ in range(10):
if len(greets) == 1:
break
else:
time.sleep(0.1)
assert len(greets) > 0
else:
# Send 10 greetings even if the target process is exited
for i in range(10):
router.emit("greeting")
time.sleep(0.1)
assert i == 9
@pytest.mark.tmp
def test_disconnected():
# Make one process exit normally and the rest will still run, even if the network request
# is not received by other processes.
Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(disconnected_main)
class AutoRecover:
@classmethod
def main_p0(cls):
# Wait for p1's message and recovered message from p1
greets = []
router = Parallel()
router.on("greeting_0", lambda msg: greets.append(msg))
for _ in range(50):
if greets and greets[-1] == "recovered_p1":
break
time.sleep(0.1)
assert greets and greets[-1] == "recovered_p1"
@classmethod
def main_p1(cls):
# Send empty message to p0
# When recovered from exception, send recovered_p1 to p0
# Listen msgs from p2
greets = []
router = Parallel()
router.on("greeting_1", lambda msg: greets.append(msg))
# Test sending message to p0
if router._retries == 0:
for _ in range(10):
router.emit("greeting_0", "")
time.sleep(0.1)
raise Exception("P1 Error")
elif router._retries == 1:
for _ in range(10):
router.emit("greeting_0", "recovered_p1")
time.sleep(0.1)
else:
raise Exception("Failed too many times")
# Test recover and receving message from p2
for _ in range(20):
if greets:
break
time.sleep(0.1)
assert len(greets) > 0
@classmethod
def main_p2(cls):
# Simply send message to p1
router = Parallel()
for _ in range(20):
router.emit("greeting_1", "")
time.sleep(0.1)
@classmethod
def main(cls):
router = Parallel()
if router.node_id == 0:
cls.main_p0()
elif router.node_id == 1:
cls.main_p1()
elif router.node_id == 2:
cls.main_p2()
else:
raise Exception("Invalid node id")
@pytest.mark.tmp
def test_auto_recover():
# With max_retries=1
Parallel.runner(
n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=1, startup_interval=0.1
)(AutoRecover.main)
# With max_retries=0
with pytest.raises(Exception) as exc_info:
Parallel.runner(
n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=0, startup_interval=0.1
)(AutoRecover.main)
e = exc_info._excinfo[1]
assert "P1 Error" in str(e)