File size: 1,541 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
from typing import TYPE_CHECKING, Callable
from easydict import EasyDict
from ding.rl_utils import get_epsilon_greedy_fn
from ding.framework import task
if TYPE_CHECKING:
from ding.framework import OnlineRLContext
def eps_greedy_handler(cfg: EasyDict) -> Callable:
"""
Overview:
The middleware that computes epsilon value according to the env_step.
Arguments:
- cfg (:obj:`EasyDict`): Config.
"""
if task.router.is_active and not task.has_role(task.role.COLLECTOR):
return task.void()
eps_cfg = cfg.policy.other.eps
handle = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
def _eps_greedy(ctx: "OnlineRLContext"):
"""
Input of ctx:
- env_step (:obj:`int`): The env steps count.
Output of ctx:
- collect_kwargs['eps'] (:obj:`float`): The eps conditioned on env_step and cfg.
"""
ctx.collect_kwargs['eps'] = handle(ctx.env_step)
yield
try:
ctx.collect_kwargs.pop('eps')
except: # noqa
pass
return _eps_greedy
def eps_greedy_masker():
"""
Overview:
The middleware that returns masked epsilon value and stop generating \
actions by the e_greedy method.
"""
def _masker(ctx: "OnlineRLContext"):
"""
Output of ctx:
- collect_kwargs['eps'] (:obj:`float`): The masked eps value, default to -1.
"""
ctx.collect_kwargs['eps'] = -1
return _masker
|