zjowowen's picture
init space
079c32c
import multiprocessing as mp
import pytest
from threading import Lock
from time import sleep, time
import random
import dataclasses
from ding.framework import task, Context, Parallel
@dataclasses.dataclass
class TestContext(Context):
pipeline: list = dataclasses.field(default_factory=list)
@pytest.mark.unittest
def test_serial_pipeline():
def step0(ctx):
ctx.pipeline.append(0)
def step1(ctx):
ctx.pipeline.append(1)
# Execute step1, step2 twice
with task.start(ctx=TestContext()):
for _ in range(2):
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1, 0, 1]
# Renew and execute step1, step2
task.renew()
assert task.ctx.total_step == 1
task.forward(step0)
task.forward(step1)
assert task.ctx.pipeline == [0, 1]
# Test context inheritance
task.renew()
@pytest.mark.unittest
def test_serial_yield_pipeline():
def step0(ctx):
ctx.pipeline.append(0)
yield
ctx.pipeline.append(0)
def step1(ctx):
ctx.pipeline.append(1)
with task.start(ctx=TestContext()):
task.forward(step0)
task.forward(step1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
@pytest.mark.unittest
def test_async_pipeline():
def step0(ctx):
ctx.pipeline.append(0)
def step1(ctx):
ctx.pipeline.append(1)
# Execute step1, step2 twice
with task.start(async_mode=True, ctx=TestContext()):
for _ in range(2):
task.forward(step0)
sleep(0.1)
task.forward(step1)
sleep(0.1)
task.backward()
assert task.ctx.pipeline == [0, 1, 0, 1]
task.renew()
assert task.ctx.total_step == 1
@pytest.mark.unittest
def test_async_yield_pipeline():
def step0(ctx):
sleep(0.1)
ctx.pipeline.append(0)
yield
ctx.pipeline.append(0)
def step1(ctx):
sleep(0.2)
ctx.pipeline.append(1)
with task.start(async_mode=True, ctx=TestContext()):
task.forward(step0)
task.forward(step1)
sleep(0.3)
task.backward().sync()
assert task.ctx.pipeline == [0, 1, 0]
assert len(task._backward_stack) == 0
def parallel_main():
sync_count = 0
def on_count():
nonlocal sync_count
sync_count += 1
def counter(task):
def _counter(ctx):
sleep(0.2 + random.random() / 10)
task.emit("count", only_remote=True)
return _counter
with task.start():
task.on("count", on_count)
task.use(counter(task))
task.run(max_step=10)
assert sync_count > 0
@pytest.mark.tmp
def test_parallel_pipeline():
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main)
@pytest.mark.tmp
def test_emit():
with task.start():
greets = []
task.on("Greeting", lambda msg: greets.append(msg))
def step1(ctx):
task.emit("Greeting", "Hi")
task.use(step1)
task.run(max_step=10)
sleep(0.1)
assert len(greets) == 10
def emit_remote_main():
with task.start():
greets = []
if task.router.node_id == 0:
task.on("Greeting", lambda msg: greets.append(msg))
for _ in range(20):
if greets:
break
sleep(0.1)
assert len(greets) > 0
else:
for _ in range(20):
task.emit("Greeting", "Hi", only_remote=True)
sleep(0.1)
assert len(greets) == 0
@pytest.mark.tmp
def test_emit_remote():
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(emit_remote_main)
@pytest.mark.tmp
def test_wait_for():
# Wait for will only work in async or parallel mode
with task.start(async_mode=True, n_async_workers=2):
greets = []
def step1(_):
hi = task.wait_for("Greeting")[0][0]
if hi:
greets.append(hi)
def step2(_):
task.emit("Greeting", "Hi")
task.use(step1)
task.use(step2)
task.run(max_step=10)
assert len(greets) == 10
assert all(map(lambda hi: hi == "Hi", greets))
# Test timeout exception
with task.start(async_mode=True, n_async_workers=2):
def step1(_):
task.wait_for("Greeting", timeout=0.3, ignore_timeout_exception=False)
task.use(step1)
with pytest.raises(TimeoutError):
task.run(max_step=1)
@pytest.mark.tmp
def test_async_exception():
with task.start(async_mode=True, n_async_workers=2):
def step1(_):
task.wait_for("any_event") # Never end
def step2(_):
sleep(0.3)
raise Exception("Oh")
task.use(step1)
task.use(step2)
with pytest.raises(Exception):
task.run(max_step=2)
assert task.ctx.total_step == 0
def early_stop_main():
with task.start():
task.use(lambda _: sleep(0.5))
if task.match_labels("node.0"):
task.run(max_step=10)
else:
task.run(max_step=2)
assert task.ctx.total_step < 7
@pytest.mark.tmp
def test_early_stop():
Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(early_stop_main)
@pytest.mark.tmp
def test_parallel_in_sequencial():
result = []
def fast(_):
result.append("fast")
def slow(_):
sleep(0.1)
result.append("slow")
with task.start():
task.use(lambda _: result.append("begin"))
task.use(task.parallel(slow, fast))
task.run(max_step=1)
assert result == ["begin", "fast", "slow"]
@pytest.mark.tmp
def test_serial_in_parallel():
result = []
def fast(_):
result.append("fast")
def slow(_):
sleep(0.1)
result.append("slow")
with task.start(async_mode=True):
task.use(lambda _: result.append("begin"))
task.use(task.serial(slow, fast))
task.run(max_step=1)
assert result == ["begin", "slow", "fast"]
@pytest.mark.unittest
def test_nested_middleware():
"""
When there is a yield in the middleware,
calling this middleware in another will lead to an unexpected result.
Use task.forward or task.wrap can fix this problem.
"""
result = []
def child():
def _child(ctx: Context):
result.append(3 * ctx.total_step)
yield
result.append(2 + 3 * ctx.total_step)
return _child
def mother():
_child = task.wrap(child())
def _mother(ctx: Context):
child_back = _child(ctx)
result.append(1 + 3 * ctx.total_step)
child_back()
return _mother
with task.start():
task.use(mother())
task.run(2)
assert result == [0, 1, 2, 3, 4, 5]
@pytest.mark.unittest
def test_use_lock():
def slow(ctx):
sleep(0.1)
ctx.result = "slow"
def fast(ctx):
ctx.result = "fast"
with task.start(async_mode=True):
# The lock will turn async middleware into serial
task.use(slow, lock=True)
task.use(fast, lock=True)
task.run(1)
assert task.ctx.result == "fast"
# With custom lock, it will not affect the inner lock of task
lock = Lock()
def slowest(ctx):
sleep(0.3)
ctx.result = "slowest"
with task.start(async_mode=True):
task.use(slow, lock=lock)
# If it receives other locks, it will not be the last one to finish execution
task.use(slowest, lock=True)
task.use(fast, lock=lock)
task.run(1)
assert task.ctx.result == "slowest"
def broadcast_finish_main():
with task.start():
def tick(ctx: Context):
if task.router.node_id == 1 and ctx.total_step == 1:
task.finish = True
sleep(1)
task.use(tick)
task.run(20)
def broadcast_main_target():
Parallel.runner(
n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555, startup_interval=0.1
)(broadcast_finish_main)
def broadcast_secondary_target():
"Start two standalone processes and connect to the main process."
Parallel.runner(
n_parallel_workers=2,
protocol="tcp",
address="127.0.0.1",
topology="alone",
ports=50556,
attach_to=["tcp://127.0.0.1:50555"],
node_ids=[1, 2],
startup_interval=0.1
)(broadcast_finish_main)
@pytest.mark.tmp # gitlab ci and local test pass, github always fail
@pytest.mark.timeout(10)
def test_broadcast_finish():
start = time()
ctx = mp.get_context("spawn")
main_process = ctx.Process(target=broadcast_main_target)
secondary_process = ctx.Process(target=broadcast_secondary_target)
main_process.start()
secondary_process.start()
main_process.join()
secondary_process.join()
assert (time() - start) < 10