ledmands commited on
Commit
ba5687a
1 Parent(s): 68fdd2c

Added notebook structure for training first agent.

Browse files
Files changed (1) hide show
  1. notebooks/dqn_pacmanv5_run1.ipynb +318 -0
notebooks/dqn_pacmanv5_run1.ipynb ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "%%capture\n",
10
+ "!pip install stable-baselines3[extra]\n",
11
+ "!pip install moviepy"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from stable_baselines3 import DQN\n",
21
+ "from stable_baselines3.common.monitor import Monitor\n",
22
+ "from stable_baselines3.common.callbacks import BaseCallback, EvalCallback, CallbackList\n",
23
+ "from stable_baselines3.common.logger import Video, HParam, TensorBoardOutputFormat\n",
24
+ "from stable_baselines3.common.evaluation import evaluate_policy\n",
25
+ "\n",
26
+ "from typing import Any, Dict\n",
27
+ "\n",
28
+ "import gymnasium as gym\n",
29
+ "import torch as th\n",
30
+ "import numpy as np\n",
31
+ "\n",
32
+ "# Model Config\n",
33
+ "EVAL_CALLBACK_FREQ = 1_000\n",
34
+ "VIDEO_CALLBACK_FREQ = 2_000\n",
35
+ "FRAMESKIP = 4\n",
36
+ "NUM_TIMESTEPS = 10_000\n",
37
+ "\n",
38
+ "# Hyperparams\n",
39
+ "EXPLORATION_FRACTION = 0.35\n",
40
+ "BUFFER_SIZE = 60_000\n",
41
+ "BATCH_SIZE = 8\n",
42
+ "LEARNING_STARTS = 1_000\n",
43
+ "LEARNING_RATE = 0.0001\n",
44
+ "GAMMA = 0.999\n",
45
+ "FINAL_EPSILON = 0.1"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "# VideoRecorderCallback\n",
55
+ "# The VideoRecorderCallback should record a video of the agent in the evaluation environment\n",
56
+ "# every render_freq timesteps. It will record one episode. It will also record one episode when\n",
57
+ "# the training has been completed\n",
58
+ "\n",
59
+ "class VideoRecorderCallback(BaseCallback):\n",
60
+ " def __init__(self, eval_env: gym.Env, render_freq: int, n_eval_episodes: int = 1, deterministic: bool = True):\n",
61
+ " \"\"\"\n",
62
+ " Records a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard.\n",
63
+ " :param eval_env: A gym environment from which the trajectory is recorded\n",
64
+ " :param render_freq: Render the agent's trajectory every eval_freq call of the callback.\n",
65
+ " :param n_eval_episodes: Number of episodes to render\n",
66
+ " :param deterministic: Whether to use deterministic or stochastic policy\n",
67
+ " \"\"\"\n",
68
+ " super().__init__()\n",
69
+ " self._eval_env = eval_env\n",
70
+ " self._render_freq = render_freq\n",
71
+ " self._n_eval_episodes = n_eval_episodes\n",
72
+ " self._deterministic = deterministic\n",
73
+ "\n",
74
+ " def _on_step(self) -> bool:\n",
75
+ " if self.n_calls % self._render_freq == 0:\n",
76
+ " screens = []\n",
77
+ "\n",
78
+ " def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:\n",
79
+ " \"\"\"\n",
80
+ " Renders the environment in its current state, recording the screen in the captured `screens` list\n",
81
+ " :param _locals: A dictionary containing all local variables of the callback's scope\n",
82
+ " :param _globals: A dictionary containing all global variables of the callback's scope\n",
83
+ " \"\"\"\n",
84
+ " screen = self._eval_env.render()\n",
85
+ " # PyTorch uses CxHxW vs HxWxC gym (and tensorflow) image convention\n",
86
+ " screens.append(screen.transpose(2, 0, 1))\n",
87
+ "\n",
88
+ " evaluate_policy(\n",
89
+ " self.model,\n",
90
+ " self._eval_env,\n",
91
+ " callback=grab_screens,\n",
92
+ " n_eval_episodes=self._n_eval_episodes,\n",
93
+ " deterministic=self._deterministic,\n",
94
+ " )\n",
95
+ " self.logger.record(\n",
96
+ " \"trajectory/video\",\n",
97
+ " Video(th.from_numpy(np.array([screens])), fps=60),\n",
98
+ " exclude=(\"stdout\", \"log\", \"json\", \"csv\"),\n",
99
+ " )\n",
100
+ " return True"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "# HParamCallback\n",
110
+ "# This should log the hyperparameters specified and map the metrics that are logged to \n",
111
+ "# the appropriate run.\n",
112
+ "class HParamCallback(BaseCallback):\n",
113
+ " \"\"\"\n",
114
+ " Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.\n",
115
+ " \"\"\" \n",
116
+ " def __init__(self):\n",
117
+ " super().__init__()\n",
118
+ " \n",
119
+ "\n",
120
+ " def _on_training_start(self) -> None:\n",
121
+ " \n",
122
+ " hparam_dict = {\n",
123
+ " \"algorithm\": self.model.__class__.__name__,\n",
124
+ " \"policy\": self.model.policy.__class__.__name__,\n",
125
+ " \"environment\": self.model.env.__class__.__name__,\n",
126
+ " \"buffer_size\": self.model.buffer_size,\n",
127
+ " \"batch_size\": self.model.batch_size,\n",
128
+ " \"tau\": self.model.tau,\n",
129
+ " \"gradient_steps\": self.model.gradient_steps,\n",
130
+ " \"target_update_interval\": self.model.target_update_interval,\n",
131
+ " \"exploration_fraction\": self.model.exploration_fraction,\n",
132
+ " \"exploration_initial_eps\": self.model.exploration_initial_eps,\n",
133
+ " \"exploration_final_eps\": self.model.exploration_final_eps,\n",
134
+ " \"max_grad_norm\": self.model.max_grad_norm,\n",
135
+ " \"tensorboard_log\": self.model.tensorboard_log,\n",
136
+ " \"seed\": self.model.seed, \n",
137
+ " \"learning rate\": self.model.learning_rate,\n",
138
+ " \"gamma\": self.model.gamma, \n",
139
+ " }\n",
140
+ " # define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag\n",
141
+ " # Tensorbaord will find & display metrics from the `SCALARS` tab\n",
142
+ " metric_dict = {\n",
143
+ " \"eval/mean_ep_length\": 0,\n",
144
+ " \"eval/mean_reward\": 0,\n",
145
+ " \"rollout/ep_len_mean\": 0,\n",
146
+ " \"rollout/ep_rew_mean\": 0,\n",
147
+ " \"rollout/exploration_rate\": 0,\n",
148
+ " \"time/_episode_num\": 0,\n",
149
+ " \"time/fps\": 0,\n",
150
+ " \"time/total_timesteps\": 0,\n",
151
+ " \"train/learning_rate\": 0.0,\n",
152
+ " \"train/loss\": 0.0,\n",
153
+ " \"train/n_updates\": 0.0,\n",
154
+ " \"locals/rewards\": 0.0,\n",
155
+ " \"locals/infos_0_lives\": 0.0,\n",
156
+ " \"locals/num_collected_steps\": 0.0,\n",
157
+ " \"locals/num_collected_episodes\": 0.0\n",
158
+ " }\n",
159
+ " \n",
160
+ " self.logger.record(\n",
161
+ " \"hparams\",\n",
162
+ " HParam(hparam_dict, metric_dict),\n",
163
+ " exclude=(\"stdout\", \"log\", \"json\", \"csv\"),\n",
164
+ " )"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "# PlotTensorboardValuesCallback\n",
174
+ "# This callback should log values to tensorboard on every step. \n",
175
+ "# The self.logger class should plot a new scalar value when recording.\n",
176
+ "\n",
177
+ "class PlotTensorboardValuesCallback(BaseCallback):\n",
178
+ " \"\"\"\n",
179
+ " Custom callback for plotting additional values in tensorboard.\n",
180
+ " \"\"\"\n",
181
+ " def __init__(self, eval_env: gym.Env, train_env: gym.Env, model: DQN, verbose=0):\n",
182
+ " super().__init__(verbose)\n",
183
+ " self._eval_env = eval_env\n",
184
+ " self._train_env = train_env\n",
185
+ " self._model = model\n",
186
+ "\n",
187
+ " def _on_training_start(self) -> None:\n",
188
+ " output_formats = self.logger.output_formats\n",
189
+ " # Save reference to tensorboard formatter object\n",
190
+ " # note: the failure case (not formatter found) is not handled here, should be done with try/except.\n",
191
+ " try:\n",
192
+ " self.tb_formatter = next(formatter for formatter in output_formats if isinstance(formatter, TensorBoardOutputFormat))\n",
193
+ " except:\n",
194
+ " print(\"Exception thrown in tb_formatter initialization.\") \n",
195
+ " \n",
196
+ " self.tb_formatter.writer.add_text(\"metadata/eval_env\", str(self._eval_env.metadata), self.num_timesteps)\n",
197
+ " self.tb_formatter.writer.flush()\n",
198
+ " self.tb_formatter.writer.add_text(\"metadata/train_env\", str(self._train_env.metadata), self.num_timesteps)\n",
199
+ " self.tb_formatter.writer.flush()\n",
200
+ " self.tb_formatter.writer.add_text(\"model/q_net\", str(self._model.q_net), self.num_timesteps)\n",
201
+ " self.tb_formatter.writer.flush()\n",
202
+ " self.tb_formatter.writer.add_text(\"model/q_net_target\", str(self._model.q_net_target), self.num_timesteps)\n",
203
+ " self.tb_formatter.writer.flush()\n",
204
+ "\n",
205
+ " def _on_step(self) -> bool:\n",
206
+ " self.logger.record(\"time/_episode_num\", self.model._episode_num, exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
207
+ " self.logger.record(\"train/n_updates\", self.model._n_updates, exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
208
+ " self.logger.record(\"locals/rewards\", self.locals[\"rewards\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
209
+ " self.logger.record(\"locals/infos_0_lives\", self.locals[\"infos\"][0][\"lives\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
210
+ " self.logger.record(\"locals/num_collected_steps\", self.locals[\"num_collected_steps\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
211
+ " self.logger.record(\"locals/num_collected_episodes\", self.locals[\"num_collected_episodes\"], exclude=(\"stdout\", \"log\", \"json\", \"csv\"))\n",
212
+ " \n",
213
+ " return True\n",
214
+ " \n",
215
+ " def _on_training_end(self) -> None:\n",
216
+ " self.tb_formatter.writer.add_text(\"metadata/eval_env\", str(self._eval_env.metadata), self.num_timesteps)\n",
217
+ " self.tb_formatter.writer.flush()\n",
218
+ " self.tb_formatter.writer.add_text(\"metadata/train_env\", str(self._train_env.metadata), self.num_timesteps)\n",
219
+ " self.tb_formatter.writer.flush()\n",
220
+ " self.tb_formatter.writer.add_text(\"model/q_net\", str(self._model.q_net), self.num_timesteps)\n",
221
+ " self.tb_formatter.writer.flush()\n",
222
+ " self.tb_formatter.writer.add_text(\"model/q_net_target\", str(self._model.q_net_target), self.num_timesteps)\n",
223
+ " self.tb_formatter.writer.flush()"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": [
232
+ "# make the training and evaluation environments\n",
233
+ "eval_env = Monitor(gym.make(\"ALE/Pacman-v5\", render_mode=\"rgb_array\", frameskip=FRAMESKIP))\n",
234
+ "train_env = gym.make(\"ALE/Pacman-v5\", render_mode=\"rgb_array\", frameskip=FRAMESKIP)\n",
235
+ "\n",
236
+ "# Make the model with specified hyperparams\n",
237
+ "model = DQN(\n",
238
+ " \"CnnPolicy\",\n",
239
+ " train_env,\n",
240
+ " verbose=1,\n",
241
+ " buffer_size=BUFFER_SIZE,\n",
242
+ " exploration_fraction = EXPLORATION_FRACTION,\n",
243
+ " batch_size=BATCH_SIZE,\n",
244
+ " exploration_final_eps=FINAL_EPSILON,\n",
245
+ " gamma=GAMMA,\n",
246
+ " learning_starts=LEARNING_STARTS,\n",
247
+ " learning_rate=LEARNING_RATE,\n",
248
+ " tensorboard_log=\"./\")"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "metadata": {},
255
+ "outputs": [],
256
+ "source": [
257
+ "# Define the callbacks and put them in a list\n",
258
+ "eval_callback = EvalCallback(\n",
259
+ " eval_env,\n",
260
+ " best_model_save_path=\"./best_model/\",\n",
261
+ " log_path=\"./evals/\",\n",
262
+ " eval_freq=EVAL_CALLBACK_FREQ,\n",
263
+ " n_eval_episodes=10,\n",
264
+ " deterministic=True,\n",
265
+ " render=False)\n",
266
+ "\n",
267
+ "tbplot_callback = PlotTensorboardValuesCallback(eval_env=eval_env, train_env=train_env, model=model)\n",
268
+ "video_callback = VideoRecorderCallback(eval_env, render_freq=VIDEO_CALLBACK_FREQ)\n",
269
+ "hparam_callback = HParamCallback()\n",
270
+ "\n",
271
+ "callback_list = CallbackList([hparam_callback, eval_callback, video_callback, tbplot_callback])"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "metadata": {},
278
+ "outputs": [],
279
+ "source": [
280
+ "# Train the model\n",
281
+ "# model.learn(total_timesteps=NUM_TIMESTEPS, callback=callback_list, tb_log_name=\"./tb/\")"
282
+ ]
283
+ },
284
+ {
285
+ "cell_type": "code",
286
+ "execution_count": null,
287
+ "metadata": {},
288
+ "outputs": [],
289
+ "source": [
290
+ "# Save the model, policy, and replay buffer for future loading and training\n",
291
+ "model.save(\"ALE-Pacman-v5\")\n",
292
+ "model.save_replay_buffer(\"dqn_replay_buffer_pacman\")\n",
293
+ "model.policy.save(\"dqn_policy_pacman\")"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": []
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": []
309
+ }
310
+ ],
311
+ "metadata": {
312
+ "language_info": {
313
+ "name": "python"
314
+ }
315
+ },
316
+ "nbformat": 4,
317
+ "nbformat_minor": 2
318
+ }