{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "pbVXrmEsLXDt"
},
"source": [
"# ML-Agents run with Stable Baselines 3\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WNKTwHU3d2-l"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"#@title Install Rendering Dependencies { display-mode: \"form\" }\n",
"#@markdown (You only need to run this code when using Colab's hosted runtime)\n",
"\n",
"import os\n",
"from IPython.display import HTML, display\n",
"\n",
"def progress(value, max=100):\n",
" return HTML(\"\"\"\n",
" \n",
" \"\"\".format(value=value, max=max))\n",
"\n",
"pro_bar = display(progress(0, 100), display_id=True)\n",
"\n",
"try:\n",
" import google.colab\n",
" INSTALL_XVFB = True\n",
"except ImportError:\n",
" INSTALL_XVFB = 'COLAB_ALWAYS_INSTALL_XVFB' in os.environ\n",
"\n",
"if INSTALL_XVFB:\n",
" with open('frame-buffer', 'w') as writefile:\n",
" writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n",
"XVFB=/usr/bin/Xvfb\n",
"XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n",
"PIDFILE=./frame-buffer.pid\n",
"case \"$1\" in\n",
" start)\n",
" echo -n \"Starting virtual X frame buffer: Xvfb\"\n",
" /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n",
" echo \".\"\n",
" ;;\n",
" stop)\n",
" echo -n \"Stopping virtual X frame buffer: Xvfb\"\n",
" /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n",
" rm $PIDFILE\n",
" echo \".\"\n",
" ;;\n",
" restart)\n",
" $0 stop\n",
" $0 start\n",
" ;;\n",
" *)\n",
" echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n",
" exit 1\n",
"esac\n",
"exit 0\n",
" \"\"\")\n",
" !sudo apt-get update\n",
" pro_bar.update(progress(10, 100))\n",
" !sudo DEBIAN_FRONTEND=noninteractive apt install -y daemon wget gdebi-core build-essential libfontenc1 libfreetype6 xorg-dev xorg\n",
" pro_bar.update(progress(20, 100))\n",
" !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n",
" pro_bar.update(progress(30, 100))\n",
" !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb 2>&1\n",
" pro_bar.update(progress(40, 100))\n",
" !sudo dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n",
" pro_bar.update(progress(50, 100))\n",
" !sudo dpkg -i xvfb.deb 2>&1\n",
" pro_bar.update(progress(70, 100))\n",
" !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n",
" pro_bar.update(progress(80, 100))\n",
" !rm xvfb.deb\n",
" pro_bar.update(progress(90, 100))\n",
" !bash frame-buffer start\n",
" os.environ[\"DISPLAY\"] = \":1\"\n",
"pro_bar.update(progress(100, 100))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pzj7wgapAcDs"
},
"source": [
"### Installing ml-agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "N8yfQqkbebQ5",
"pycharm": {
"is_executing": true
}
},
"outputs": [],
"source": [
"try:\n",
" import mlagents\n",
" print(\"ml-agents already installed\")\n",
"except ImportError:\n",
" !python -m pip install -q mlagents==0.30.0\n",
" print(\"Installed ml-agents\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_u74YhSmW6gD"
},
"source": [
"## Run the Environment"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P-r_cB2rqp5x"
},
"source": [
"### Import dependencies and set some high level parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YSf-WhxbqtLw"
},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from pathlib import Path\n",
"from typing import Callable, Any\n",
"\n",
"import gym\n",
"from gym import Env\n",
"\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv\n",
"from supersuit import observation_lambda_v0\n",
"\n",
"\n",
"from mlagents_envs.environment import UnityEnvironment\n",
"from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper\n",
"from mlagents_envs.registry import UnityEnvRegistry, default_registry\n",
"from mlagents_envs.side_channel.engine_configuration_channel import (\n",
" EngineConfig,\n",
" EngineConfigurationChannel,\n",
")\n",
"\n",
"NUM_ENVS = 8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Environment and Engine Configurations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Default values from CLI (See cli_utils.py)\n",
"DEFAULT_ENGINE_CONFIG = EngineConfig(\n",
" width=84,\n",
" height=84,\n",
" quality_level=4,\n",
" time_scale=20,\n",
" target_frame_rate=-1,\n",
" capture_frame_rate=60,\n",
")\n",
"\n",
"# Some config subset of an actual config.yaml file for MLA.\n",
"@dataclass\n",
"class LimitedConfig:\n",
" # The local path to a Unity executable or the name of an entry in the registry.\n",
" env_path_or_name: str\n",
" base_port: int\n",
" base_seed: int = 0\n",
" num_env: int = 1\n",
" engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG\n",
" visual_obs: bool = False\n",
" # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow.\n",
" # WARNING: Make sure to use MultiInputPolicy if you turn this on.\n",
" allow_multiple_obs: bool = False\n",
" env_registry: UnityEnvRegistry = default_registry"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Unity Environment SB3 Factory"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def _unity_env_from_path_or_registry(\n",
" env: str, registry: UnityEnvRegistry, **kwargs: Any\n",
") -> UnityEnvironment:\n",
" if Path(env).expanduser().absolute().exists():\n",
" return UnityEnvironment(file_name=env, **kwargs)\n",
" elif env in registry:\n",
" return registry.get(env).make(**kwargs)\n",
" else:\n",
" raise ValueError(f\"Environment '{env}' wasn't a local path or registry entry\")\n",
" \n",
"def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:\n",
" def handle_obs(obs, space):\n",
" if isinstance(space, gym.spaces.Tuple):\n",
" if len(space) == 1:\n",
" return obs[0]\n",
" # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n",
" return {str(i): v for i, v in enumerate(obs)}\n",
" return obs\n",
"\n",
" def handle_obs_space(space):\n",
" if isinstance(space, gym.spaces.Tuple):\n",
" if len(space) == 1:\n",
" return space[0]\n",
" # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n",
" return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})\n",
" return space\n",
"\n",
" def create_env(env: str, worker_id: int) -> Callable[[], Env]:\n",
" def _f() -> Env:\n",
" engine_configuration_channel = EngineConfigurationChannel()\n",
" engine_configuration_channel.set_configuration(config.engine_config)\n",
" kwargs[\"side_channels\"] = kwargs.get(\"side_channels\", []) + [\n",
" engine_configuration_channel\n",
" ]\n",
" unity_env = _unity_env_from_path_or_registry(\n",
" env=env,\n",
" registry=config.env_registry,\n",
" worker_id=worker_id,\n",
" base_port=config.base_port,\n",
" seed=config.base_seed + worker_id,\n",
" **kwargs,\n",
" )\n",
" new_env = UnityToGymWrapper(\n",
" unity_env=unity_env,\n",
" uint8_visual=config.visual_obs,\n",
" allow_multiple_obs=config.allow_multiple_obs,\n",
" )\n",
" new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)\n",
" return new_env\n",
"\n",
" return _f\n",
"\n",
" env_facts = [\n",
" create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)\n",
" ]\n",
" return SubprocVecEnv(env_facts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Start Environment from the registry"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# -----------------\n",
"# This code is used to close an env that might not have been closed before\n",
"try:\n",
" env.close()\n",
"except:\n",
" pass\n",
"# -----------------\n",
"\n",
"env = make_mla_sb3_env(\n",
" config=LimitedConfig(\n",
" env_path_or_name='Basic', # Can use any name from a registry or a path to your own unity build.\n",
" base_port=6006,\n",
" base_seed=42,\n",
" num_env=NUM_ENVS,\n",
" allow_multiple_obs=True,\n",
" ),\n",
" no_graphics=True, # Set to false if you are running locally and want to watch the environments move around as they train.\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# 250K should train to a reward ~= 0.90 for the \"Basic\" environment.\n",
"# We set the value lower here to demonstrate just a small amount of trianing.\n",
"BATCH_SIZE = 32\n",
"BUFFER_SIZE = 256\n",
"UPDATES = 50\n",
"TOTAL_TAINING_STEPS_GOAL = BUFFER_SIZE * UPDATES\n",
"BETA = 0.0005\n",
"N_EPOCHS = 3 \n",
"STEPS_PER_UPDATE = BUFFER_SIZE / NUM_ENVS\n",
"\n",
"# Helps gather stats for our eval() calls later so we can see reward stats.\n",
"env = VecMonitor(env)\n",
"\n",
"#Policy and Value function with 2 layers of 128 units each and no shared layers.\n",
"policy_kwargs = {\"net_arch\" : [{\"pi\": [32,32], \"vf\": [32,32]}]}\n",
"\n",
"model = PPO(\n",
" \"MlpPolicy\",\n",
" env,\n",
" verbose=1,\n",
" learning_rate=lambda progress: 0.0003 * (1.0 - progress),\n",
" clip_range=lambda progress: 0.2 * (1.0 - progress),\n",
" clip_range_vf=lambda progress: 0.2 * (1.0 - progress),\n",
" # Uncomment this if you want to log tensorboard results when running this notebook locally.\n",
" # tensorboard_log=\"results\",\n",
" policy_kwargs=policy_kwargs,\n",
" n_steps=int(STEPS_PER_UPDATE),\n",
" batch_size=BATCH_SIZE,\n",
" n_epochs=N_EPOCHS,\n",
" ent_coef=BETA,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# 0.93 is considered solved for the Basic environment\n",
"for i in range(UPDATES):\n",
" print(f\"Training round {i + 1}/{UPDATES}\")\n",
" # NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.\n",
" model.learn(total_timesteps=BUFFER_SIZE, reset_num_timesteps=(i == 0))\n",
" model.policy.eval()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h1lIx3_l24OP"
},
"source": [
"### Close the environment\n",
"Frees up the ports being used."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vdWG6_SqtNtv",
"pycharm": {
"is_executing": true,
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"env.close()\n",
"print(\"Closed environment\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Colab-UnityEnvironment-1-Run.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}