File size: 852 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from easydict import EasyDict
from typing import Callable


def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]:
    """
    Overview:
        Get the rollout length scheduler that adapts rollout length based\
        on the current environment steps.
    Returns:
        - scheduler (:obj:`Callble`): The function that takes envstep and\
          return the current rollout length.
    """
    if cfg.type == 'linear':
        x0 = cfg.rollout_start_step
        x1 = cfg.rollout_end_step
        y0 = cfg.rollout_length_min
        y1 = cfg.rollout_length_max
        w = (y1 - y0) / (x1 - x0)
        b = y0
        return lambda x: int(min(max(w * (x - x0) + b, y0), y1))
    elif cfg.type == 'constant':
        return lambda x: cfg.rollout_length
    else:
        raise KeyError("not implemented key: {}".format(cfg.type))