# ML-Agents run with Stable Baselines 3


## Setup

In [None]:
#@title Install Rendering Dependencies { display-mode: "form" }
#@markdown (You only need to run this code when using Colab's hosted runtime)

import os
from IPython.display import HTML, display

def progress(value, max=100):
 return HTML("""
 
 """.format(value=value, max=max))

pro_bar = display(progress(0, 100), display_id=True)

try:
 import google.colab
 INSTALL_XVFB = True
except ImportError:
 INSTALL_XVFB = 'COLAB_ALWAYS_INSTALL_XVFB' in os.environ

if INSTALL_XVFB:
 with open('frame-buffer', 'w') as writefile:
 writefile.write("""#taken from https://gist.github.com/jterrace/2911875
XVFB=/usr/bin/Xvfb
XVFBARGS=":1 -screen 0 1024x768x24 -ac +extension GLX +render -noreset"
PIDFILE=./frame-buffer.pid
case "$1" in
 start)
 echo -n "Starting virtual X frame buffer: Xvfb"
 /sbin/start-stop-daemon --start --quiet --pidfile $PIDFILE --make-pidfile --background --exec $XVFB -- $XVFBARGS
 echo "."
 ;;
 stop)
 echo -n "Stopping virtual X frame buffer: Xvfb"
 /sbin/start-stop-daemon --stop --quiet --pidfile $PIDFILE
 rm $PIDFILE
 echo "."
 ;;
 restart)
 $0 stop
 $0 start
 ;;
 *)
 echo "Usage: /etc/init.d/xvfb {start|stop|restart}"
 exit 1
esac
exit 0
 """)
 !sudo apt-get update
 pro_bar.update(progress(10, 100))
 !sudo DEBIAN_FRONTEND=noninteractive apt install -y daemon wget gdebi-core build-essential libfontenc1 libfreetype6 xorg-dev xorg
 pro_bar.update(progress(20, 100))
 !wget http://security.ubuntu.com/ubuntu/pool/main/libx/libxfont/libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1
 pro_bar.update(progress(30, 100))
 !wget --output-document xvfb.deb http://security.ubuntu.com/ubuntu/pool/universe/x/xorg-server/xvfb_1.18.4-0ubuntu0.12_amd64.deb 2>&1
 pro_bar.update(progress(40, 100))
 !sudo dpkg -i libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb 2>&1
 pro_bar.update(progress(50, 100))
 !sudo dpkg -i xvfb.deb 2>&1
 pro_bar.update(progress(70, 100))
 !rm libxfont1_1.5.1-1ubuntu0.16.04.4_amd64.deb
 pro_bar.update(progress(80, 100))
 !rm xvfb.deb
 pro_bar.update(progress(90, 100))
 !bash frame-buffer start
 os.environ["DISPLAY"] = ":1"
pro_bar.update(progress(100, 100))

### Installing ml-agents

In [None]:
try:
 import mlagents
 print("ml-agents already installed")
except ImportError:
 !python -m pip install -q mlagents==0.30.0
 print("Installed ml-agents")

## Run the Environment

### Import dependencies and set some high level parameters.

In [None]:
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Any

import gym
from gym import Env

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecMonitor, VecEnv, SubprocVecEnv
from supersuit import observation_lambda_v0


from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.envs.unity_gym_env import UnityToGymWrapper
from mlagents_envs.registry import UnityEnvRegistry, default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
 EngineConfig,
 EngineConfigurationChannel,
)

NUM_ENVS = 8

### Environment and Engine Configurations

In [None]:
# Default values from CLI (See cli_utils.py)
DEFAULT_ENGINE_CONFIG = EngineConfig(
 width=84,
 height=84,
 quality_level=4,
 time_scale=20,
 target_frame_rate=-1,
 capture_frame_rate=60,
)

# Some config subset of an actual config.yaml file for MLA.
@dataclass
class LimitedConfig:
 # The local path to a Unity executable or the name of an entry in the registry.
 env_path_or_name: str
 base_port: int
 base_seed: int = 0
 num_env: int = 1
 engine_config: EngineConfig = DEFAULT_ENGINE_CONFIG
 visual_obs: bool = False
 # TODO: Decide if we should just tell users to always use MultiInputPolicy so we can simplify the user workflow.
 # WARNING: Make sure to use MultiInputPolicy if you turn this on.
 allow_multiple_obs: bool = False
 env_registry: UnityEnvRegistry = default_registry

### Unity Environment SB3 Factory

