ledmands
commited on
Commit
•
a3d1c3a
1
Parent(s):
ba5687a
Updated dqn_pacmanv5_run1.ipynb to final version prior to training.
Browse files
notebooks/dqn_pacmanv5_run1.ipynb
CHANGED
@@ -29,20 +29,32 @@
|
|
29 |
"import torch as th\n",
|
30 |
"import numpy as np\n",
|
31 |
"\n",
|
32 |
-
"#
|
33 |
-
"
|
34 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
"FRAMESKIP = 4\n",
|
36 |
-
"NUM_TIMESTEPS =
|
37 |
"\n",
|
38 |
-
"# Hyperparams
|
39 |
-
"EXPLORATION_FRACTION = 0.
|
|
|
40 |
"BUFFER_SIZE = 60_000\n",
|
41 |
"BATCH_SIZE = 8\n",
|
42 |
-
"LEARNING_STARTS =
|
43 |
-
"LEARNING_RATE = 0.
|
44 |
"GAMMA = 0.999\n",
|
45 |
-
"FINAL_EPSILON = 0.1"
|
|
|
|
|
|
|
|
|
46 |
]
|
47 |
},
|
48 |
{
|
@@ -245,7 +257,9 @@
|
|
245 |
" gamma=GAMMA,\n",
|
246 |
" learning_starts=LEARNING_STARTS,\n",
|
247 |
" learning_rate=LEARNING_RATE,\n",
|
248 |
-
"
|
|
|
|
|
249 |
]
|
250 |
},
|
251 |
{
|
@@ -278,7 +292,7 @@
|
|
278 |
"outputs": [],
|
279 |
"source": [
|
280 |
"# Train the model\n",
|
281 |
-
"
|
282 |
]
|
283 |
},
|
284 |
{
|
@@ -288,24 +302,10 @@
|
|
288 |
"outputs": [],
|
289 |
"source": [
|
290 |
"# Save the model, policy, and replay buffer for future loading and training\n",
|
291 |
-
"model.save(
|
292 |
-
"model.save_replay_buffer(
|
293 |
-
"model.policy.save(
|
294 |
]
|
295 |
-
},
|
296 |
-
{
|
297 |
-
"cell_type": "code",
|
298 |
-
"execution_count": null,
|
299 |
-
"metadata": {},
|
300 |
-
"outputs": [],
|
301 |
-
"source": []
|
302 |
-
},
|
303 |
-
{
|
304 |
-
"cell_type": "code",
|
305 |
-
"execution_count": null,
|
306 |
-
"metadata": {},
|
307 |
-
"outputs": [],
|
308 |
-
"source": []
|
309 |
}
|
310 |
],
|
311 |
"metadata": {
|
|
|
29 |
"import torch as th\n",
|
30 |
"import numpy as np\n",
|
31 |
"\n",
|
32 |
+
"# =====File names=====\n",
|
33 |
+
"MODEL_FILE_NAME = \"ALE-Pacman-v5\"\n",
|
34 |
+
"BUFFER_FILE_NAME = \"dqn_replay_buffer_pacman_v1\"\n",
|
35 |
+
"POLICY_FILE_NAME = \"dqn_policy_pacman_v1\"\n",
|
36 |
+
"\n",
|
37 |
+
"# =====Model Config=====\n",
|
38 |
+
"# Evaluate in tenths\n",
|
39 |
+
"EVAL_CALLBACK_FREQ = 150_000\n",
|
40 |
+
"# Record in quarters (the last one won't record, will have to do manually)\n",
|
41 |
+
"VIDEO_CALLBACK_FREQ = 250_000\n",
|
42 |
"FRAMESKIP = 4\n",
|
43 |
+
"NUM_TIMESTEPS = 1_500_000\n",
|
44 |
"\n",
|
45 |
+
"# =====Hyperparams=====\n",
|
46 |
+
"EXPLORATION_FRACTION = 0.3\n",
|
47 |
+
"# Buffer size needs to be less than about 60k in order to save it in a Kaggle instance\n",
|
48 |
"BUFFER_SIZE = 60_000\n",
|
49 |
"BATCH_SIZE = 8\n",
|
50 |
+
"LEARNING_STARTS = 50_000\n",
|
51 |
+
"LEARNING_RATE = 0.0002\n",
|
52 |
"GAMMA = 0.999\n",
|
53 |
+
"FINAL_EPSILON = 0.1\n",
|
54 |
+
"# Target Update Interval is set to 10k by default and looks like it is set to \n",
|
55 |
+
"# 4 in the Nature paper. This is a large discrepency and makes me wonder if it \n",
|
56 |
+
"# is something different or measured differently...\n",
|
57 |
+
"TARGET_UPDATE_INTERVAL = 1_000"
|
58 |
]
|
59 |
},
|
60 |
{
|
|
|
257 |
" gamma=GAMMA,\n",
|
258 |
" learning_starts=LEARNING_STARTS,\n",
|
259 |
" learning_rate=LEARNING_RATE,\n",
|
260 |
+
" target_update_interval=TARGET_UPDATE_INTERVAL,\n",
|
261 |
+
" tensorboard_log=\"./\",\n",
|
262 |
+
" )"
|
263 |
]
|
264 |
},
|
265 |
{
|
|
|
292 |
"outputs": [],
|
293 |
"source": [
|
294 |
"# Train the model\n",
|
295 |
+
"model.learn(total_timesteps=NUM_TIMESTEPS, callback=callback_list, tb_log_name=\"./tb/\")"
|
296 |
]
|
297 |
},
|
298 |
{
|
|
|
302 |
"outputs": [],
|
303 |
"source": [
|
304 |
"# Save the model, policy, and replay buffer for future loading and training\n",
|
305 |
+
"model.save(MODEL_FILE_NAME)\n",
|
306 |
+
"model.save_replay_buffer(BUFFER_FILE_NAME)\n",
|
307 |
+
"model.policy.save(POLICY_FILE_NAME)"
|
308 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
}
|
310 |
],
|
311 |
"metadata": {
|