|
import multiprocessing as mp |
|
import pytest |
|
from threading import Lock |
|
from time import sleep, time |
|
import random |
|
import dataclasses |
|
from ding.framework import task, Context, Parallel |
|
|
|
|
|
@dataclasses.dataclass |
|
class TestContext(Context): |
|
pipeline: list = dataclasses.field(default_factory=list) |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_serial_pipeline(): |
|
|
|
def step0(ctx): |
|
ctx.pipeline.append(0) |
|
|
|
def step1(ctx): |
|
ctx.pipeline.append(1) |
|
|
|
|
|
with task.start(ctx=TestContext()): |
|
for _ in range(2): |
|
task.forward(step0) |
|
task.forward(step1) |
|
assert task.ctx.pipeline == [0, 1, 0, 1] |
|
|
|
|
|
task.renew() |
|
assert task.ctx.total_step == 1 |
|
task.forward(step0) |
|
task.forward(step1) |
|
assert task.ctx.pipeline == [0, 1] |
|
|
|
|
|
task.renew() |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_serial_yield_pipeline(): |
|
|
|
def step0(ctx): |
|
ctx.pipeline.append(0) |
|
yield |
|
ctx.pipeline.append(0) |
|
|
|
def step1(ctx): |
|
ctx.pipeline.append(1) |
|
|
|
with task.start(ctx=TestContext()): |
|
task.forward(step0) |
|
task.forward(step1) |
|
task.backward() |
|
assert task.ctx.pipeline == [0, 1, 0] |
|
assert len(task._backward_stack) == 0 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_async_pipeline(): |
|
|
|
def step0(ctx): |
|
ctx.pipeline.append(0) |
|
|
|
def step1(ctx): |
|
ctx.pipeline.append(1) |
|
|
|
|
|
with task.start(async_mode=True, ctx=TestContext()): |
|
for _ in range(2): |
|
task.forward(step0) |
|
sleep(0.1) |
|
task.forward(step1) |
|
sleep(0.1) |
|
task.backward() |
|
assert task.ctx.pipeline == [0, 1, 0, 1] |
|
task.renew() |
|
assert task.ctx.total_step == 1 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_async_yield_pipeline(): |
|
|
|
def step0(ctx): |
|
sleep(0.1) |
|
ctx.pipeline.append(0) |
|
yield |
|
ctx.pipeline.append(0) |
|
|
|
def step1(ctx): |
|
sleep(0.2) |
|
ctx.pipeline.append(1) |
|
|
|
with task.start(async_mode=True, ctx=TestContext()): |
|
task.forward(step0) |
|
task.forward(step1) |
|
sleep(0.3) |
|
task.backward().sync() |
|
assert task.ctx.pipeline == [0, 1, 0] |
|
assert len(task._backward_stack) == 0 |
|
|
|
|
|
def parallel_main(): |
|
sync_count = 0 |
|
|
|
def on_count(): |
|
nonlocal sync_count |
|
sync_count += 1 |
|
|
|
def counter(task): |
|
|
|
def _counter(ctx): |
|
sleep(0.2 + random.random() / 10) |
|
task.emit("count", only_remote=True) |
|
|
|
return _counter |
|
|
|
with task.start(): |
|
task.on("count", on_count) |
|
task.use(counter(task)) |
|
task.run(max_step=10) |
|
assert sync_count > 0 |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_parallel_pipeline(): |
|
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_emit(): |
|
with task.start(): |
|
greets = [] |
|
task.on("Greeting", lambda msg: greets.append(msg)) |
|
|
|
def step1(ctx): |
|
task.emit("Greeting", "Hi") |
|
|
|
task.use(step1) |
|
task.run(max_step=10) |
|
sleep(0.1) |
|
assert len(greets) == 10 |
|
|
|
|
|
def emit_remote_main(): |
|
with task.start(): |
|
greets = [] |
|
if task.router.node_id == 0: |
|
task.on("Greeting", lambda msg: greets.append(msg)) |
|
for _ in range(20): |
|
if greets: |
|
break |
|
sleep(0.1) |
|
assert len(greets) > 0 |
|
else: |
|
for _ in range(20): |
|
task.emit("Greeting", "Hi", only_remote=True) |
|
sleep(0.1) |
|
assert len(greets) == 0 |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_emit_remote(): |
|
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(emit_remote_main) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_wait_for(): |
|
|
|
with task.start(async_mode=True, n_async_workers=2): |
|
greets = [] |
|
|
|
def step1(_): |
|
hi = task.wait_for("Greeting")[0][0] |
|
if hi: |
|
greets.append(hi) |
|
|
|
def step2(_): |
|
task.emit("Greeting", "Hi") |
|
|
|
task.use(step1) |
|
task.use(step2) |
|
task.run(max_step=10) |
|
|
|
assert len(greets) == 10 |
|
assert all(map(lambda hi: hi == "Hi", greets)) |
|
|
|
|
|
with task.start(async_mode=True, n_async_workers=2): |
|
|
|
def step1(_): |
|
task.wait_for("Greeting", timeout=0.3, ignore_timeout_exception=False) |
|
|
|
task.use(step1) |
|
with pytest.raises(TimeoutError): |
|
task.run(max_step=1) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_async_exception(): |
|
with task.start(async_mode=True, n_async_workers=2): |
|
|
|
def step1(_): |
|
task.wait_for("any_event") |
|
|
|
def step2(_): |
|
sleep(0.3) |
|
raise Exception("Oh") |
|
|
|
task.use(step1) |
|
task.use(step2) |
|
with pytest.raises(Exception): |
|
task.run(max_step=2) |
|
|
|
assert task.ctx.total_step == 0 |
|
|
|
|
|
def early_stop_main(): |
|
with task.start(): |
|
task.use(lambda _: sleep(0.5)) |
|
if task.match_labels("node.0"): |
|
task.run(max_step=10) |
|
else: |
|
task.run(max_step=2) |
|
assert task.ctx.total_step < 7 |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_early_stop(): |
|
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(early_stop_main) |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_parallel_in_sequencial(): |
|
result = [] |
|
|
|
def fast(_): |
|
result.append("fast") |
|
|
|
def slow(_): |
|
sleep(0.1) |
|
result.append("slow") |
|
|
|
with task.start(): |
|
task.use(lambda _: result.append("begin")) |
|
task.use(task.parallel(slow, fast)) |
|
task.run(max_step=1) |
|
assert result == ["begin", "fast", "slow"] |
|
|
|
|
|
@pytest.mark.tmp |
|
def test_serial_in_parallel(): |
|
result = [] |
|
|
|
def fast(_): |
|
result.append("fast") |
|
|
|
def slow(_): |
|
sleep(0.1) |
|
result.append("slow") |
|
|
|
with task.start(async_mode=True): |
|
task.use(lambda _: result.append("begin")) |
|
task.use(task.serial(slow, fast)) |
|
task.run(max_step=1) |
|
|
|
assert result == ["begin", "slow", "fast"] |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_nested_middleware(): |
|
""" |
|
When there is a yield in the middleware, |
|
calling this middleware in another will lead to an unexpected result. |
|
Use task.forward or task.wrap can fix this problem. |
|
""" |
|
result = [] |
|
|
|
def child(): |
|
|
|
def _child(ctx: Context): |
|
result.append(3 * ctx.total_step) |
|
yield |
|
result.append(2 + 3 * ctx.total_step) |
|
|
|
return _child |
|
|
|
def mother(): |
|
_child = task.wrap(child()) |
|
|
|
def _mother(ctx: Context): |
|
child_back = _child(ctx) |
|
result.append(1 + 3 * ctx.total_step) |
|
child_back() |
|
|
|
return _mother |
|
|
|
with task.start(): |
|
task.use(mother()) |
|
task.run(2) |
|
assert result == [0, 1, 2, 3, 4, 5] |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_use_lock(): |
|
|
|
def slow(ctx): |
|
sleep(0.1) |
|
ctx.result = "slow" |
|
|
|
def fast(ctx): |
|
ctx.result = "fast" |
|
|
|
with task.start(async_mode=True): |
|
|
|
task.use(slow, lock=True) |
|
task.use(fast, lock=True) |
|
task.run(1) |
|
assert task.ctx.result == "fast" |
|
|
|
|
|
lock = Lock() |
|
|
|
def slowest(ctx): |
|
sleep(0.3) |
|
ctx.result = "slowest" |
|
|
|
with task.start(async_mode=True): |
|
task.use(slow, lock=lock) |
|
|
|
task.use(slowest, lock=True) |
|
task.use(fast, lock=lock) |
|
task.run(1) |
|
assert task.ctx.result == "slowest" |
|
|
|
|
|
def broadcast_finish_main(): |
|
with task.start(): |
|
|
|
def tick(ctx: Context): |
|
if task.router.node_id == 1 and ctx.total_step == 1: |
|
task.finish = True |
|
sleep(1) |
|
|
|
task.use(tick) |
|
task.run(20) |
|
|
|
|
|
def broadcast_main_target(): |
|
Parallel.runner( |
|
n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555, startup_interval=0.1 |
|
)(broadcast_finish_main) |
|
|
|
|
|
def broadcast_secondary_target(): |
|
"Start two standalone processes and connect to the main process." |
|
Parallel.runner( |
|
n_parallel_workers=2, |
|
protocol="tcp", |
|
address="127.0.0.1", |
|
topology="alone", |
|
ports=50556, |
|
attach_to=["tcp://127.0.0.1:50555"], |
|
node_ids=[1, 2], |
|
startup_interval=0.1 |
|
)(broadcast_finish_main) |
|
|
|
|
|
@pytest.mark.tmp |
|
@pytest.mark.timeout(10) |
|
def test_broadcast_finish(): |
|
start = time() |
|
ctx = mp.get_context("spawn") |
|
main_process = ctx.Process(target=broadcast_main_target) |
|
secondary_process = ctx.Process(target=broadcast_secondary_target) |
|
main_process.start() |
|
secondary_process.start() |
|
main_process.join() |
|
secondary_process.join() |
|
assert (time() - start) < 10 |
|
|