File size: 1,245 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import pytest
import pickle
import numpy as np
from ding.framework import Context, OnlineRLContext, OfflineRLContext
from dataclasses import dataclass
@dataclass
class MockContext(Context):
hello: str = "world"
keep_me: int = 0
not_keep_me: int = 0
@pytest.mark.unittest
def test_pickable():
ctx = MockContext()
ctx.keep("keep_me")
_ctx = pickle.loads(pickle.dumps(ctx))
assert _ctx.hello == "world"
ctx.keep_me += 1
ctx.not_keep_me += 1
_ctx = ctx.renew()
assert _ctx.keep_me == 1
assert _ctx.not_keep_me == 0
@pytest.mark.unittest
def test_online():
ctx = OnlineRLContext()
assert ctx.env_step == 0
assert ctx.eval_value == -np.inf
ctx.env_step += 1
ctx.eval_value = 1
assert ctx.env_step == 1
assert ctx.eval_value == 1
_ctx = ctx.renew()
assert _ctx.env_step == 1
assert _ctx.eval_value == -np.inf
@pytest.mark.unittest
def test_offline():
ctx = OfflineRLContext()
assert ctx.train_iter == 0
assert ctx.eval_value == -np.inf
ctx.train_iter += 1
ctx.eval_value = 1
assert ctx.train_iter == 1
assert ctx.eval_value == 1
_ctx = ctx.renew()
assert _ctx.train_iter == 1
assert _ctx.eval_value == -np.inf
|