|
from easydict import EasyDict |
|
from typing import Callable |
|
|
|
|
|
def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: |
|
""" |
|
Overview: |
|
Get the rollout length scheduler that adapts rollout length based\ |
|
on the current environment steps. |
|
Returns: |
|
- scheduler (:obj:`Callble`): The function that takes envstep and\ |
|
return the current rollout length. |
|
""" |
|
if cfg.type == 'linear': |
|
x0 = cfg.rollout_start_step |
|
x1 = cfg.rollout_end_step |
|
y0 = cfg.rollout_length_min |
|
y1 = cfg.rollout_length_max |
|
w = (y1 - y0) / (x1 - x0) |
|
b = y0 |
|
return lambda x: int(min(max(w * (x - x0) + b, y0), y1)) |
|
elif cfg.type == 'constant': |
|
return lambda x: cfg.rollout_length |
|
else: |
|
raise KeyError("not implemented key: {}".format(cfg.type)) |
|
|