File size: 1,069 Bytes
8bf4dee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import optuna
from gym.spaces import Box
from typing import Any, Dict
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
single_action_space,
)
def sample_on_policy_hyperparams(
trial: optuna.Trial, policy_hparams: Dict[str, Any], env: VecEnv
) -> Dict[str, Any]:
act_space = single_action_space(env)
policy_hparams["init_layers_orthogonal"] = trial.suggest_categorical(
"init_layers_orthogonal", [True, False]
)
policy_hparams["activation_fn"] = trial.suggest_categorical(
"activation_fn", ["tanh", "relu"]
)
if isinstance(act_space, Box):
policy_hparams["log_std_init"] = trial.suggest_float("log_std_init", -5, 0.5)
policy_hparams["use_sde"] = trial.suggest_categorical("use_sde", [False, True])
if policy_hparams.get("use_sde", False):
policy_hparams["squash_output"] = trial.suggest_categorical(
"squash_output", [False, True]
)
elif "squash_output" in policy_hparams:
del policy_hparams["squash_output"]
return policy_hparams
|