File size: 9,269 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import multiprocessing as mp
import ctypes
from time import sleep, time
from typing import Any, Dict, List
import pytest
from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType


class MockEnv():

    def __init__(self, _) -> None:
        self._counter = 0

    def step(self, _):
        self._counter += 1
        return self._counter

    @property
    def action_space(self):
        return 3

    def block(self):
        sleep(10)

    def block_reset(self):
        sleep(10)

    def sleep1(self):
        sleep(1)


@pytest.mark.tmp
@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
def test_supervisor(type_):
    sv = Supervisor(type_=type_)
    for _ in range(3):
        sv.register(lambda: MockEnv("AnyArgs"))
    sv.start_link()

    for env_id in range(len(sv._children)):
        sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))

    recv_states: List[RecvPayload] = []
    for _ in range(3):
        recv_states.append(sv.recv())

    assert sum([payload.proc_id for payload in recv_states]) == 3
    assert all([payload.data == 1 for payload in recv_states])

    # Test recv_all
    send_payloads = []
    for env_id in range(len(sv._children)):
        payload = SendPayload(
            proc_id=env_id,
            method="step",
            args=["any action"],
        )
        send_payloads.append(payload)
        sv.send(payload)

    req_ids = [payload.req_id for payload in send_payloads]
    # Only wait for last two messages, keep the first one in the queue.
    recv_payloads = sv.recv_all(send_payloads[1:])
    assert len(recv_payloads) == 2
    for req_id, payload in zip(req_ids[1:], recv_payloads):
        assert req_id == payload.req_id

    recv_payload = sv.recv()
    assert recv_payload.req_id == req_ids[0]

    assert len(sv.action_space) == 3
    assert all(a == 3 for a in sv.action_space)

    sv.shutdown()


@pytest.mark.tmp
def test_supervisor_spawn():
    sv = Supervisor(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
    for _ in range(3):
        sv.register(MockEnv("AnyArgs"))
    sv.start_link()

    for env_id in range(len(sv._children)):
        sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))

    recv_states: List[RecvPayload] = []
    for _ in range(3):
        recv_states.append(sv.recv())

    assert sum([payload.proc_id for payload in recv_states]) == 3
    assert all([payload.data == 1 for payload in recv_states])
    sv.shutdown()


class MockCrashEnv(MockEnv):

    def step(self, _):
        super().step(_)
        if self._counter == 2:
            raise Exception("Ohh")

        return self._counter


@pytest.mark.tmp
@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
def test_crash_supervisor(type_):
    sv = Supervisor(type_=type_)
    for _ in range(2):
        sv.register(lambda: MockEnv("AnyArgs"))
    sv.register(lambda: MockCrashEnv("AnyArgs"))
    sv.start_link()

    # Send 6 messages, will cause the third subprocess crash
    for env_id in range(len(sv._children)):
        for _ in range(2):
            sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))

    # Find the error mesasge
    recv_states: List[RecvPayload] = []
    for _ in range(6):
        recv_payload = sv.recv(ignore_err=True)
        if recv_payload.err:
            sv._children[recv_payload.proc_id].restart()
        recv_states.append(recv_payload)
    assert any([isinstance(payload.err, Exception) for payload in recv_states])

    # Resume
    for env_id in range(len(sv._children)):
        sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))
    recv_states: List[RecvPayload] = []
    for _ in range(3):
        recv_states.append(sv.recv())

    # 3 + 3 + 1
    assert sum([p.data for p in recv_states]) == 7

    with pytest.raises(Exception):
        sv.send(SendPayload(proc_id=2, method="step", args=["any action"]))
        sv.recv(ignore_err=False)

    sv.shutdown()


@pytest.mark.tmp
@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
def test_recv_all(type_):
    sv = Supervisor(type_=type_)
    for _ in range(3):
        sv.register(lambda: MockEnv("AnyArgs"))
    sv.start_link()

    # Test recv_all
    send_payloads = []
    for env_id in range(len(sv._children)):
        payload = SendPayload(
            proc_id=env_id,
            method="step",
            args=["any action"],
        )
        send_payloads.append(payload)
        sv.send(payload)

    retry_times = {env_id: 0 for env_id in range(len(sv._children))}

    def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayload]):
        if retry_times[recv_payload.proc_id] == 2:
            return
        retry_times[recv_payload.proc_id] += 1
        payload = SendPayload(proc_id=recv_payload.proc_id, method="step", args={"action"})
        sv.send(payload)
        remain_payloads[payload.req_id] = payload

    recv_payloads = sv.recv_all(send_payloads=send_payloads, callback=recv_callback)
    assert len(recv_payloads) == 3
    assert all([v == 2 for v in retry_times.values()])

    sv.shutdown()


