File size: 1,541 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from typing import TYPE_CHECKING, Callable
from easydict import EasyDict
from ding.rl_utils import get_epsilon_greedy_fn
from ding.framework import task

if TYPE_CHECKING:
    from ding.framework import OnlineRLContext


def eps_greedy_handler(cfg: EasyDict) -> Callable:
    """
    Overview:
        The middleware that computes epsilon value according to the env_step.
    Arguments:
        - cfg (:obj:`EasyDict`): Config.
    """
    if task.router.is_active and not task.has_role(task.role.COLLECTOR):
        return task.void()

    eps_cfg = cfg.policy.other.eps
    handle = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)

    def _eps_greedy(ctx: "OnlineRLContext"):
        """
        Input of ctx:
            - env_step (:obj:`int`): The env steps count.
        Output of ctx:
            - collect_kwargs['eps'] (:obj:`float`): The eps conditioned on env_step and cfg.
        """

        ctx.collect_kwargs['eps'] = handle(ctx.env_step)
        yield
        try:
            ctx.collect_kwargs.pop('eps')
        except:  # noqa
            pass

    return _eps_greedy


def eps_greedy_masker():
    """
    Overview:
        The middleware that returns masked epsilon value and stop generating \
             actions by the e_greedy method.
    """

    def _masker(ctx: "OnlineRLContext"):
        """
        Output of ctx:
            - collect_kwargs['eps'] (:obj:`float`): The masked eps value, default to -1.
        """

        ctx.collect_kwargs['eps'] = -1

    return _masker