File size: 1,069 Bytes
8bf4dee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import optuna

from gym.spaces import Box
from typing import Any, Dict

from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecEnv,
    single_action_space,
)


def sample_on_policy_hyperparams(
    trial: optuna.Trial, policy_hparams: Dict[str, Any], env: VecEnv
) -> Dict[str, Any]:
    act_space = single_action_space(env)

    policy_hparams["init_layers_orthogonal"] = trial.suggest_categorical(
        "init_layers_orthogonal", [True, False]
    )
    policy_hparams["activation_fn"] = trial.suggest_categorical(
        "activation_fn", ["tanh", "relu"]
    )

    if isinstance(act_space, Box):
        policy_hparams["log_std_init"] = trial.suggest_float("log_std_init", -5, 0.5)
        policy_hparams["use_sde"] = trial.suggest_categorical("use_sde", [False, True])

    if policy_hparams.get("use_sde", False):
        policy_hparams["squash_output"] = trial.suggest_categorical(
            "squash_output", [False, True]
        )
    elif "squash_output" in policy_hparams:
        del policy_hparams["squash_output"]

    return policy_hparams