@pytest.mark.timeout(60)
@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
def test_timeout(type_):
    sv = Supervisor(type_=type_)
    for _ in range(3):
        sv.register(lambda: MockEnv("AnyArgs"))
    sv.start_link()

    send_payloads = []
    for env_id in range(len(sv._children)):
        payload = SendPayload(proc_id=env_id, method="block")
        send_payloads.append(payload)
        sv.send(payload)

    # Test timeout exception
    with pytest.raises(TimeoutError):
        sv.recv_all(send_payloads=send_payloads, timeout=1)
    sv.shutdown(timeout=1)

    # Test timeout with ignore error
    sv.start_link()
    send_payloads = []

    # 0 is block
    payload = SendPayload(proc_id=0, method="block")
    send_payloads.append(payload)
    sv.send(payload)

    # 1 is step
    payload = SendPayload(proc_id=1, method="step", args=[""])
    send_payloads.append(payload)
    sv.send(payload)

    payloads = sv.recv_all(send_payloads=send_payloads, timeout=1, ignore_err=True)
    assert isinstance(payloads[0].err, TimeoutError)
    assert payloads[1].err is None

    sv.shutdown(timeout=1)


@pytest.mark.timeout(60)
@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
def test_timeout_with_callback(type_):
    sv = Supervisor(type_=type_)
    for _ in range(3):
        sv.register(lambda: MockEnv("AnyArgs"))
    sv.start_link()
    send_payloads = []

    # 0 is block
    payload = SendPayload(proc_id=0, method="block")
    send_payloads.append(payload)
    sv.send(payload)

    # 1 is step
    payload = SendPayload(proc_id=1, method="step", args=[""])
    send_payloads.append(payload)
    sv.send(payload)

    block_reset_callback = False

    # 1. Add another send payload in the callback
    # 2. Recv this send payload and check for the method
    def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayload]):
        if recv_payload.method == "block" and recv_payload.err:
            new_send_payload = SendPayload(proc_id=recv_payload.proc_id, method="block_reset")
            remain_payloads[new_send_payload.req_id] = new_send_payload
            return

        if recv_payload.method == "block_reset" and recv_payload.err:
            nonlocal block_reset_callback
            block_reset_callback = True
            return

    payloads = sv.recv_all(send_payloads=send_payloads, timeout=1, ignore_err=True, callback=recv_callback)
    assert block_reset_callback
    assert isinstance(payloads[0].err, TimeoutError)
    assert payloads[1].err is None

    sv.shutdown(timeout=1)


@pytest.mark.tmp  # gitlab ci and local test pass, github always fail
def test_shared_memory():
    sv = Supervisor(type_=ChildType.PROCESS)

    def shm_callback(payload: RecvPayload, shm: Any):
        shm[payload.proc_id] = payload.req_id
        payload.data = 0

    shm = mp.Array(ctypes.c_uint8, 3)
    for i in range(3):
        sv.register(lambda: MockEnv("AnyArgs"), shm_buffer=shm, shm_callback=shm_callback)
    sv.start_link()

    # Send init request
    for env_id in range(len(sv._children)):
        sv.send(SendPayload(proc_id=env_id, req_id=env_id, method="sleep1", args=[]))

    start = time()
    for i in range(6):
        payload = sv.recv()
        assert payload.data == 0
        assert shm[payload.proc_id] == payload.req_id
        sv.send(SendPayload(proc_id=payload.proc_id, req_id=i, method="sleep1", args=[]))

    # Non blocking
    assert time() - start < 3

    sv.shutdown()


@pytest.mark.benchmark
@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
def test_supervisor_benchmark(type_):
    sv = Supervisor(type_=type_)
    for _ in range(3):
        sv.register(lambda: MockEnv("AnyArgs"))
    sv.start_link()

    for env_id in range(len(sv._children)):
        sv.send(SendPayload(proc_id=env_id, method="step", args=[""]))

    start = time()
    for _ in range(1000):
        payload = sv.recv()
        sv.send(SendPayload(proc_id=payload.proc_id, method="step", args=[""]))

    assert time() - start < 1