File size: 731 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import pytest
import copy
from ding.framework import OnlineRLContext
from ding.framework.middleware import eps_greedy_handler, eps_greedy_masker
from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG


@pytest.mark.unittest
def test_eps_greedy_handler():
    cfg = copy.deepcopy(CONFIG)
    ctx = OnlineRLContext()

    ctx.env_step = 0
    next(eps_greedy_handler(cfg)(ctx))
    assert ctx.collect_kwargs['eps'] == 0.95

    ctx.env_step = 1000000
    next(eps_greedy_handler(cfg)(ctx))
    assert ctx.collect_kwargs['eps'] == 0.1


@pytest.mark.unittest
def test_eps_greedy_masker():
    ctx = OnlineRLContext()
    for _ in range(10):
        eps_greedy_masker()(ctx)
    assert ctx.collect_kwargs['eps'] == -1