sgoodfriend's picture
PPO playing SpaceInvadersNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
14b68af
import optuna
from typing import Any, Dict
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
def sample_env_hyperparams(
trial: optuna.Trial, env_hparams: Dict[str, Any], env: VecEnv
) -> Dict[str, Any]:
obs_space = single_observation_space(env)
n_envs = 2 ** trial.suggest_int("n_envs_exp", 1, 5)
trial.set_user_attr("n_envs", n_envs)
env_hparams["n_envs"] = n_envs
normalize = trial.suggest_categorical("normalize", [False, True])
env_hparams["normalize"] = normalize
if normalize:
normalize_kwargs = env_hparams.get("normalize_kwargs", {})
if len(obs_space.shape) == 3:
normalize_kwargs.update(
{
"norm_obs": False,
"norm_reward": True,
}
)
else:
norm_obs = trial.suggest_categorical("norm_obs", [True, False])
norm_reward = trial.suggest_categorical("norm_reward", [True, False])
normalize_kwargs.update(
{
"norm_obs": norm_obs,
"norm_reward": norm_reward,
}
)
env_hparams["normalize_kwargs"] = normalize_kwargs
elif "normalize_kwargs" in env_hparams:
del env_hparams["normalize_kwargs"]
return env_hparams