|
from collections import defaultdict |
|
import pytest |
|
import time |
|
from ding.framework import Parallel |
|
|
|
|
|
def parallel_main(): |
|
msg = defaultdict(bool) |
|
|
|
def test_callback(key): |
|
msg[key] = True |
|
|
|
router = Parallel() |
|
router.on("test_callback", test_callback) |
|
|
|
time.sleep(0.7) |
|
for _ in range(30): |
|
router.emit("test_callback", "ping") |
|
if msg["ping"]: |
|
break |
|
time.sleep(0.03) |
|
assert msg["ping"] |
|
|
|
time.sleep(0.7) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_parallel_run(): |
|
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) |
|
Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main) |
|
|
|
|
|
def uncaught_exception_main(): |
|
router = Parallel() |
|
if router.node_id == 0: |
|
time.sleep(0.1) |
|
raise Exception("uncaught exception") |
|
else: |
|
time.sleep(0.2) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_uncaught_exception(): |
|
|
|
with pytest.raises(Exception) as exc_info: |
|
Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(uncaught_exception_main) |
|
e = exc_info._excinfo[1] |
|
assert "uncaught exception" in str(e) |
|
|
|
|
|
def disconnected_main(): |
|
router = Parallel() |
|
|
|
if router.node_id == 0: |
|
time.sleep(0.1) |
|
|
|
greets = [] |
|
router.on("greeting", lambda: greets.append(".")) |
|
for _ in range(10): |
|
if len(greets) == 1: |
|
break |
|
else: |
|
time.sleep(0.1) |
|
assert len(greets) > 0 |
|
else: |
|
|
|
for i in range(10): |
|
router.emit("greeting") |
|
time.sleep(0.1) |
|
assert i == 9 |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_disconnected(): |
|
|
|
|
|
Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(disconnected_main) |
|
|
|
|
|
class AutoRecover: |
|
|
|
@classmethod |
|
def main_p0(cls): |
|
|
|
greets = [] |
|
router = Parallel() |
|
router.on("greeting_0", lambda msg: greets.append(msg)) |
|
for _ in range(50): |
|
if greets and greets[-1] == "recovered_p1": |
|
break |
|
time.sleep(0.1) |
|
assert greets and greets[-1] == "recovered_p1" |
|
|
|
@classmethod |
|
def main_p1(cls): |
|
|
|
|
|
|
|
greets = [] |
|
router = Parallel() |
|
router.on("greeting_1", lambda msg: greets.append(msg)) |
|
|
|
|
|
if router._retries == 0: |
|
for _ in range(10): |
|
router.emit("greeting_0", "") |
|
time.sleep(0.1) |
|
raise Exception("P1 Error") |
|
elif router._retries == 1: |
|
for _ in range(10): |
|
router.emit("greeting_0", "recovered_p1") |
|
time.sleep(0.1) |
|
else: |
|
raise Exception("Failed too many times") |
|
|
|
|
|
for _ in range(20): |
|
if greets: |
|
break |
|
time.sleep(0.1) |
|
assert len(greets) > 0 |
|
|
|
@classmethod |
|
def main_p2(cls): |
|
|
|
router = Parallel() |
|
for _ in range(20): |
|
router.emit("greeting_1", "") |
|
time.sleep(0.1) |
|
|
|
@classmethod |
|
def main(cls): |
|
router = Parallel() |
|
if router.node_id == 0: |
|
cls.main_p0() |
|
elif router.node_id == 1: |
|
cls.main_p1() |
|
elif router.node_id == 2: |
|
cls.main_p2() |
|
else: |
|
raise Exception("Invalid node id") |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_auto_recover(): |
|
|
|
Parallel.runner( |
|
n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=1, startup_interval=0.1 |
|
)(AutoRecover.main) |
|
|
|
with pytest.raises(Exception) as exc_info: |
|
Parallel.runner( |
|
n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=0, startup_interval=0.1 |
|
)(AutoRecover.main) |
|
e = exc_info._excinfo[1] |
|
assert "P1 Error" in str(e) |
|
|