|
import pytest |
|
import copy |
|
from ding.framework import OnlineRLContext |
|
from ding.framework.middleware import eps_greedy_handler, eps_greedy_masker |
|
from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_eps_greedy_handler(): |
|
cfg = copy.deepcopy(CONFIG) |
|
ctx = OnlineRLContext() |
|
|
|
ctx.env_step = 0 |
|
next(eps_greedy_handler(cfg)(ctx)) |
|
assert ctx.collect_kwargs['eps'] == 0.95 |
|
|
|
ctx.env_step = 1000000 |
|
next(eps_greedy_handler(cfg)(ctx)) |
|
assert ctx.collect_kwargs['eps'] == 0.1 |
|
|
|
|
|
@pytest.mark.unittest |
|
def test_eps_greedy_masker(): |
|
ctx = OnlineRLContext() |
|
for _ in range(10): |
|
eps_greedy_masker()(ctx) |
|
assert ctx.collect_kwargs['eps'] == -1 |
|
|