zjowowen's picture
init space
079c32c
raw
history blame
852 Bytes
from easydict import EasyDict
from typing import Callable
def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]:
"""
Overview:
Get the rollout length scheduler that adapts rollout length based\
on the current environment steps.
Returns:
- scheduler (:obj:`Callble`): The function that takes envstep and\
return the current rollout length.
"""
if cfg.type == 'linear':
x0 = cfg.rollout_start_step
x1 = cfg.rollout_end_step
y0 = cfg.rollout_length_min
y1 = cfg.rollout_length_max
w = (y1 - y0) / (x1 - x0)
b = y0
return lambda x: int(min(max(w * (x - x0) + b, y0), y1))
elif cfg.type == 'constant':
return lambda x: cfg.rollout_length
else:
raise KeyError("not implemented key: {}".format(cfg.type))