{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install stable-baselines3[extra]\n", "!pip install moviepy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from stable_baselines3 import DQN\n", "from stable_baselines3.common.monitor import Monitor\n", "from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList\n", "from stable_baselines3.common.logger import Video, HParam, TensorBoardOutputFormat\n", "from stable_baselines3.common.evaluation import evaluate_policy\n", "\n", "from typing import Any, Dict\n", "\n", "import gymnasium as gym\n", "import torch as th\n", "import numpy as np\n", "\n", "# =====File names=====\n", "MODEL_FILE_NAME = \"ALE-Pacman-v5\"\n", "BUFFER_FILE_NAME = \"dqn_replay_buffer_pacman_v1\"\n", "POLICY_FILE_NAME = \"dqn_policy_pacman_v1\"\n", "\n", "# =====Model Config=====\n", "# Evaluate in tenths\n", "EVAL_CALLBACK_FREQ = 150_000\n", "# Record in quarters (the last one won't record, will have to do manually)\n", "VIDEO_CALLBACK_FREQ = 250_000\n", "FRAMESKIP = 4\n", "NUM_TIMESTEPS = 1_500_000\n", "\n", "# =====Hyperparams=====\n", "EXPLORATION_FRACTION = 0.3\n", "# Buffer size needs to be less than about 60k in order to save it in a Kaggle instance\n", "BUFFER_SIZE = 60_000\n", "BATCH_SIZE = 8\n", "LEARNING_STARTS = 50_000\n", "LEARNING_RATE = 0.0002\n", "GAMMA = 0.999\n", "FINAL_EPSILON = 0.1\n", "# Target Update Interval is set to 10k by default and looks like it is set to \n", "# 4 in the Nature paper. This is a large discrepency and makes me wonder if it \n", "# is something different or measured differently...\n", "TARGET_UPDATE_INTERVAL = 1_000" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# VideoRecorderCallback\n", "# The VideoRecorderCallback should record a video of the agent in the evaluation environment\n", "# every render_freq timesteps. It will record one episode. It will also record one episode when\n", "# the training has been completed\n", "\n", "class VideoRecorderCallback(BaseCallback):\n", " def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):\n", " \"\"\"\n", " Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard.\n", " :param eval_env: A gym environment from which the trajectory is recorded\n", " :param render_freq: Render the agent's trajectory every eval_freq call of the callback.\n", " :param n_eval_episodes: Number of episodes to render\n", " :param deterministic: Whether to use deterministic or stochastic policy\n", " \"\"\"\n", " super().__init__()\n", " self._eval_env = eval_env\n", " self._render_freq = render_freq\n", " self._n_eval_episodes = n_eval_episodes\n", " self._deterministic = deterministic\n", "\n", " def _on_step(self) -> bool:\n", " if self.n_calls % self._render_freq == 0:\n", " screens = []\n", "\n", " def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:\n", " \"\"\"\n", " Renders the environment in its current state, recording the screen in the captured `screens` list\n", " :param _locals: A dictionary containing all local variables of the callback's scope\n", " :param _globals: A dictionary containing all global variables of the callback's scope\n", " \"\"\"\n", " screen = self._eval_env.render()\n", " # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention\n", " screens.append(screen.transpose(2, 0, 1))\n", "\n", " evaluate_policy(\n", " self.model,\n", " self._eval_env,\n", " callback=grab_screens,\n", " n_eval_episodes=self._n_eval_episodes,\n", " deterministic=self._deterministic,\n", " )\n", " self.logger.record(\n", " \"trajectory/video\",\n", " Video(th.from_numpy(np.array([screens])), fps=60),\n", " exclude=(\"stdout\", \"log\", \"json\", \"csv\"),\n", " )\n", " return True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# HParamCallback\n", "# This should log the hyperparameters specified and map the metrics that are logged to \n", "# the appropriate run.\n", "class HParamCallback(BaseCallback):\n", " \"\"\"\n", " Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.\n", " \"\"\" \n", " def __init__(self):\n", " super().__init__()\n", " \n", "\n", " def _on_training_start(self) -> None:\n", " \n", " hparam_dict = {\n", " \"algorithm\": self.model.__class__.__name__,\n", " \"policy\": self.model.policy.__class__.__name__,\n", " \"environment\": self.model.env.__class__.__name__,\n", " \"buffer_size\": self.model.buffer_size,\n", " \"batch_size\": self.model.batch_size,\n", " \"tau\": self.model.tau,\n", " \"gradient_steps\": self.model.gradient_steps,\n", " \"target_update_interval\": self.model.target_update_interval,\n", " \"exploration_fraction\": self.model.exploration_fraction,\n", " \"exploration_initial_eps\": self.model.exploration_initial_eps,\n", " \"exploration_final_eps\": self.model.exploration_final_eps,\n", " \"max_grad_norm\": self.model.max_grad_norm,\n", " \"tensorboard_log\": self.model.tensorboard_log,\n", " \"seed\": self.model.seed, \n", " \"learning rate\": self.model.learning_rate,\n", " \"gamma\": self.model.gamma, \n", " }\n", " # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag\n", " # Tensorbaord will find & display metrics from the `SCALARS` tab\n", " metric_dict = {\n", " \"eval/mean_ep_length\": 0,\n", " \"eval/mean_reward\": 0,\n", " \"rollout/ep_len_mean\": 0,\n", " \"rollout/ep_rew_mean\": 0,\n", " \"rollout/exploration_rate\": 0,\n", " \"time/_episode_num\": 0,\n", " \"time/fps\": 0,\n", " \"time/total_timesteps\": 0,\n", " \"train/learning_rate\": 0.0,\n", " \"train/loss\": 0.0,\n", " \"train/n_updates\": 0.0,\n", " \"locals/rewards\": 0.0,\n", " \"locals/infos_0_lives\": 0.0,\n", " \"locals/num_collected_steps\": 0.0,\n", " \"locals/num_collected_episodes\": 0.0\n", " }\n", " \n", " self.logger.record(\n", " \"hparams\",\n", " HParam(hparam_dict, metric_dict),\n", " exclude=(\"stdout\", \"log\", \"json\", \"csv\"),\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# PlotTensorboardValuesCallback\n", "# This callback should log values to tensorboard on every step. \n", "# The self.logger class should plot a new scalar value when recording.\n", "\n", "class PlotTensorboardValuesCallback(BaseCallback):\n", " \"\"\"\n", " Custom callback for plotting additional values in tensorboard.\n", " \"\"\"\n", " def __init__(self, eval_env: gym.Env, train_env: gym.Env, model: DQN, verbose=0):\n", " super().__init__(verbose)\n", " self._eval_env = eval_env\n", " self._train_env = train_env\n", " self._model = model\n", "\n", " def _on_training_start(self) -> None:\n", " output_formats = self.logger.output_formats\n", " # Save reference to tensorboard formatter object\n", " # note: the failure case (not formatter found) is not handled here, should be done with try/except.\n", " try:\n", " self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))\n", " except:\n", " print(\"Exception thrown in tb_formatter initialization.\") \n", " \n", " self.tb_formatter.writer.add_text(\"metadata/eval_env\", str(self._eval_env.metadata), self.num_timesteps)\n", " self.tb_formatter.writer.flush()\n", " self.tb_formatter.writer.add_text(\"metadata/train_env\", str(self._train_env.metadata), self.num_timesteps)\n", " self.tb_formatter.writer.flush()\n", " self.tb_formatter.writer.add_text(\"model/q_net\", str(self._model.q_net), self.num_timesteps)\n", " self.tb_formatter.writer.flush()\n", " self.tb_formatter.writer.add_text(\"model/q_net_target\", str(self._model.q_net_target), self.num_timesteps)\n", " self.tb_formatter.writer.flush()\n", "\n", " def _on_step(self) -> bool:\n", " self.logger.record(\"time/_episode_num\", self.model._episode_num, exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n", " self.logger.record(\"train/n_updates\", self.model._n_updates, exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n", " self.logger.record(\"locals/rewards\", self.locals[\"rewards\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n", " self.logger.record(\"locals/infos_0_lives\", self.locals[\"infos\"][0][\"lives\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n", " self.logger.record(\"locals/num_collected_steps\", self.locals[\"num_collected_steps\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n", " self.logger.record(\"locals/num_collected_episodes\", self.locals[\"num_collected_episodes\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n", " \n", " return True\n", " \n", " def _on_training_end(self) -> None:\n", " self.tb_formatter.writer.add_text(\"metadata/eval_env\", str(self._eval_env.metadata), self.num_timesteps)\n", " self.tb_formatter.writer.flush()\n", " self.tb_formatter.writer.add_text(\"metadata/train_env\", str(self._train_env.metadata), self.num_timesteps)\n", " self.tb_formatter.writer.flush()\n", " self.tb_formatter.writer.add_text(\"model/q_net\", str(self._model.q_net), self.num_timesteps)\n", " self.tb_formatter.writer.flush()\n", " self.tb_formatter.writer.add_text(\"model/q_net_target\", str(self._model.q_net_target), self.num_timesteps)\n", " self.tb_formatter.writer.flush()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# make the training and evaluation environments\n", "eval_env = Monitor(gym.make(\"ALE/Pacman-v5\", render_mode=\"rgb_array\", frameskip=FRAMESKIP))\n", "train_env = gym.make(\"ALE/Pacman-v5\", render_mode=\"rgb_array\", frameskip=FRAMESKIP)\n", "\n", "# Make the model with specified hyperparams\n", "model = DQN(\n", " \"CnnPolicy\",\n", " train_env,\n", " verbose=1,\n", " buffer_size=BUFFER_SIZE,\n", " exploration_fraction = EXPLORATION_FRACTION,\n", " batch_size=BATCH_SIZE,\n", " exploration_final_eps=FINAL_EPSILON,\n", " gamma=GAMMA,\n", " learning_starts=LEARNING_STARTS,\n", " learning_rate=LEARNING_RATE,\n", " target_update_interval=TARGET_UPDATE_INTERVAL,\n", " tensorboard_log=\"./\",\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Define the callbacks and put them in a list\n", "eval_callback = EvalCallback(\n", " eval_env,\n", " best_model_save_path=\"./best_model/\",\n", " log_path=\"./evals/\",\n", " eval_freq=EVAL_CALLBACK_FREQ,\n", " n_eval_episodes=10,\n", " deterministic=True,\n", " render=False)\n", "\n", "tbplot_callback = PlotTensorboardValuesCallback(eval_env=eval_env, train_env=train_env, model=model)\n", "video_callback = VideoRecorderCallback(eval_env, render_freq=VIDEO_CALLBACK_FREQ)\n", "hparam_callback = HParamCallback()\n", "\n", "callback_list = CallbackList([hparam_callback, eval_callback, video_callback, tbplot_callback])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train the model\n", "model.learn(total_timesteps=NUM_TIMESTEPS, callback=callback_list, tb_log_name=\"./tb/\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save the model, policy, and replay buffer for future loading and training\n", "model.save(MODEL_FILE_NAME)\n", "model.save_replay_buffer(BUFFER_FILE_NAME)\n", "model.policy.save(POLICY_FILE_NAME)" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 2 }