In [None]:
def _unity_env_from_path_or_registry(
 env: str, registry: UnityEnvRegistry, **kwargs: Any
) -> UnityEnvironment:
 if Path(env).expanduser().absolute().exists():
 return UnityEnvironment(file_name=env, **kwargs)
 elif env in registry:
 return registry.get(env).make(**kwargs)
 else:
 raise ValueError(f"Environment '{env}' wasn't a local path or registry entry")
 
def make_mla_sb3_env(config: LimitedConfig, **kwargs: Any) -> VecEnv:
 def handle_obs(obs, space):
 if isinstance(space, gym.spaces.Tuple):
 if len(space) == 1:
 return obs[0]
 # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
 return {str(i): v for i, v in enumerate(obs)}
 return obs

 def handle_obs_space(space):
 if isinstance(space, gym.spaces.Tuple):
 if len(space) == 1:
 return space[0]
 # Turn the tuple into a dict (stable baselines can handle spaces.Dict but not spaces.Tuple).
 return gym.spaces.Dict({str(i): v for i, v in enumerate(space)})
 return space

 def create_env(env: str, worker_id: int) -> Callable[[], Env]:
 def _f() -> Env:
 engine_configuration_channel = EngineConfigurationChannel()
 engine_configuration_channel.set_configuration(config.engine_config)
 kwargs["side_channels"] = kwargs.get("side_channels", []) + [
 engine_configuration_channel
 ]
 unity_env = _unity_env_from_path_or_registry(
 env=env,
 registry=config.env_registry,
 worker_id=worker_id,
 base_port=config.base_port,
 seed=config.base_seed + worker_id,
 **kwargs,
 )
 new_env = UnityToGymWrapper(
 unity_env=unity_env,
 uint8_visual=config.visual_obs,
 allow_multiple_obs=config.allow_multiple_obs,
 )
 new_env = observation_lambda_v0(new_env, handle_obs, handle_obs_space)
 return new_env

 return _f

 env_facts = [
 create_env(config.env_path_or_name, worker_id=x) for x in range(config.num_env)
 ]
 return SubprocVecEnv(env_facts)

### Start Environment from the registry

In [None]:
# -----------------
# This code is used to close an env that might not have been closed before
try:
 env.close()
except:
 pass
# -----------------

env = make_mla_sb3_env(
 config=LimitedConfig(
 env_path_or_name='Basic', # Can use any name from a registry or a path to your own unity build.
 base_port=6006,
 base_seed=42,
 num_env=NUM_ENVS,
 allow_multiple_obs=True,
 ),
 no_graphics=True, # Set to false if you are running locally and want to watch the environments move around as they train.
)

### Create the model

In [None]:
# 250K should train to a reward ~= 0.90 for the "Basic" environment.
# We set the value lower here to demonstrate just a small amount of trianing.
BATCH_SIZE = 32
BUFFER_SIZE = 256
UPDATES = 50
TOTAL_TAINING_STEPS_GOAL = BUFFER_SIZE * UPDATES
BETA = 0.0005
N_EPOCHS = 3 
STEPS_PER_UPDATE = BUFFER_SIZE / NUM_ENVS

# Helps gather stats for our eval() calls later so we can see reward stats.
env = VecMonitor(env)

#Policy and Value function with 2 layers of 128 units each and no shared layers.
policy_kwargs = {"net_arch" : [{"pi": [32,32], "vf": [32,32]}]}

model = PPO(
 "MlpPolicy",
 env,
 verbose=1,
 learning_rate=lambda progress: 0.0003 * (1.0 - progress),
 clip_range=lambda progress: 0.2 * (1.0 - progress),
 clip_range_vf=lambda progress: 0.2 * (1.0 - progress),
 # Uncomment this if you want to log tensorboard results when running this notebook locally.
 # tensorboard_log="results",
 policy_kwargs=policy_kwargs,
 n_steps=int(STEPS_PER_UPDATE),
 batch_size=BATCH_SIZE,
 n_epochs=N_EPOCHS,
 ent_coef=BETA,
)

### Train the model

In [None]:
# 0.93 is considered solved for the Basic environment
for i in range(UPDATES):
 print(f"Training round {i + 1}/{UPDATES}")
 # NOTE: rest_num_timesteps should only happen the first time so that tensorboard logs are consistent.
 model.learn(total_timesteps=BUFFER_SIZE, reset_num_timesteps=(i == 0))
 model.policy.eval()

### Close the environment
Frees up the ports being used.

In [None]:
env.close()
print("Closed environment")
