{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "pbVXrmEsLXDt" }, "source": [ "# ML-Agents run with Stable Baselines 3\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "WNKTwHU3d2-l" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "#@title Install Rendering Dependencies { display-mode: \"form\" }\n", "#@markdown (You only need to run this code when using Colab's hosted runtime)\n", "\n", "import os\n", "from IPython.display import HTML, display\n", "\n", "def progress(value, max=100):\n", " return HTML(\"\"\"\n", " \n", " {value}\n", " \n", " \"\"\".format(value=value, max=max))\n", "\n", "pro_bar = display(progress(0, 100), display_id=True)\n", "\n", "try:\n", " import google.colab\n", " INSTALL_XVFB = True\n", "except ImportError:\n", " INSTALL_XVFB = 'COLAB_ALWAYS_INSTALL_XVFB' in os.environ\n", "\n", "if INSTALL_XVFB:\n", " with open('frame-buffer', 'w') as writefile:\n", " writefile.write(\"\"\"#taken from https://gist.github.com/jterrace/2911875\n", "XVFB=/usr/bin/Xvfb\n", "XVFBARGS=\":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset\"\n", "PIDFILE=./frame-buffer.pid\n", "case \"$1\" in\n", " start)\n", " echo -n \"Starting virtual X frame buffer: Xvfb\"\n", " /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS\n", " echo \".\"\n", " ;;\n", " stop)\n", " echo -n \"Stopping virtual X frame buffer: Xvfb\"\n", " /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE\n", " rm $PIDFILE\n", " echo \".\"\n", " ;;\n", " restart)\n", " $0 stop\n", " $0 start\n", " ;;\n", " *)\n", " echo \"Usage: /etc/init.d/xvfb {start|stop|restart}\"\n", " exit 1\n", "esac\n", "exit 0\n", " \"\"\")\n", " !sudo apt-get update\n", " pro_bar.update(progress(10, 100))\n", " !sudo DEBIAN_FRONTEND=noninteractive apt install -y daemon wget gdebi-core build-essential libfontenc1 libfreetype6 xorg-dev xorg\n", " pro_bar.update(progress(20, 100))\n", " !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n", " pro_bar.update(progress(30, 100))\n", " !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb 2>&1\n", " pro_bar.update(progress(40, 100))\n", " !sudo dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1\n", " pro_bar.update(progress(50, 100))\n", " !sudo dpkg -i xvfb.deb 2>&1\n", " pro_bar.update(progress(70, 100))\n", " !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb\n", " pro_bar.update(progress(80, 100))\n", " !rm xvfb.deb\n", " pro_bar.update(progress(90, 100))\n", " !bash frame-buffer start\n", " os.environ[\"DISPLAY\"] = \":1\"\n", "pro_bar.update(progress(100, 100))" ] }, { "cell_type": "markdown", "metadata": { "id": "Pzj7wgapAcDs" }, "source": [ "### Installing ml-agents" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N8yfQqkbebQ5", "pycharm": { "is_executing": true } }, "outputs": [], "source": [ "try:\n", " import mlagents\n", " print(\"ml-agents already installed\")\n", "except ImportError:\n", " !python -m pip install -q mlagents==0.30.0\n", " print(\"Installed ml-agents\")" ] }, { "cell_type": "markdown", "metadata": { "id": "_u74YhSmW6gD" }, "source": [ "## Run the Environment" ] }, { "cell_type": "markdown", "metadata": { "id": "P-r_cB2rqp5x" }, "source": [ "### Import dependencies and set some high level parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YSf-WhxbqtLw" }, "outputs": [], "source": [ "from dataclasses import dataclass\n", "from pathlib import Path\n", "from typing import Callable, Any\n", "\n", "import gym\n", "from gym import Env\n", "\n", "from stable_baselines3 import PPO\n", "from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv\n", "from supersuit import observation_lambda_v0\n", "\n", "\n", "from mlagents_envs.environment import UnityEnvironment\n", "from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper\n", "from mlagents_envs.registry import UnityEnvRegistry, default_registry\n", "from mlagents_envs.side_channel.engine_configuration_channel import (\n", " EngineConfig,\n", " EngineConfigurationChannel,\n", ")\n", "\n", "NUM_ENVS = 8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Environment and Engine Configurations" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Default values from CLI (See cli_utils.py)\n", "DEFAULT_ENGINE_CONFIG = EngineConfig(\n", " width=84,\n", " height=84,\n", " quality_level=4,\n", " time_scale=20,\n", " target_frame_rate=-1,\n", " capture_frame_rate=60,\n", ")\n", "\n", "# Some config subset of an actual config.yaml file for MLA.\n", "@dataclass\n", "class LimitedConfig:\n", " # The local path to a Unity executable or the name of an entry in the registry.\n", " env_path_or_name: str\n", " base_port: int\n", " base_seed: int = 0\n", " num_env: int = 1\n", " engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG\n", " visual_obs: bool = False\n", " # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow.\n", " # WARNING: Make sure to use MultiInputPolicy if you turn this on.\n", " allow_multiple_obs: bool = False\n", " env_registry: UnityEnvRegistry = default_registry" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Unity Environment SB3 Factory" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _unity_env_from_path_or_registry(\n", " env: str, registry: UnityEnvRegistry, **kwargs: Any\n", ") -> UnityEnvironment:\n", " if Path(env).expanduser().absolute().exists():\n", " return UnityEnvironment(file_name=env, **kwargs)\n", " elif env in registry:\n", " return registry.get(env).make(**kwargs)\n", " else:\n", " raise ValueError(f\"Environment '{env}' wasn't a local path or registry entry\")\n", " \n", "def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:\n", " def handle_obs(obs, space):\n", " if isinstance(space, gym.spaces.Tuple):\n", " if len(space) == 1:\n", " return obs[0]\n", " # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n", " return {str(i): v for i, v in enumerate(obs)}\n", " return obs\n", "\n", " def handle_obs_space(space):\n", " if isinstance(space, gym.spaces.Tuple):\n", " if len(space) == 1:\n", " return space[0]\n", " # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).\n", " return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})\n", " return space\n", "\n", " def create_env(env: str, worker_id: int) -> Callable[[], Env]:\n", " def _f() -> Env:\n", " engine_configuration_channel = EngineConfigurationChannel()\n", " engine_configuration_channel.set_configuration(config.engine_config)\n", " kwargs[\"side_channels\"] = kwargs.get(\"side_channels\", []) + [\n", " engine_configuration_channel\n", " ]\n", " unity_env = _unity_env_from_path_or_registry(\n", " env=env,\n", " registry=config.env_registry,\n", " worker_id=worker_id,\n", " base_port=config.base_port,\n", " seed=config.base_seed + worker_id,\n", " **kwargs,\n", " )\n", " new_env = UnityToGymWrapper(\n", " unity_env=unity_env,\n", " uint8_visual=config.visual_obs,\n", " allow_multiple_obs=config.allow_multiple_obs,\n", " )\n", " new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)\n", " return new_env\n", "\n", " return _f\n", "\n", " env_facts = [\n", " create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)\n", " ]\n", " return SubprocVecEnv(env_facts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Start Environment from the registry" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "is_executing": true, "name": "#%%\n" } }, "outputs": [], "source": [ "# -----------------\n", "# This code is used to close an env that might not have been closed before\n", "try:\n", " env.close()\n", "except:\n", " pass\n", "# -----------------\n", "\n", "env = make_mla_sb3_env(\n", " config=LimitedConfig(\n", " env_path_or_name='Basic', # Can use any name from a registry or a path to your own unity build.\n", " base_port=6006,\n", " base_seed=42,\n", " num_env=NUM_ENVS,\n", " allow_multiple_obs=True,\n", " ),\n", " no_graphics=True, # Set to false if you are running locally and want to watch the environments move around as they train.\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "is_executing": true, "name": "#%%\n" } }, "outputs": [], "source": [ "# 250K should train to a reward ~= 0.90 for the \"Basic\" environment.\n", "# We set the value lower here to demonstrate just a small amount of trianing.\n", "BATCH_SIZE = 32\n", "BUFFER_SIZE = 256\n", "UPDATES = 50\n", "TOTAL_TAINING_STEPS_GOAL = BUFFER_SIZE * UPDATES\n", "BETA = 0.0005\n", "N_EPOCHS = 3 \n", "STEPS_PER_UPDATE = BUFFER_SIZE / NUM_ENVS\n", "\n", "# Helps gather stats for our eval() calls later so we can see reward stats.\n", "env = VecMonitor(env)\n", "\n", "#Policy and Value function with 2 layers of 128 units each and no shared layers.\n", "policy_kwargs = {\"net_arch\" : [{\"pi\": [32,32], \"vf\": [32,32]}]}\n", "\n", "model = PPO(\n", " \"MlpPolicy\",\n", " env,\n", " verbose=1,\n", " learning_rate=lambda progress: 0.0003 * (1.0 - progress),\n", " clip_range=lambda progress: 0.2 * (1.0 - progress),\n", " clip_range_vf=lambda progress: 0.2 * (1.0 - progress),\n", " # Uncomment this if you want to log tensorboard results when running this notebook locally.\n", " # tensorboard_log=\"results\",\n", " policy_kwargs=policy_kwargs,\n", " n_steps=int(STEPS_PER_UPDATE),\n", " batch_size=BATCH_SIZE,\n", " n_epochs=N_EPOCHS,\n", " ent_coef=BETA,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "pycharm": { "is_executing": true, "name": "#%%\n" } }, "outputs": [], "source": [ "# 0.93 is considered solved for the Basic environment\n", "for i in range(UPDATES):\n", " print(f\"Training round {i + 1}/{UPDATES}\")\n", " # NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.\n", " model.learn(total_timesteps=BUFFER_SIZE, reset_num_timesteps=(i == 0))\n", " model.policy.eval()" ] }, { "cell_type": "markdown", "metadata": { "id": "h1lIx3_l24OP" }, "source": [ "### Close the environment\n", "Frees up the ports being used." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vdWG6_SqtNtv", "pycharm": { "is_executing": true, "name": "#%%\n" } }, "outputs": [], "source": [ "env.close()\n", "print(\"Closed environment\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "Colab-UnityEnvironment-1-Run.ipynb", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.8" } }, "nbformat": 4, "nbformat_minor": 4 }