sgoodfriend commited on
Commit
282f8ee
1 Parent(s): d030d40

A2C playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +20 -14
  2. benchmark_publish.py +2 -105
  3. colab/colab_atari1.sh +4 -0
  4. colab/colab_atari2.sh +4 -0
  5. colab/colab_basic.sh +4 -0
  6. colab/colab_benchmark.ipynb +195 -0
  7. colab/colab_carracing.sh +4 -0
  8. colab/colab_enjoy.ipynb +198 -0
  9. colab/colab_pybullet.sh +4 -0
  10. colab/colab_train.ipynb +200 -0
  11. compare_runs.py +2 -185
  12. enjoy.py +2 -28
  13. environment.yml +1 -6
  14. huggingface_publish.py +2 -187
  15. optimize.py +4 -0
  16. pyproject.toml +59 -30
  17. replay.meta.json +1 -1
  18. replay.mp4 +0 -0
  19. rl_algo_impls/a2c/a2c.py +209 -0
  20. rl_algo_impls/a2c/optimize.py +77 -0
  21. rl_algo_impls/benchmark_publish.py +111 -0
  22. rl_algo_impls/compare_runs.py +198 -0
  23. rl_algo_impls/dqn/dqn.py +182 -0
  24. rl_algo_impls/dqn/policy.py +55 -0
  25. rl_algo_impls/dqn/q_net.py +41 -0
  26. rl_algo_impls/enjoy.py +35 -0
  27. rl_algo_impls/huggingface_publish.py +193 -0
  28. rl_algo_impls/hyperparams/a2c.yml +138 -0
  29. rl_algo_impls/hyperparams/dqn.yml +130 -0
  30. rl_algo_impls/hyperparams/ppo.yml +383 -0
  31. rl_algo_impls/hyperparams/vpg.yml +197 -0
  32. rl_algo_impls/optimize.py +441 -0
  33. rl_algo_impls/ppo/ppo.py +353 -0
  34. rl_algo_impls/publish/markdown_format.py +210 -0
  35. rl_algo_impls/runner/config.py +189 -0
  36. rl_algo_impls/runner/env.py +292 -0
  37. rl_algo_impls/runner/evaluate.py +103 -0
  38. rl_algo_impls/runner/running_utils.py +192 -0
  39. rl_algo_impls/runner/train.py +143 -0
  40. rl_algo_impls/shared/algorithm.py +39 -0
  41. rl_algo_impls/shared/callbacks/callback.py +11 -0
  42. rl_algo_impls/shared/callbacks/eval_callback.py +199 -0
  43. rl_algo_impls/shared/callbacks/optimize_callback.py +117 -0
  44. rl_algo_impls/shared/gae.py +67 -0
  45. rl_algo_impls/shared/module/feature_extractor.py +215 -0
  46. rl_algo_impls/shared/module/module.py +40 -0
  47. rl_algo_impls/shared/policy/actor.py +310 -0
  48. rl_algo_impls/shared/policy/critic.py +28 -0
  49. rl_algo_impls/shared/policy/on_policy.py +226 -0
  50. rl_algo_impls/shared/policy/optimize_on_policy.py +35 -0
README.md CHANGED
@@ -23,17 +23,17 @@ model-index:
23
 
24
  This is a trained model of a **A2C** agent playing **CartPole-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
- All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/eyvb72mv.
27
 
28
  ## Training Results
29
 
30
- This model was trained from 3 trainings of **A2C** agents using different initial seeds. These agents were trained by checking out [0760ef7](https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
- | a2c | CartPole-v1 | 1 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/2ogq8udd) |
35
- | a2c | CartPole-v1 | 2 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/3e2qxbjj) |
36
- | a2c | CartPole-v1 | 3 | 500 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/k3ydp4p7) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
@@ -53,10 +53,10 @@ login`.
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
- [0760ef7](https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
- python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/k3ydp4p7
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -68,11 +68,11 @@ notebook.
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
- commit the agent was trained on: [0760ef7](https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
- python train.py --algo a2c --env CartPole-v1 --seed 3
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -83,14 +83,14 @@ notebook.
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
- This and other models from https://api.wandb.ai/links/sgoodfriend/eyvb72mv were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone git@github.com:sgoodfriend/rl-algo-impls.git
90
  cd rl-algo-impls
91
  bash ./lambda_labs/setup.sh
92
  wandb login
93
- bash ./lambda_labs/benchmark.sh
94
  ```
95
 
96
  ### Alternative: Google Colab Pro+
@@ -106,16 +106,22 @@ This isn't exactly the format of hyperparams in hyperparams/a2c.yml, but instead
106
  close and has some additional data:
107
  ```
108
  algo: a2c
 
 
109
  env: CartPole-v1
110
  env_hyperparams:
111
  n_envs: 8
 
 
112
  n_timesteps: 500000
113
- seed: 3
 
114
  use_deterministic_algorithms: true
115
  wandb_entity: null
 
116
  wandb_project_name: rl-algo-impls-benchmarks
117
  wandb_tags:
118
- - benchmark_0760ef7
119
- - host_192-9-248-209
120
 
121
  ```
 
23
 
24
  This is a trained model of a **A2C** agent playing **CartPole-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/09frjfcs.
27
 
28
  ## Training Results
29
 
30
+ This model was trained from 3 trainings of **A2C** agents using different initial seeds. These agents were trained by checking out [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | a2c | CartPole-v1 | 1 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/gwtpzn86) |
35
+ | a2c | CartPole-v1 | 2 | 500 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/r5mp7hup) |
36
+ | a2c | CartPole-v1 | 3 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/1tcre33a) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
 
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
+ [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/r5mp7hup
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [2067e21](https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
+ python train.py --algo a2c --env CartPole-v1 --seed 2
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/09frjfcs were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone git@github.com:sgoodfriend/rl-algo-impls.git
90
  cd rl-algo-impls
91
  bash ./lambda_labs/setup.sh
92
  wandb login
93
+ bash ./lambda_labs/benchmark.sh [-a {"ppo a2c dqn vpg"}] [-e ENVS] [-j {6}] [-p {rl-algo-impls-benchmarks}] [-s {"1 2 3"}]
94
  ```
95
 
96
  ### Alternative: Google Colab Pro+
 
106
  close and has some additional data:
107
  ```
108
  algo: a2c
109
+ algo_hyperparams: {}
110
+ device: auto
111
  env: CartPole-v1
112
  env_hyperparams:
113
  n_envs: 8
114
+ env_id: null
115
+ eval_params: {}
116
  n_timesteps: 500000
117
+ policy_hyperparams: {}
118
+ seed: 2
119
  use_deterministic_algorithms: true
120
  wandb_entity: null
121
+ wandb_group: null
122
  wandb_project_name: rl-algo-impls-benchmarks
123
  wandb_tags:
124
+ - benchmark_2067e21
125
+ - host_155-248-199-228
126
 
127
  ```
benchmark_publish.py CHANGED
@@ -1,107 +1,4 @@
1
- import argparse
2
- import subprocess
3
- import wandb
4
- import wandb.apis.public
5
-
6
- from collections import defaultdict
7
- from multiprocessing.pool import ThreadPool
8
- from typing import List, NamedTuple
9
-
10
-
11
- class RunGroup(NamedTuple):
12
- algo: str
13
- env_id: str
14
-
15
 
16
  if __name__ == "__main__":
17
- parser = argparse.ArgumentParser()
18
- parser.add_argument(
19
- "--wandb-project-name",
20
- type=str,
21
- default="rl-algo-impls-benchmarks",
22
- help="WandB project name to load runs from",
23
- )
24
- parser.add_argument(
25
- "--wandb-entity",
26
- type=str,
27
- default=None,
28
- help="WandB team of project. None uses default entity",
29
- )
30
- parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
31
- parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
32
- parser.add_argument(
33
- "--envs", type=str, nargs="*", help="Optional filter down to these envs"
34
- )
35
- parser.add_argument(
36
- "--exclude-envs",
37
- type=str,
38
- nargs="*",
39
- help="Environments to exclude from publishing",
40
- )
41
- parser.add_argument(
42
- "--huggingface-user",
43
- type=str,
44
- default=None,
45
- help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
46
- )
47
- parser.add_argument(
48
- "--pool-size",
49
- type=int,
50
- default=3,
51
- help="How many publish jobs can run in parallel",
52
- )
53
- parser.add_argument(
54
- "--virtual-display", action="store_true", help="Use headless virtual display"
55
- )
56
- # parser.set_defaults(
57
- # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
58
- # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
59
- # envs=[],
60
- # exclude_envs=[],
61
- # )
62
- args = parser.parse_args()
63
- print(args)
64
-
65
- api = wandb.Api()
66
- all_runs = api.runs(
67
- f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
68
- )
69
-
70
- required_tags = set(args.wandb_tags)
71
- runs: List[wandb.apis.public.Run] = [
72
- r
73
- for r in all_runs
74
- if required_tags.issubset(set(r.config.get("wandb_tags", [])))
75
- ]
76
-
77
- runs_paths_by_group = defaultdict(list)
78
- for r in runs:
79
- if r.state != "finished":
80
- continue
81
- algo = r.config["algo"]
82
- env = r.config["env"]
83
- if args.envs and env not in args.envs:
84
- continue
85
- if args.exclude_envs and env in args.exclude_envs:
86
- continue
87
- run_group = RunGroup(algo, env)
88
- runs_paths_by_group[run_group].append("/".join(r.path))
89
-
90
- def run(run_paths: List[str]) -> None:
91
- publish_args = ["python", "huggingface_publish.py"]
92
- publish_args.append("--wandb-run-paths")
93
- publish_args.extend(run_paths)
94
- publish_args.append("--wandb-report-url")
95
- publish_args.append(args.wandb_report_url)
96
- if args.huggingface_user:
97
- publish_args.append("--huggingface-user")
98
- publish_args.append(args.huggingface_user)
99
- if args.virtual_display:
100
- publish_args.append("--virtual-display")
101
- subprocess.run(publish_args)
102
-
103
- tp = ThreadPool(args.pool_size)
104
- for run_paths in runs_paths_by_group.values():
105
- tp.apply_async(run, (run_paths,))
106
- tp.close()
107
- tp.join()
 
1
+ from rl_algo_impls.benchmark_publish import benchmark_publish
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ benchmark_publish()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
colab/colab_atari1.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_atari2.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_basic.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_benchmark.ipynb ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "## Setup\n",
63
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
64
+ ],
65
+ "metadata": {
66
+ "id": "bsG35Io0hmKG"
67
+ }
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "%%capture\n",
73
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
74
+ ],
75
+ "metadata": {
76
+ "id": "k5ynTV25hdAf"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "Installing the correct packages:\n",
85
+ "\n",
86
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
87
+ ],
88
+ "metadata": {
89
+ "id": "jKxGok-ElYQ7"
90
+ }
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "source": [
95
+ "%%capture\n",
96
+ "!apt install python-opengl\n",
97
+ "!apt install ffmpeg\n",
98
+ "!apt install xvfb\n",
99
+ "!apt install swig"
100
+ ],
101
+ "metadata": {
102
+ "id": "nn6EETTc2Ewf"
103
+ },
104
+ "execution_count": null,
105
+ "outputs": []
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "source": [
110
+ "%%capture\n",
111
+ "%cd /content/rl-algo-impls\n",
112
+ "python -m pip install ."
113
+ ],
114
+ "metadata": {
115
+ "id": "AfZh9rH3yQii"
116
+ },
117
+ "execution_count": null,
118
+ "outputs": []
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "source": [
123
+ "## Run Once Per Runtime"
124
+ ],
125
+ "metadata": {
126
+ "id": "4o5HOLjc4wq7"
127
+ }
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "import wandb\n",
133
+ "wandb.login()"
134
+ ],
135
+ "metadata": {
136
+ "id": "PCXa5tdS2qFX"
137
+ },
138
+ "execution_count": null,
139
+ "outputs": []
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "source": [
144
+ "## Restart Session beteween runs"
145
+ ],
146
+ "metadata": {
147
+ "id": "AZBZfSUV43JQ"
148
+ }
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "source": [
153
+ "%%capture\n",
154
+ "from pyvirtualdisplay import Display\n",
155
+ "\n",
156
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
157
+ "virtual_display.start()"
158
+ ],
159
+ "metadata": {
160
+ "id": "VzemeQJP2NO9"
161
+ },
162
+ "execution_count": null,
163
+ "outputs": []
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "source": [
168
+ "The below 5 bash scripts train agents on environments with 3 seeds each:\n",
169
+ "- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
170
+ "- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
171
+ "- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
172
+ ],
173
+ "metadata": {
174
+ "id": "nSHfna0hLlO1"
175
+ }
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "source": [
180
+ "%cd /content/rl-algo-impls\n",
181
+ "os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
182
+ "!./benchmarks/colab_basic.sh\n",
183
+ "!./benchmarks/colab_pybullet.sh\n",
184
+ "# !./benchmarks/colab_carracing.sh\n",
185
+ "# !./benchmarks/colab_atari1.sh\n",
186
+ "# !./benchmarks/colab_atari2.sh"
187
+ ],
188
+ "metadata": {
189
+ "id": "07aHYFH1zfXa"
190
+ },
191
+ "execution_count": null,
192
+ "outputs": []
193
+ }
194
+ ]
195
+ }
colab/colab_carracing.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="CarRacing-v0"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_enjoy.ipynb ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. enjoy.py parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
72
+ ],
73
+ "metadata": {
74
+ "id": "jKL_NFhVOjSc"
75
+ },
76
+ "execution_count": 2,
77
+ "outputs": []
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "source": [
82
+ "## Setup\n",
83
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
84
+ ],
85
+ "metadata": {
86
+ "id": "bsG35Io0hmKG"
87
+ }
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "source": [
92
+ "%%capture\n",
93
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
94
+ ],
95
+ "metadata": {
96
+ "id": "k5ynTV25hdAf"
97
+ },
98
+ "execution_count": 3,
99
+ "outputs": []
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "source": [
104
+ "Installing the correct packages:\n",
105
+ "\n",
106
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
107
+ ],
108
+ "metadata": {
109
+ "id": "jKxGok-ElYQ7"
110
+ }
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "source": [
115
+ "%%capture\n",
116
+ "!apt install python-opengl\n",
117
+ "!apt install ffmpeg\n",
118
+ "!apt install xvfb\n",
119
+ "!apt install swig"
120
+ ],
121
+ "metadata": {
122
+ "id": "nn6EETTc2Ewf"
123
+ },
124
+ "execution_count": 4,
125
+ "outputs": []
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "source": [
130
+ "%%capture\n",
131
+ "%cd /content/rl-algo-impls\n",
132
+ "python -m pip install ."
133
+ ],
134
+ "metadata": {
135
+ "id": "AfZh9rH3yQii"
136
+ },
137
+ "execution_count": 5,
138
+ "outputs": []
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "source": [
143
+ "## Run Once Per Runtime"
144
+ ],
145
+ "metadata": {
146
+ "id": "4o5HOLjc4wq7"
147
+ }
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "source": [
152
+ "import wandb\n",
153
+ "wandb.login()"
154
+ ],
155
+ "metadata": {
156
+ "id": "PCXa5tdS2qFX"
157
+ },
158
+ "execution_count": null,
159
+ "outputs": []
160
+ },
161
+ {
162
+ "cell_type": "markdown",
163
+ "source": [
164
+ "## Restart Session beteween runs"
165
+ ],
166
+ "metadata": {
167
+ "id": "AZBZfSUV43JQ"
168
+ }
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "source": [
173
+ "%%capture\n",
174
+ "from pyvirtualdisplay import Display\n",
175
+ "\n",
176
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
177
+ "virtual_display.start()"
178
+ ],
179
+ "metadata": {
180
+ "id": "VzemeQJP2NO9"
181
+ },
182
+ "execution_count": 7,
183
+ "outputs": []
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "source": [
188
+ "%cd /content/rl-algo-impls\n",
189
+ "!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
190
+ ],
191
+ "metadata": {
192
+ "id": "07aHYFH1zfXa"
193
+ },
194
+ "execution_count": null,
195
+ "outputs": []
196
+ }
197
+ ]
198
+ }
colab/colab_pybullet.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ALGO="ppo"
2
+ ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
3
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
4
+ bash scripts/train_loop.sh -a $ALGO -e "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
colab/colab_train.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
9
+ "include_colab_link": true
10
+ },
11
+ "kernelspec": {
12
+ "name": "python3",
13
+ "display_name": "Python 3"
14
+ },
15
+ "language_info": {
16
+ "name": "python"
17
+ },
18
+ "gpuClass": "standard",
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
36
+ "## Parameters\n",
37
+ "\n",
38
+ "\n",
39
+ "1. Wandb\n",
40
+ "\n"
41
+ ],
42
+ "metadata": {
43
+ "id": "S-tXDWP8WTLc"
44
+ }
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "source": [
49
+ "from getpass import getpass\n",
50
+ "import os\n",
51
+ "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
52
+ ],
53
+ "metadata": {
54
+ "id": "1ZtdYgxWNGwZ"
55
+ },
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "markdown",
61
+ "source": [
62
+ "2. train run parameters"
63
+ ],
64
+ "metadata": {
65
+ "id": "ao0nAh3MOdN7"
66
+ }
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "source": [
71
+ "ALGO = \"ppo\"\n",
72
+ "ENV = \"CartPole-v1\"\n",
73
+ "SEED = 1"
74
+ ],
75
+ "metadata": {
76
+ "id": "jKL_NFhVOjSc"
77
+ },
78
+ "execution_count": null,
79
+ "outputs": []
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "source": [
84
+ "## Setup\n",
85
+ "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
86
+ ],
87
+ "metadata": {
88
+ "id": "bsG35Io0hmKG"
89
+ }
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "source": [
94
+ "%%capture\n",
95
+ "!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
96
+ ],
97
+ "metadata": {
98
+ "id": "k5ynTV25hdAf"
99
+ },
100
+ "execution_count": null,
101
+ "outputs": []
102
+ },
103
+ {
104
+ "cell_type": "markdown",
105
+ "source": [
106
+ "Installing the correct packages:\n",
107
+ "\n",
108
+ "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
109
+ ],
110
+ "metadata": {
111
+ "id": "jKxGok-ElYQ7"
112
+ }
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "source": [
117
+ "%%capture\n",
118
+ "!apt install python-opengl\n",
119
+ "!apt install ffmpeg\n",
120
+ "!apt install xvfb\n",
121
+ "!apt install swig"
122
+ ],
123
+ "metadata": {
124
+ "id": "nn6EETTc2Ewf"
125
+ },
126
+ "execution_count": null,
127
+ "outputs": []
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "source": [
132
+ "%%capture\n",
133
+ "%cd /content/rl-algo-impls\n",
134
+ "python -m pip install ."
135
+ ],
136
+ "metadata": {
137
+ "id": "AfZh9rH3yQii"
138
+ },
139
+ "execution_count": null,
140
+ "outputs": []
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "source": [
145
+ "## Run Once Per Runtime"
146
+ ],
147
+ "metadata": {
148
+ "id": "4o5HOLjc4wq7"
149
+ }
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "source": [
154
+ "import wandb\n",
155
+ "wandb.login()"
156
+ ],
157
+ "metadata": {
158
+ "id": "PCXa5tdS2qFX"
159
+ },
160
+ "execution_count": null,
161
+ "outputs": []
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "source": [
166
+ "## Restart Session beteween runs"
167
+ ],
168
+ "metadata": {
169
+ "id": "AZBZfSUV43JQ"
170
+ }
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "source": [
175
+ "%%capture\n",
176
+ "from pyvirtualdisplay import Display\n",
177
+ "\n",
178
+ "virtual_display = Display(visible=0, size=(1400, 900))\n",
179
+ "virtual_display.start()"
180
+ ],
181
+ "metadata": {
182
+ "id": "VzemeQJP2NO9"
183
+ },
184
+ "execution_count": null,
185
+ "outputs": []
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "source": [
190
+ "%cd /content/rl-algo-impls\n",
191
+ "!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
192
+ ],
193
+ "metadata": {
194
+ "id": "07aHYFH1zfXa"
195
+ },
196
+ "execution_count": null,
197
+ "outputs": []
198
+ }
199
+ ]
200
+ }
compare_runs.py CHANGED
@@ -1,187 +1,4 @@
1
- import argparse
2
- import itertools
3
- import numpy as np
4
- import pandas as pd
5
- import wandb
6
- import wandb.apis.public
7
-
8
- from collections import defaultdict
9
- from dataclasses import dataclass
10
- from typing import Dict, Iterable, List, TypeVar
11
-
12
- from benchmark_publish import RunGroup
13
-
14
-
15
- @dataclass
16
- class Comparison:
17
- control_values: List[float]
18
- experiment_values: List[float]
19
-
20
- def mean_diff_percentage(self) -> float:
21
- return self._diff_percentage(
22
- np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
23
- )
24
-
25
- def median_diff_percentage(self) -> float:
26
- return self._diff_percentage(
27
- np.median(self.control_values).item(),
28
- np.median(self.experiment_values).item(),
29
- )
30
-
31
- def _diff_percentage(self, c: float, e: float) -> float:
32
- if c == e:
33
- return 0
34
- elif c == 0:
35
- return float("inf") if e > 0 else float("-inf")
36
- return 100 * (e - c) / c
37
-
38
- def score(self) -> float:
39
- return (
40
- np.sum(
41
- np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
42
- ).item()
43
- / 2
44
- )
45
-
46
-
47
- RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
48
-
49
-
50
- class RunGroupRuns:
51
- def __init__(
52
- self,
53
- run_group: RunGroup,
54
- control: List[str],
55
- experiment: List[str],
56
- summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
57
- summary_metrics: List[str] = ["mean", "result"],
58
- ) -> None:
59
- self.algo = run_group.algo
60
- self.env = run_group.env_id
61
- self.control = set(control)
62
- self.experiment = set(experiment)
63
-
64
- self.summary_stats = summary_stats
65
- self.summary_metrics = summary_metrics
66
-
67
- self.control_runs = []
68
- self.experiment_runs = []
69
-
70
- def add_run(self, run: wandb.apis.public.Run) -> None:
71
- wandb_tags = set(run.config.get("wandb_tags", []))
72
- if self.control & wandb_tags:
73
- self.control_runs.append(run)
74
- elif self.experiment & wandb_tags:
75
- self.experiment_runs.append(run)
76
-
77
- def comparisons_by_metric(self) -> Dict[str, Comparison]:
78
- c_by_m = {}
79
- for metric in (
80
- f"{s}/{m}"
81
- for s, m in itertools.product(self.summary_stats, self.summary_metrics)
82
- ):
83
- c_by_m[metric] = Comparison(
84
- [c.summary[metric] for c in self.control_runs],
85
- [e.summary[metric] for e in self.experiment_runs],
86
- )
87
- return c_by_m
88
-
89
- @staticmethod
90
- def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
91
- results = defaultdict(list)
92
- for r in rows:
93
- if not r.control_runs or not r.experiment_runs:
94
- continue
95
- results["algo"].append(r.algo)
96
- results["env"].append(r.env)
97
- results["control"].append(r.control)
98
- results["expierment"].append(r.experiment)
99
- c_by_m = r.comparisons_by_metric()
100
- results["score"].append(
101
- sum(m.score() for m in c_by_m.values()) / len(c_by_m)
102
- )
103
- for m, c in c_by_m.items():
104
- results[f"{m}_mean"].append(c.mean_diff_percentage())
105
- results[f"{m}_median"].append(c.median_diff_percentage())
106
- return pd.DataFrame(results)
107
-
108
 
109
  if __name__ == "__main__":
110
- parser = argparse.ArgumentParser()
111
- parser.add_argument(
112
- "-p",
113
- "--wandb-project-name",
114
- type=str,
115
- default="rl-algo-impls-benchmarks",
116
- help="WandB project name to load runs from",
117
- )
118
- parser.add_argument(
119
- "--wandb-entity",
120
- type=str,
121
- default=None,
122
- help="WandB team. None uses default entity",
123
- )
124
- parser.add_argument(
125
- "-n",
126
- "--wandb-hostname-tag",
127
- type=str,
128
- nargs="*",
129
- help="WandB tags for hostname (i.e. host_192-9-145-26)",
130
- )
131
- parser.add_argument(
132
- "-c",
133
- "--wandb-control-tag",
134
- type=str,
135
- nargs="+",
136
- help="WandB tag for control commit (i.e. benchmark_5598ebc)",
137
- )
138
- parser.add_argument(
139
- "-e",
140
- "--wandb-experiment-tag",
141
- type=str,
142
- nargs="+",
143
- help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
144
- )
145
- parser.add_argument(
146
- "--exclude-envs",
147
- type=str,
148
- nargs="*",
149
- help="Environments to exclude from comparison",
150
- )
151
- # parser.set_defaults(
152
- # wandb_hostname_tag=["host_150-230-44-105", "host_155-248-214-128"],
153
- # wandb_control_tag=["benchmark_fbc943f"],
154
- # wandb_experiment_tag=["benchmark_f59bf74"],
155
- # exclude_envs=[],
156
- # )
157
- args = parser.parse_args()
158
- print(args)
159
-
160
- api = wandb.Api()
161
- all_runs = api.runs(
162
- path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
163
- order="+created_at",
164
- )
165
-
166
- runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
167
- wandb_hostname_tags = set(args.wandb_hostname_tag)
168
- for r in all_runs:
169
- if r.state != "finished":
170
- continue
171
- wandb_tags = set(r.config.get("wandb_tags", []))
172
- if not wandb_tags or not wandb_hostname_tags & wandb_tags:
173
- continue
174
- rg = RunGroup(r.config["algo"], r.config.get("env_id") or r.config["env"])
175
- if args.exclude_envs and rg.env_id in args.exclude_envs:
176
- continue
177
- if rg not in runs_by_run_group:
178
- runs_by_run_group[rg] = RunGroupRuns(
179
- rg,
180
- args.wandb_control_tag,
181
- args.wandb_experiment_tag,
182
- )
183
- runs_by_run_group[rg].add_run(r)
184
- df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
185
- print(f"**Total Score: {sum(df.score)}**")
186
- df.loc["mean"] = df.mean(numeric_only=True)
187
- print(df.to_markdown())
 
1
+ from rl_algo_impls.compare_runs import compare_runs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ compare_runs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
enjoy.py CHANGED
@@ -1,30 +1,4 @@
1
- # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
- import os
3
-
4
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
-
6
- from runner.evaluate import EvalArgs, evaluate_model
7
- from runner.running_utils import base_parser
8
-
9
 
10
  if __name__ == "__main__":
11
- parser = base_parser(multiple=False)
12
- parser.add_argument("--render", default=True, type=bool)
13
- parser.add_argument("--best", default=True, type=bool)
14
- parser.add_argument("--n_envs", default=1, type=int)
15
- parser.add_argument("--n_episodes", default=3, type=int)
16
- parser.add_argument("--deterministic-eval", default=None, type=bool)
17
- parser.add_argument(
18
- "--no-print-returns", action="store_true", help="Limit printing"
19
- )
20
- # wandb-run-path overrides base RunArgs
21
- parser.add_argument("--wandb-run-path", default=None, type=str)
22
- parser.set_defaults(
23
- algo=["ppo"],
24
- )
25
- args = parser.parse_args()
26
- args.algo = args.algo[0]
27
- args.env = args.env[0]
28
- args = EvalArgs(**vars(args))
29
-
30
- evaluate_model(args, os.path.dirname(__file__))
 
1
+ from rl_algo_impls.enjoy import enjoy
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ enjoy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
environment.yml CHANGED
@@ -4,14 +4,9 @@ channels:
4
  - conda-forge
5
  - nodefaults
6
  dependencies:
7
- - python=3.10.*
8
  - mamba
9
  - pip
10
- - poetry
11
  - pytorch
12
  - torchvision
13
  - torchaudio
14
- - cmake
15
- - swig
16
- - ipywidgets
17
- - black
 
4
  - conda-forge
5
  - nodefaults
6
  dependencies:
7
+ - python>=3.8, <3.11
8
  - mamba
9
  - pip
 
10
  - pytorch
11
  - torchvision
12
  - torchaudio
 
 
 
 
huggingface_publish.py CHANGED
@@ -1,189 +1,4 @@
1
- import os
2
-
3
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
-
5
- import argparse
6
- import requests
7
- import shutil
8
- import subprocess
9
- import tempfile
10
- import wandb
11
- import wandb.apis.public
12
-
13
- from typing import List, Optional
14
-
15
- from huggingface_hub.hf_api import HfApi, upload_folder
16
- from huggingface_hub.repocard import metadata_save
17
- from pyvirtualdisplay.display import Display
18
-
19
- from publish.markdown_format import EvalTableData, model_card_text
20
- from runner.config import EnvHyperparams
21
- from runner.evaluate import EvalArgs, evaluate_model
22
- from runner.env import make_eval_env
23
- from shared.callbacks.eval_callback import evaluate
24
- from wrappers.vec_episode_recorder import VecEpisodeRecorder
25
-
26
-
27
- def publish(
28
- wandb_run_paths: List[str],
29
- wandb_report_url: str,
30
- huggingface_user: Optional[str] = None,
31
- huggingface_token: Optional[str] = None,
32
- virtual_display: bool = False,
33
- ) -> None:
34
- if virtual_display:
35
- display = Display(visible=False, size=(1400, 900))
36
- display.start()
37
-
38
- api = wandb.Api()
39
- runs = [api.run(rp) for rp in wandb_run_paths]
40
- algo = runs[0].config["algo"]
41
- hyperparam_id = runs[0].config["env"]
42
- evaluations = [
43
- evaluate_model(
44
- EvalArgs(
45
- algo,
46
- hyperparam_id,
47
- seed=r.config.get("seed", None),
48
- render=False,
49
- best=True,
50
- n_envs=None,
51
- n_episodes=10,
52
- no_print_returns=True,
53
- wandb_run_path="/".join(r.path),
54
- ),
55
- os.path.dirname(__file__),
56
- )
57
- for r in runs
58
- ]
59
- run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
60
- table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
61
- best_eval = sorted(
62
- table_data, key=lambda d: d.evaluation.stats.score, reverse=True
63
- )[0]
64
-
65
- with tempfile.TemporaryDirectory() as tmpdirname:
66
- _, (policy, stats, config) = best_eval
67
-
68
- repo_name = config.model_name(include_seed=False)
69
- repo_dir_path = os.path.join(tmpdirname, repo_name)
70
- # Locally clone this repo to a temp directory
71
- subprocess.run(["git", "clone", ".", repo_dir_path])
72
- shutil.rmtree(os.path.join(repo_dir_path, ".git"))
73
- model_path = config.model_dir_path(best=True, downloaded=True)
74
- shutil.copytree(
75
- model_path,
76
- os.path.join(
77
- repo_dir_path, "saved_models", config.model_dir_name(best=True)
78
- ),
79
- )
80
-
81
- github_url = "https://github.com/sgoodfriend/rl-algo-impls"
82
- commit_hash = run_metadata.get("git", {}).get("commit", None)
83
- env_id = runs[0].config.get("env_id") or runs[0].config["env"]
84
- card_text = model_card_text(
85
- algo,
86
- env_id,
87
- github_url,
88
- commit_hash,
89
- wandb_report_url,
90
- table_data,
91
- best_eval,
92
- )
93
- readme_filepath = os.path.join(repo_dir_path, "README.md")
94
- os.remove(readme_filepath)
95
- with open(readme_filepath, "w") as f:
96
- f.write(card_text)
97
-
98
- metadata = {
99
- "library_name": "rl-algo-impls",
100
- "tags": [
101
- env_id,
102
- algo,
103
- "deep-reinforcement-learning",
104
- "reinforcement-learning",
105
- ],
106
- "model-index": [
107
- {
108
- "name": algo,
109
- "results": [
110
- {
111
- "metrics": [
112
- {
113
- "type": "mean_reward",
114
- "value": str(stats.score),
115
- "name": "mean_reward",
116
- }
117
- ],
118
- "task": {
119
- "type": "reinforcement-learning",
120
- "name": "reinforcement-learning",
121
- },
122
- "dataset": {
123
- "name": env_id,
124
- "type": env_id,
125
- },
126
- }
127
- ],
128
- }
129
- ],
130
- }
131
- metadata_save(readme_filepath, metadata)
132
-
133
- video_env = VecEpisodeRecorder(
134
- make_eval_env(
135
- config,
136
- EnvHyperparams(**config.env_hyperparams),
137
- override_n_envs=1,
138
- normalize_load_path=model_path,
139
- ),
140
- os.path.join(repo_dir_path, "replay"),
141
- max_video_length=3600,
142
- )
143
- evaluate(
144
- video_env,
145
- policy,
146
- 1,
147
- deterministic=config.eval_params.get("deterministic", True),
148
- )
149
-
150
- api = HfApi()
151
- huggingface_user = huggingface_user or api.whoami()["name"]
152
- huggingface_repo = f"{huggingface_user}/{repo_name}"
153
- api.create_repo(
154
- token=huggingface_token,
155
- repo_id=huggingface_repo,
156
- private=False,
157
- exist_ok=True,
158
- )
159
- repo_url = upload_folder(
160
- repo_id=huggingface_repo,
161
- folder_path=repo_dir_path,
162
- path_in_repo="",
163
- commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
164
- token=huggingface_token,
165
- )
166
- print(f"Pushed model to the hub: {repo_url}")
167
-
168
 
169
  if __name__ == "__main__":
170
- parser = argparse.ArgumentParser()
171
- parser.add_argument(
172
- "--wandb-run-paths",
173
- type=str,
174
- nargs="+",
175
- help="Run paths of the form entity/project/run_id",
176
- )
177
- parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
178
- parser.add_argument(
179
- "--huggingface-user",
180
- type=str,
181
- help="Huggingface user or team to upload model cards",
182
- default=None,
183
- )
184
- parser.add_argument(
185
- "--virtual-display", action="store_true", help="Use headless virtual display"
186
- )
187
- args = parser.parse_args()
188
- print(args)
189
- publish(**vars(args))
 
1
+ from rl_algo_impls.huggingface_publish import huggingface_publish
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ huggingface_publish()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
optimize.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from rl_algo_impls.optimize import optimize
2
+
3
+ if __name__ == "__main__":
4
+ optimize()
pyproject.toml CHANGED
@@ -1,35 +1,64 @@
1
- [tool.poetry]
2
- name = "rl-algo-impls"
3
- version = "0.1.0"
4
  description = "Implementations of reinforcement learning algorithms"
5
- authors = ["Scott Goodfriend <goodfriend.scott@gmail.com>"]
6
- license = "MIT License"
 
 
7
  readme = "README.md"
8
- packages = [{include = "rl_algo_impls"}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- [tool.poetry.dependencies]
11
- python = "~3.10"
12
- "AutoROM.accept-rom-license" = "^0.4.2"
13
- stable-baselines3 = {extras = ["extra"], version = "^1.7.0"}
14
- scipy = "^1.10.0"
15
- gym = {extras = ["box2d"], version = "^0.21.0"}
16
- pyglet = "1.5.27"
17
- PyYAML = "^6.0"
18
- tensorboard = "^2.11.0"
19
- pybullet = "^3.2.5"
20
- wandb = "^0.13.9"
21
- conda-lock = "^1.3.0"
22
- torch-tb-profiler = "^0.4.1"
23
- jupyter = "^1.0.0"
24
- tabulate = "^0.9.0"
25
- huggingface-hub = "^0.12.0"
26
- cryptography = "39.0.1"
27
- pyvirtualdisplay = "^3.0"
28
- numexpr = "^2.8.4"
29
- gym3 = "^0.3.3"
30
- glfw = "1.12.0"
31
- ipython = "^8.10.0"
32
 
33
  [build-system]
34
- requires = ["poetry-core"]
35
- build-backend = "poetry.core.masonry.api"
 
1
+ [project]
2
+ name = "rl_algo_impls"
3
+ version = "0.0.4"
4
  description = "Implementations of reinforcement learning algorithms"
5
+ authors = [
6
+ {name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"},
7
+ ]
8
+ license = {file = "LICENSE"}
9
  readme = "README.md"
10
+ requires-python = ">= 3.8"
11
+ classifiers = [
12
+ "License :: OSI Approved :: MIT License",
13
+ "Development Status :: 3 - Alpha",
14
+ "Programming Language :: Python :: 3.8",
15
+ "Programming Language :: Python :: 3.9",
16
+ "Programming Language :: Python :: 3.10",
17
+ ]
18
+ dependencies = [
19
+ "cmake",
20
+ "swig",
21
+ "scipy",
22
+ "torch",
23
+ "torchvision",
24
+ "tensorboard >= 2.11.2, < 2.12",
25
+ "AutoROM.accept-rom-license >= 0.4.2, < 0.5",
26
+ "stable-baselines3[extra] >= 1.7.0, < 1.8",
27
+ "gym[box2d] >= 0.21.0, < 0.22",
28
+ "pyglet == 1.5.27",
29
+ "wandb",
30
+ "pyvirtualdisplay",
31
+ "pybullet",
32
+ "tabulate",
33
+ "huggingface-hub",
34
+ "optuna",
35
+ "dash",
36
+ "kaleido",
37
+ "PyYAML",
38
+ ]
39
 
40
+ [tool.setuptools]
41
+ packages = ["rl_algo_impls"]
42
+
43
+ [project.optional-dependencies]
44
+ test = [
45
+ "pytest",
46
+ "black",
47
+ "mypy",
48
+ "flake8",
49
+ "flake8-bugbear",
50
+ "isort",
51
+ ]
52
+ procgen = [
53
+ "numexpr >= 2.8.4",
54
+ "gym3",
55
+ "glfw >= 1.12.0, < 1.13",
56
+ "procgen; platform_machine=='x86_64'",
57
+ ]
58
+
59
+ [project.urls]
60
+ "Homepage" = "https://github.com/sgoodfriend/rl-algo-impls"
 
61
 
62
  [build-system]
63
+ requires = ["setuptools==65.5.0", "setuptools-scm"]
64
+ build-backend = "setuptools.build_meta"
replay.meta.json CHANGED
@@ -1 +1 @@
1
- {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "1200x800", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmpgtamjs9u/a2c-CartPole-v1/replay.mp4"]}, "episode": {"r": 500.0, "l": 500, "t": 5.530564}}
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/tmp/tmp7v_hmh4f/a2c-CartPole-v1/replay.mp4"]}, "episode": {"r": 500.0, "l": 500, "t": 3.177716}}
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
rl_algo_impls/a2c/a2c.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from time import perf_counter
8
+ from torch.utils.tensorboard.writer import SummaryWriter
9
+ from typing import Optional, TypeVar
10
+
11
+ from rl_algo_impls.shared.algorithm import Algorithm
12
+ from rl_algo_impls.shared.callbacks.callback import Callback
13
+ from rl_algo_impls.shared.policy.on_policy import ActorCritic
14
+ from rl_algo_impls.shared.schedule import schedule, update_learning_rate
15
+ from rl_algo_impls.shared.stats import log_scalars
16
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
17
+ VecEnv,
18
+ single_observation_space,
19
+ single_action_space,
20
+ )
21
+
22
+ A2CSelf = TypeVar("A2CSelf", bound="A2C")
23
+
24
+
25
+ class A2C(Algorithm):
26
+ def __init__(
27
+ self,
28
+ policy: ActorCritic,
29
+ env: VecEnv,
30
+ device: torch.device,
31
+ tb_writer: SummaryWriter,
32
+ learning_rate: float = 7e-4,
33
+ learning_rate_decay: str = "none",
34
+ n_steps: int = 5,
35
+ gamma: float = 0.99,
36
+ gae_lambda: float = 1.0,
37
+ ent_coef: float = 0.0,
38
+ ent_coef_decay: str = "none",
39
+ vf_coef: float = 0.5,
40
+ max_grad_norm: float = 0.5,
41
+ rms_prop_eps: float = 1e-5,
42
+ use_rms_prop: bool = True,
43
+ sde_sample_freq: int = -1,
44
+ normalize_advantage: bool = False,
45
+ ) -> None:
46
+ super().__init__(policy, env, device, tb_writer)
47
+ self.policy = policy
48
+
49
+ self.lr_schedule = schedule(learning_rate_decay, learning_rate)
50
+ if use_rms_prop:
51
+ self.optimizer = torch.optim.RMSprop(
52
+ policy.parameters(), lr=learning_rate, eps=rms_prop_eps
53
+ )
54
+ else:
55
+ self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
56
+
57
+ self.n_steps = n_steps
58
+
59
+ self.gamma = gamma
60
+ self.gae_lambda = gae_lambda
61
+
62
+ self.vf_coef = vf_coef
63
+ self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
64
+ self.max_grad_norm = max_grad_norm
65
+
66
+ self.sde_sample_freq = sde_sample_freq
67
+ self.normalize_advantage = normalize_advantage
68
+
69
+ def learn(
70
+ self: A2CSelf,
71
+ train_timesteps: int,
72
+ callback: Optional[Callback] = None,
73
+ total_timesteps: Optional[int] = None,
74
+ start_timesteps: int = 0,
75
+ ) -> A2CSelf:
76
+ if total_timesteps is None:
77
+ total_timesteps = train_timesteps
78
+ assert start_timesteps + train_timesteps <= total_timesteps
79
+ epoch_dim = (self.n_steps, self.env.num_envs)
80
+ step_dim = (self.env.num_envs,)
81
+ obs_space = single_observation_space(self.env)
82
+ act_space = single_action_space(self.env)
83
+
84
+ obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
85
+ actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
86
+ rewards = np.zeros(epoch_dim, dtype=np.float32)
87
+ episode_starts = np.zeros(epoch_dim, dtype=np.byte)
88
+ values = np.zeros(epoch_dim, dtype=np.float32)
89
+ logprobs = np.zeros(epoch_dim, dtype=np.float32)
90
+
91
+ next_obs = self.env.reset()
92
+ next_episode_starts = np.ones(step_dim, dtype=np.byte)
93
+
94
+ timesteps_elapsed = start_timesteps
95
+ while timesteps_elapsed < start_timesteps + train_timesteps:
96
+ start_time = perf_counter()
97
+
98
+ progress = timesteps_elapsed / total_timesteps
99
+ ent_coef = self.ent_coef_schedule(progress)
100
+ learning_rate = self.lr_schedule(progress)
101
+ update_learning_rate(self.optimizer, learning_rate)
102
+ log_scalars(
103
+ self.tb_writer,
104
+ "charts",
105
+ {
106
+ "ent_coef": ent_coef,
107
+ "learning_rate": learning_rate,
108
+ },
109
+ timesteps_elapsed,
110
+ )
111
+
112
+ self.policy.eval()
113
+ self.policy.reset_noise()
114
+ for s in range(self.n_steps):
115
+ timesteps_elapsed += self.env.num_envs
116
+ if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
117
+ self.policy.reset_noise()
118
+
119
+ obs[s] = next_obs
120
+ episode_starts[s] = next_episode_starts
121
+
122
+ actions[s], values[s], logprobs[s], clamped_action = self.policy.step(
123
+ next_obs
124
+ )
125
+ next_obs, rewards[s], next_episode_starts, _ = self.env.step(
126
+ clamped_action
127
+ )
128
+
129
+ advantages = np.zeros(epoch_dim, dtype=np.float32)
130
+ last_gae_lam = 0
131
+ for t in reversed(range(self.n_steps)):
132
+ if t == self.n_steps - 1:
133
+ next_nonterminal = 1.0 - next_episode_starts
134
+ next_value = self.policy.value(next_obs)
135
+ else:
136
+ next_nonterminal = 1.0 - episode_starts[t + 1]
137
+ next_value = values[t + 1]
138
+ delta = (
139
+ rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
140
+ )
141
+ last_gae_lam = (
142
+ delta
143
+ + self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
144
+ )
145
+ advantages[t] = last_gae_lam
146
+ returns = advantages + values
147
+
148
+ b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
149
+ b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to(
150
+ self.device
151
+ )
152
+ b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
153
+ b_returns = torch.tensor(returns.reshape(-1)).to(self.device)
154
+
155
+ if self.normalize_advantage:
156
+ b_advantages = (b_advantages - b_advantages.mean()) / (
157
+ b_advantages.std() + 1e-8
158
+ )
159
+
160
+ self.policy.train()
161
+ logp_a, entropy, v = self.policy(b_obs, b_actions)
162
+
163
+ pi_loss = -(b_advantages * logp_a).mean()
164
+ value_loss = F.mse_loss(b_returns, v)
165
+ entropy_loss = -entropy.mean()
166
+
167
+ loss = pi_loss + self.vf_coef * value_loss + ent_coef * entropy_loss
168
+
169
+ self.optimizer.zero_grad()
170
+ loss.backward()
171
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
172
+ self.optimizer.step()
173
+
174
+ y_pred = values.reshape(-1)
175
+ y_true = returns.reshape(-1)
176
+ var_y = np.var(y_true).item()
177
+ explained_var = (
178
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
179
+ )
180
+
181
+ end_time = perf_counter()
182
+ rollout_steps = self.n_steps * self.env.num_envs
183
+ self.tb_writer.add_scalar(
184
+ "train/steps_per_second",
185
+ (rollout_steps) / (end_time - start_time),
186
+ timesteps_elapsed,
187
+ )
188
+
189
+ log_scalars(
190
+ self.tb_writer,
191
+ "losses",
192
+ {
193
+ "loss": loss.item(),
194
+ "pi_loss": pi_loss.item(),
195
+ "v_loss": value_loss.item(),
196
+ "entropy_loss": entropy_loss.item(),
197
+ "explained_var": explained_var,
198
+ },
199
+ timesteps_elapsed,
200
+ )
201
+
202
+ if callback:
203
+ if not callback.on_step(timesteps_elapsed=rollout_steps):
204
+ logging.info(
205
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
206
+ )
207
+ break
208
+
209
+ return self
rl_algo_impls/a2c/optimize.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+
3
+ from copy import deepcopy
4
+
5
+ from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
6
+ from rl_algo_impls.runner.env import make_eval_env
7
+ from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
8
+ from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
9
+
10
+
11
+ def sample_params(
12
+ trial: optuna.Trial,
13
+ base_hyperparams: Hyperparams,
14
+ base_config: Config,
15
+ ) -> Hyperparams:
16
+ hyperparams = deepcopy(base_hyperparams)
17
+
18
+ base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
19
+ env = make_eval_env(base_config, base_env_hyperparams, override_n_envs=1)
20
+
21
+ # env_hyperparams
22
+ env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
23
+
24
+ # policy_hyperparams
25
+ policy_hyperparams = sample_on_policy_hyperparams(
26
+ trial, hyperparams.policy_hyperparams, env
27
+ )
28
+
29
+ # algo_hyperparams
30
+ algo_hyperparams = hyperparams.algo_hyperparams
31
+
32
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 2e-3, log=True)
33
+ learning_rate_decay = trial.suggest_categorical(
34
+ "learning_rate_decay", ["none", "linear"]
35
+ )
36
+ n_steps_exp = trial.suggest_int("n_steps_exp", 1, 10)
37
+ n_steps = 2**n_steps_exp
38
+ trial.set_user_attr("n_steps", n_steps)
39
+ gamma = 1.0 - trial.suggest_float("gamma_om", 1e-4, 1e-1, log=True)
40
+ trial.set_user_attr("gamma", gamma)
41
+ gae_lambda = 1 - trial.suggest_float("gae_lambda_om", 1e-4, 1e-1)
42
+ trial.set_user_attr("gae_lambda", gae_lambda)
43
+ ent_coef = trial.suggest_float("ent_coef", 1e-8, 2.5e-2, log=True)
44
+ ent_coef_decay = trial.suggest_categorical("ent_coef_decay", ["none", "linear"])
45
+ vf_coef = trial.suggest_float("vf_coef", 0.1, 0.7)
46
+ max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e1, log=True)
47
+ use_rms_prop = trial.suggest_categorical("use_rms_prop", [True, False])
48
+ normalize_advantage = trial.suggest_categorical(
49
+ "normalize_advantage", [True, False]
50
+ )
51
+
52
+ algo_hyperparams.update(
53
+ {
54
+ "learning_rate": learning_rate,
55
+ "learning_rate_decay": learning_rate_decay,
56
+ "n_steps": n_steps,
57
+ "gamma": gamma,
58
+ "gae_lambda": gae_lambda,
59
+ "ent_coef": ent_coef,
60
+ "ent_coef_decay": ent_coef_decay,
61
+ "vf_coef": vf_coef,
62
+ "max_grad_norm": max_grad_norm,
63
+ "use_rms_prop": use_rms_prop,
64
+ "normalize_advantage": normalize_advantage,
65
+ }
66
+ )
67
+
68
+ if policy_hyperparams.get("use_sde", False):
69
+ sde_sample_freq = 2 ** trial.suggest_int("sde_sample_freq_exp", 0, n_steps_exp)
70
+ trial.set_user_attr("sde_sample_freq", sde_sample_freq)
71
+ algo_hyperparams["sde_sample_freq"] = sde_sample_freq
72
+ elif "sde_sample_freq" in algo_hyperparams:
73
+ del algo_hyperparams["sde_sample_freq"]
74
+
75
+ env.close()
76
+
77
+ return hyperparams
rl_algo_impls/benchmark_publish.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import wandb
4
+ import wandb.apis.public
5
+
6
+ from collections import defaultdict
7
+ from multiprocessing.pool import ThreadPool
8
+ from typing import List, NamedTuple
9
+
10
+
11
+ class RunGroup(NamedTuple):
12
+ algo: str
13
+ env_id: str
14
+
15
+
16
+ def benchmark_publish() -> None:
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--wandb-project-name",
20
+ type=str,
21
+ default="rl-algo-impls-benchmarks",
22
+ help="WandB project name to load runs from",
23
+ )
24
+ parser.add_argument(
25
+ "--wandb-entity",
26
+ type=str,
27
+ default=None,
28
+ help="WandB team of project. None uses default entity",
29
+ )
30
+ parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
31
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
32
+ parser.add_argument(
33
+ "--envs", type=str, nargs="*", help="Optional filter down to these envs"
34
+ )
35
+ parser.add_argument(
36
+ "--exclude-envs",
37
+ type=str,
38
+ nargs="*",
39
+ help="Environments to exclude from publishing",
40
+ )
41
+ parser.add_argument(
42
+ "--huggingface-user",
43
+ type=str,
44
+ default=None,
45
+ help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
46
+ )
47
+ parser.add_argument(
48
+ "--pool-size",
49
+ type=int,
50
+ default=3,
51
+ help="How many publish jobs can run in parallel",
52
+ )
53
+ parser.add_argument(
54
+ "--virtual-display", action="store_true", help="Use headless virtual display"
55
+ )
56
+ # parser.set_defaults(
57
+ # wandb_tags=["benchmark_e47a44c", "host_129-146-2-230"],
58
+ # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/v4wd7cp5",
59
+ # envs=[],
60
+ # exclude_envs=[],
61
+ # )
62
+ args = parser.parse_args()
63
+ print(args)
64
+
65
+ api = wandb.Api()
66
+ all_runs = api.runs(
67
+ f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
68
+ )
69
+
70
+ required_tags = set(args.wandb_tags)
71
+ runs: List[wandb.apis.public.Run] = [
72
+ r
73
+ for r in all_runs
74
+ if required_tags.issubset(set(r.config.get("wandb_tags", [])))
75
+ ]
76
+
77
+ runs_paths_by_group = defaultdict(list)
78
+ for r in runs:
79
+ if r.state != "finished":
80
+ continue
81
+ algo = r.config["algo"]
82
+ env = r.config["env"]
83
+ if args.envs and env not in args.envs:
84
+ continue
85
+ if args.exclude_envs and env in args.exclude_envs:
86
+ continue
87
+ run_group = RunGroup(algo, env)
88
+ runs_paths_by_group[run_group].append("/".join(r.path))
89
+
90
+ def run(run_paths: List[str]) -> None:
91
+ publish_args = ["python", "huggingface_publish.py"]
92
+ publish_args.append("--wandb-run-paths")
93
+ publish_args.extend(run_paths)
94
+ publish_args.append("--wandb-report-url")
95
+ publish_args.append(args.wandb_report_url)
96
+ if args.huggingface_user:
97
+ publish_args.append("--huggingface-user")
98
+ publish_args.append(args.huggingface_user)
99
+ if args.virtual_display:
100
+ publish_args.append("--virtual-display")
101
+ subprocess.run(publish_args)
102
+
103
+ tp = ThreadPool(args.pool_size)
104
+ for run_paths in runs_paths_by_group.values():
105
+ tp.apply_async(run, (run_paths,))
106
+ tp.close()
107
+ tp.join()
108
+
109
+
110
+ if __name__ == "__main__":
111
+ benchmark_publish()
rl_algo_impls/compare_runs.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import numpy as np
4
+ import pandas as pd
5
+ import wandb
6
+ import wandb.apis.public
7
+
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Iterable, List, TypeVar
11
+
12
+ from rl_algo_impls.benchmark_publish import RunGroup
13
+
14
+
15
+ @dataclass
16
+ class Comparison:
17
+ control_values: List[float]
18
+ experiment_values: List[float]
19
+
20
+ def mean_diff_percentage(self) -> float:
21
+ return self._diff_percentage(
22
+ np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
23
+ )
24
+
25
+ def median_diff_percentage(self) -> float:
26
+ return self._diff_percentage(
27
+ np.median(self.control_values).item(),
28
+ np.median(self.experiment_values).item(),
29
+ )
30
+
31
+ def _diff_percentage(self, c: float, e: float) -> float:
32
+ if c == e:
33
+ return 0
34
+ elif c == 0:
35
+ return float("inf") if e > 0 else float("-inf")
36
+ return 100 * (e - c) / c
37
+
38
+ def score(self) -> float:
39
+ return (
40
+ np.sum(
41
+ np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
42
+ ).item()
43
+ / 2
44
+ )
45
+
46
+
47
+ RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
48
+
49
+
50
+ class RunGroupRuns:
51
+ def __init__(
52
+ self,
53
+ run_group: RunGroup,
54
+ control: List[str],
55
+ experiment: List[str],
56
+ summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
57
+ summary_metrics: List[str] = ["mean", "result"],
58
+ ) -> None:
59
+ self.algo = run_group.algo
60
+ self.env = run_group.env_id
61
+ self.control = set(control)
62
+ self.experiment = set(experiment)
63
+
64
+ self.summary_stats = summary_stats
65
+ self.summary_metrics = summary_metrics
66
+
67
+ self.control_runs = []
68
+ self.experiment_runs = []
69
+
70
+ def add_run(self, run: wandb.apis.public.Run) -> None:
71
+ wandb_tags = set(run.config.get("wandb_tags", []))
72
+ if self.control & wandb_tags:
73
+ self.control_runs.append(run)
74
+ elif self.experiment & wandb_tags:
75
+ self.experiment_runs.append(run)
76
+
77
+ def comparisons_by_metric(self) -> Dict[str, Comparison]:
78
+ c_by_m = {}
79
+ for metric in (
80
+ f"{s}/{m}"
81
+ for s, m in itertools.product(self.summary_stats, self.summary_metrics)
82
+ ):
83
+ c_by_m[metric] = Comparison(
84
+ [c.summary[metric] for c in self.control_runs],
85
+ [e.summary[metric] for e in self.experiment_runs],
86
+ )
87
+ return c_by_m
88
+
89
+ @staticmethod
90
+ def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
91
+ results = defaultdict(list)
92
+ for r in rows:
93
+ if not r.control_runs or not r.experiment_runs:
94
+ continue
95
+ results["algo"].append(r.algo)
96
+ results["env"].append(r.env)
97
+ results["control"].append(r.control)
98
+ results["expierment"].append(r.experiment)
99
+ c_by_m = r.comparisons_by_metric()
100
+ results["score"].append(
101
+ sum(m.score() for m in c_by_m.values()) / len(c_by_m)
102
+ )
103
+ for m, c in c_by_m.items():
104
+ results[f"{m}_mean"].append(c.mean_diff_percentage())
105
+ results[f"{m}_median"].append(c.median_diff_percentage())
106
+ return pd.DataFrame(results)
107
+
108
+
109
+ def compare_runs() -> None:
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument(
112
+ "-p",
113
+ "--wandb-project-name",
114
+ type=str,
115
+ default="rl-algo-impls-benchmarks",
116
+ help="WandB project name to load runs from",
117
+ )
118
+ parser.add_argument(
119
+ "--wandb-entity",
120
+ type=str,
121
+ default=None,
122
+ help="WandB team. None uses default entity",
123
+ )
124
+ parser.add_argument(
125
+ "-n",
126
+ "--wandb-hostname-tag",
127
+ type=str,
128
+ nargs="*",
129
+ help="WandB tags for hostname (i.e. host_192-9-145-26)",
130
+ )
131
+ parser.add_argument(
132
+ "-c",
133
+ "--wandb-control-tag",
134
+ type=str,
135
+ nargs="+",
136
+ help="WandB tag for control commit (i.e. benchmark_5598ebc)",
137
+ )
138
+ parser.add_argument(
139
+ "-e",
140
+ "--wandb-experiment-tag",
141
+ type=str,
142
+ nargs="+",
143
+ help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
144
+ )
145
+ parser.add_argument(
146
+ "--envs",
147
+ type=str,
148
+ nargs="*",
149
+ help="If specified, only compare these envs",
150
+ )
151
+ parser.add_argument(
152
+ "--exclude-envs",
153
+ type=str,
154
+ nargs="*",
155
+ help="Environments to exclude from comparison",
156
+ )
157
+ # parser.set_defaults(
158
+ # wandb_hostname_tag=["host_150-230-44-105", "host_155-248-214-128"],
159
+ # wandb_control_tag=["benchmark_fbc943f"],
160
+ # wandb_experiment_tag=["benchmark_f59bf74"],
161
+ # exclude_envs=[],
162
+ # )
163
+ args = parser.parse_args()
164
+ print(args)
165
+
166
+ api = wandb.Api()
167
+ all_runs = api.runs(
168
+ path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
169
+ order="+created_at",
170
+ )
171
+
172
+ runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
173
+ wandb_hostname_tags = set(args.wandb_hostname_tag)
174
+ for r in all_runs:
175
+ if r.state != "finished":
176
+ continue
177
+ wandb_tags = set(r.config.get("wandb_tags", []))
178
+ if not wandb_tags or not wandb_hostname_tags & wandb_tags:
179
+ continue
180
+ rg = RunGroup(r.config["algo"], r.config.get("env_id") or r.config["env"])
181
+ if args.exclude_envs and rg.env_id in args.exclude_envs:
182
+ continue
183
+ if args.envs and rg.env_id not in args.envs:
184
+ continue
185
+ if rg not in runs_by_run_group:
186
+ runs_by_run_group[rg] = RunGroupRuns(
187
+ rg,
188
+ args.wandb_control_tag,
189
+ args.wandb_experiment_tag,
190
+ )
191
+ runs_by_run_group[rg].add_run(r)
192
+ df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
193
+ print(f"**Total Score: {sum(df.score)}**")
194
+ df.loc["mean"] = df.mean(numeric_only=True)
195
+ print(df.to_markdown())
196
+
197
+ if __name__ == "__main__":
198
+ compare_runs()
rl_algo_impls/dqn/dqn.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from collections import deque
9
+ from torch.optim import Adam
10
+ from torch.utils.tensorboard.writer import SummaryWriter
11
+ from typing import NamedTuple, Optional, TypeVar
12
+
13
+ from rl_algo_impls.dqn.policy import DQNPolicy
14
+ from rl_algo_impls.shared.algorithm import Algorithm
15
+ from rl_algo_impls.shared.callbacks.callback import Callback
16
+ from rl_algo_impls.shared.schedule import linear_schedule
17
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
18
+
19
+
20
+ class Transition(NamedTuple):
21
+ obs: np.ndarray
22
+ action: np.ndarray
23
+ reward: float
24
+ done: bool
25
+ next_obs: np.ndarray
26
+
27
+
28
+ class Batch(NamedTuple):
29
+ obs: np.ndarray
30
+ actions: np.ndarray
31
+ rewards: np.ndarray
32
+ dones: np.ndarray
33
+ next_obs: np.ndarray
34
+
35
+
36
+ class ReplayBuffer:
37
+ def __init__(self, num_envs: int, maxlen: int) -> None:
38
+ self.num_envs = num_envs
39
+ self.buffer = deque(maxlen=maxlen)
40
+
41
+ def add(
42
+ self,
43
+ obs: VecEnvObs,
44
+ action: np.ndarray,
45
+ reward: np.ndarray,
46
+ done: np.ndarray,
47
+ next_obs: VecEnvObs,
48
+ ) -> None:
49
+ assert isinstance(obs, np.ndarray)
50
+ assert isinstance(next_obs, np.ndarray)
51
+ for i in range(self.num_envs):
52
+ self.buffer.append(
53
+ Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
54
+ )
55
+
56
+ def sample(self, batch_size: int) -> Batch:
57
+ ts = random.sample(self.buffer, batch_size)
58
+ return Batch(
59
+ obs=np.array([t.obs for t in ts]),
60
+ actions=np.array([t.action for t in ts]),
61
+ rewards=np.array([t.reward for t in ts]),
62
+ dones=np.array([t.done for t in ts]),
63
+ next_obs=np.array([t.next_obs for t in ts]),
64
+ )
65
+
66
+ def __len__(self) -> int:
67
+ return len(self.buffer)
68
+
69
+
70
+ DQNSelf = TypeVar("DQNSelf", bound="DQN")
71
+
72
+
73
+ class DQN(Algorithm):
74
+ def __init__(
75
+ self,
76
+ policy: DQNPolicy,
77
+ env: VecEnv,
78
+ device: torch.device,
79
+ tb_writer: SummaryWriter,
80
+ learning_rate: float = 1e-4,
81
+ buffer_size: int = 1_000_000,
82
+ learning_starts: int = 50_000,
83
+ batch_size: int = 32,
84
+ tau: float = 1.0,
85
+ gamma: float = 0.99,
86
+ train_freq: int = 4,
87
+ gradient_steps: int = 1,
88
+ target_update_interval: int = 10_000,
89
+ exploration_fraction: float = 0.1,
90
+ exploration_initial_eps: float = 1.0,
91
+ exploration_final_eps: float = 0.05,
92
+ max_grad_norm: float = 10.0,
93
+ ) -> None:
94
+ super().__init__(policy, env, device, tb_writer)
95
+ self.policy = policy
96
+
97
+ self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
98
+
99
+ self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
100
+ self.target_q_net.train(False)
101
+ self.tau = tau
102
+ self.target_update_interval = target_update_interval
103
+
104
+ self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
105
+ self.batch_size = batch_size
106
+
107
+ self.learning_starts = learning_starts
108
+ self.train_freq = train_freq
109
+ self.gradient_steps = gradient_steps
110
+
111
+ self.gamma = gamma
112
+ self.exploration_eps_schedule = linear_schedule(
113
+ exploration_initial_eps,
114
+ exploration_final_eps,
115
+ end_fraction=exploration_fraction,
116
+ )
117
+
118
+ self.max_grad_norm = max_grad_norm
119
+
120
+ def learn(
121
+ self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
122
+ ) -> DQNSelf:
123
+ self.policy.train(True)
124
+ obs = self.env.reset()
125
+ obs = self._collect_rollout(self.learning_starts, obs, 1)
126
+ learning_steps = total_timesteps - self.learning_starts
127
+ timesteps_elapsed = 0
128
+ steps_since_target_update = 0
129
+ while timesteps_elapsed < learning_steps:
130
+ progress = timesteps_elapsed / learning_steps
131
+ eps = self.exploration_eps_schedule(progress)
132
+ obs = self._collect_rollout(self.train_freq, obs, eps)
133
+ rollout_steps = self.train_freq
134
+ timesteps_elapsed += rollout_steps
135
+ for _ in range(
136
+ self.gradient_steps if self.gradient_steps > 0 else self.train_freq
137
+ ):
138
+ self.train()
139
+ steps_since_target_update += rollout_steps
140
+ if steps_since_target_update >= self.target_update_interval:
141
+ self._update_target()
142
+ steps_since_target_update = 0
143
+ if callback:
144
+ callback.on_step(timesteps_elapsed=rollout_steps)
145
+ return self
146
+
147
+ def train(self) -> None:
148
+ if len(self.replay_buffer) < self.batch_size:
149
+ return
150
+ o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
151
+ o = torch.as_tensor(o, device=self.device)
152
+ a = torch.as_tensor(a, device=self.device).unsqueeze(1)
153
+ r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
154
+ d = torch.as_tensor(d, dtype=torch.long, device=self.device)
155
+ next_o = torch.as_tensor(next_o, device=self.device)
156
+
157
+ with torch.no_grad():
158
+ target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
159
+ current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
160
+ loss = F.smooth_l1_loss(current, target)
161
+
162
+ self.optimizer.zero_grad()
163
+ loss.backward()
164
+ if self.max_grad_norm:
165
+ nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
166
+ self.optimizer.step()
167
+
168
+ def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
169
+ for _ in range(0, timesteps, self.env.num_envs):
170
+ action = self.policy.act(obs, eps, deterministic=False)
171
+ next_obs, reward, done, _ = self.env.step(action)
172
+ self.replay_buffer.add(obs, action, reward, done, next_obs)
173
+ obs = next_obs
174
+ return obs
175
+
176
+ def _update_target(self) -> None:
177
+ for target_param, param in zip(
178
+ self.target_q_net.parameters(), self.policy.q_net.parameters()
179
+ ):
180
+ target_param.data.copy_(
181
+ self.tau * param.data + (1 - self.tau) * target_param.data
182
+ )
rl_algo_impls/dqn/policy.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+
5
+ from typing import Optional, Sequence, TypeVar
6
+
7
+ from rl_algo_impls.dqn.q_net import QNetwork
8
+ from rl_algo_impls.shared.policy.policy import Policy
9
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
10
+ VecEnv,
11
+ VecEnvObs,
12
+ single_observation_space,
13
+ single_action_space,
14
+ )
15
+
16
+ DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
17
+
18
+
19
+ class DQNPolicy(Policy):
20
+ def __init__(
21
+ self,
22
+ env: VecEnv,
23
+ hidden_sizes: Sequence[int] = [],
24
+ cnn_feature_dim: int = 512,
25
+ cnn_style: str = "nature",
26
+ cnn_layers_init_orthogonal: Optional[bool] = None,
27
+ impala_channels: Sequence[int] = (16, 32, 32),
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(env, **kwargs)
31
+ self.q_net = QNetwork(
32
+ single_observation_space(env),
33
+ single_action_space(env),
34
+ hidden_sizes,
35
+ cnn_feature_dim=cnn_feature_dim,
36
+ cnn_style=cnn_style,
37
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
38
+ impala_channels=impala_channels,
39
+ )
40
+
41
+ def act(
42
+ self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
43
+ ) -> np.ndarray:
44
+ assert eps == 0 if deterministic else eps >= 0
45
+ if not deterministic and np.random.random() < eps:
46
+ return np.array(
47
+ [
48
+ single_action_space(self.env).sample()
49
+ for _ in range(self.env.num_envs)
50
+ ]
51
+ )
52
+ else:
53
+ o = self._as_tensor(obs)
54
+ with torch.no_grad():
55
+ return self.q_net(o).argmax(axis=1).cpu().numpy()
rl_algo_impls/dqn/q_net.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch as th
3
+ import torch.nn as nn
4
+
5
+ from gym.spaces import Discrete
6
+ from typing import Optional, Sequence, Type
7
+
8
+ from rl_algo_impls.shared.module.feature_extractor import FeatureExtractor
9
+ from rl_algo_impls.shared.module.module import mlp
10
+
11
+
12
+ class QNetwork(nn.Module):
13
+ def __init__(
14
+ self,
15
+ observation_space: gym.Space,
16
+ action_space: gym.Space,
17
+ hidden_sizes: Sequence[int] = [],
18
+ activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
19
+ cnn_feature_dim: int = 512,
20
+ cnn_style: str = "nature",
21
+ cnn_layers_init_orthogonal: Optional[bool] = None,
22
+ impala_channels: Sequence[int] = (16, 32, 32),
23
+ ) -> None:
24
+ super().__init__()
25
+ assert isinstance(action_space, Discrete)
26
+ self._feature_extractor = FeatureExtractor(
27
+ observation_space,
28
+ activation,
29
+ cnn_feature_dim=cnn_feature_dim,
30
+ cnn_style=cnn_style,
31
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
32
+ impala_channels=impala_channels,
33
+ )
34
+ layer_sizes = (
35
+ (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
36
+ )
37
+ self._fc = mlp(layer_sizes, activation)
38
+
39
+ def forward(self, obs: th.Tensor) -> th.Tensor:
40
+ x = self._feature_extractor(obs)
41
+ return self._fc(x)
rl_algo_impls/enjoy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
7
+ from rl_algo_impls.runner.running_utils import base_parser
8
+
9
+
10
+ def enjoy() -> None:
11
+ parser = base_parser(multiple=False)
12
+ parser.add_argument("--render", default=True, type=bool)
13
+ parser.add_argument("--best", default=True, type=bool)
14
+ parser.add_argument("--n_envs", default=1, type=int)
15
+ parser.add_argument("--n_episodes", default=3, type=int)
16
+ parser.add_argument("--deterministic-eval", default=None, type=bool)
17
+ parser.add_argument(
18
+ "--no-print-returns", action="store_true", help="Limit printing"
19
+ )
20
+ # wandb-run-path overrides base RunArgs
21
+ parser.add_argument("--wandb-run-path", default=None, type=str)
22
+ parser.set_defaults(
23
+ algo=["ppo"],
24
+ wandb_run_path="sgoodfriend/rl-algo-impls/m5c1t7g5",
25
+ )
26
+ args = parser.parse_args()
27
+ args.algo = args.algo[0]
28
+ args.env = args.env[0]
29
+ args = EvalArgs(**vars(args))
30
+
31
+ evaluate_model(args, os.getcwd())
32
+
33
+
34
+ if __name__ == "__main__":
35
+ enjoy()
rl_algo_impls/huggingface_publish.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
+
5
+ import argparse
6
+ import requests
7
+ import shutil
8
+ import subprocess
9
+ import tempfile
10
+ import wandb
11
+ import wandb.apis.public
12
+
13
+ from typing import List, Optional
14
+
15
+ from huggingface_hub.hf_api import HfApi, upload_folder
16
+ from huggingface_hub.repocard import metadata_save
17
+ from pyvirtualdisplay.display import Display
18
+
19
+ from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
20
+ from rl_algo_impls.runner.config import EnvHyperparams
21
+ from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
22
+ from rl_algo_impls.runner.env import make_eval_env
23
+ from rl_algo_impls.shared.callbacks.eval_callback import evaluate
24
+ from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
25
+
26
+
27
+ def publish(
28
+ wandb_run_paths: List[str],
29
+ wandb_report_url: str,
30
+ huggingface_user: Optional[str] = None,
31
+ huggingface_token: Optional[str] = None,
32
+ virtual_display: bool = False,
33
+ ) -> None:
34
+ if virtual_display:
35
+ display = Display(visible=False, size=(1400, 900))
36
+ display.start()
37
+
38
+ api = wandb.Api()
39
+ runs = [api.run(rp) for rp in wandb_run_paths]
40
+ algo = runs[0].config["algo"]
41
+ hyperparam_id = runs[0].config["env"]
42
+ evaluations = [
43
+ evaluate_model(
44
+ EvalArgs(
45
+ algo,
46
+ hyperparam_id,
47
+ seed=r.config.get("seed", None),
48
+ render=False,
49
+ best=True,
50
+ n_envs=None,
51
+ n_episodes=10,
52
+ no_print_returns=True,
53
+ wandb_run_path="/".join(r.path),
54
+ ),
55
+ os.getcwd(),
56
+ )
57
+ for r in runs
58
+ ]
59
+ run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
60
+ table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
61
+ best_eval = sorted(
62
+ table_data, key=lambda d: d.evaluation.stats.score, reverse=True
63
+ )[0]
64
+
65
+ with tempfile.TemporaryDirectory() as tmpdirname:
66
+ _, (policy, stats, config) = best_eval
67
+
68
+ repo_name = config.model_name(include_seed=False)
69
+ repo_dir_path = os.path.join(tmpdirname, repo_name)
70
+ # Locally clone this repo to a temp directory
71
+ subprocess.run(["git", "clone", ".", repo_dir_path])
72
+ shutil.rmtree(os.path.join(repo_dir_path, ".git"))
73
+ model_path = config.model_dir_path(best=True, downloaded=True)
74
+ shutil.copytree(
75
+ model_path,
76
+ os.path.join(
77
+ repo_dir_path, "saved_models", config.model_dir_name(best=True)
78
+ ),
79
+ )
80
+
81
+ github_url = "https://github.com/sgoodfriend/rl-algo-impls"
82
+ commit_hash = run_metadata.get("git", {}).get("commit", None)
83
+ env_id = runs[0].config.get("env_id") or runs[0].config["env"]
84
+ card_text = model_card_text(
85
+ algo,
86
+ env_id,
87
+ github_url,
88
+ commit_hash,
89
+ wandb_report_url,
90
+ table_data,
91
+ best_eval,
92
+ )
93
+ readme_filepath = os.path.join(repo_dir_path, "README.md")
94
+ os.remove(readme_filepath)
95
+ with open(readme_filepath, "w") as f:
96
+ f.write(card_text)
97
+
98
+ metadata = {
99
+ "library_name": "rl-algo-impls",
100
+ "tags": [
101
+ env_id,
102
+ algo,
103
+ "deep-reinforcement-learning",
104
+ "reinforcement-learning",
105
+ ],
106
+ "model-index": [
107
+ {
108
+ "name": algo,
109
+ "results": [
110
+ {
111
+ "metrics": [
112
+ {
113
+ "type": "mean_reward",
114
+ "value": str(stats.score),
115
+ "name": "mean_reward",
116
+ }
117
+ ],
118
+ "task": {
119
+ "type": "reinforcement-learning",
120
+ "name": "reinforcement-learning",
121
+ },
122
+ "dataset": {
123
+ "name": env_id,
124
+ "type": env_id,
125
+ },
126
+ }
127
+ ],
128
+ }
129
+ ],
130
+ }
131
+ metadata_save(readme_filepath, metadata)
132
+
133
+ video_env = VecEpisodeRecorder(
134
+ make_eval_env(
135
+ config,
136
+ EnvHyperparams(**config.env_hyperparams),
137
+ override_n_envs=1,
138
+ normalize_load_path=model_path,
139
+ ),
140
+ os.path.join(repo_dir_path, "replay"),
141
+ max_video_length=3600,
142
+ )
143
+ evaluate(
144
+ video_env,
145
+ policy,
146
+ 1,
147
+ deterministic=config.eval_params.get("deterministic", True),
148
+ )
149
+
150
+ api = HfApi()
151
+ huggingface_user = huggingface_user or api.whoami()["name"]
152
+ huggingface_repo = f"{huggingface_user}/{repo_name}"
153
+ api.create_repo(
154
+ token=huggingface_token,
155
+ repo_id=huggingface_repo,
156
+ private=False,
157
+ exist_ok=True,
158
+ )
159
+ repo_url = upload_folder(
160
+ repo_id=huggingface_repo,
161
+ folder_path=repo_dir_path,
162
+ path_in_repo="",
163
+ commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
164
+ token=huggingface_token,
165
+ )
166
+ print(f"Pushed model to the hub: {repo_url}")
167
+
168
+
169
+ def huggingface_publish():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument(
172
+ "--wandb-run-paths",
173
+ type=str,
174
+ nargs="+",
175
+ help="Run paths of the form entity/project/run_id",
176
+ )
177
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
178
+ parser.add_argument(
179
+ "--huggingface-user",
180
+ type=str,
181
+ help="Huggingface user or team to upload model cards",
182
+ default=None,
183
+ )
184
+ parser.add_argument(
185
+ "--virtual-display", action="store_true", help="Use headless virtual display"
186
+ )
187
+ args = parser.parse_args()
188
+ print(args)
189
+ publish(**vars(args))
190
+
191
+
192
+ if __name__ == "__main__":
193
+ huggingface_publish()
rl_algo_impls/hyperparams/a2c.yml ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 5e5
3
+ env_hyperparams:
4
+ n_envs: 8
5
+
6
+ CartPole-v0:
7
+ <<: *cartpole-defaults
8
+
9
+ MountainCar-v0:
10
+ n_timesteps: !!float 1e6
11
+ env_hyperparams:
12
+ n_envs: 16
13
+ normalize: true
14
+
15
+ MountainCarContinuous-v0:
16
+ n_timesteps: !!float 1e5
17
+ env_hyperparams:
18
+ n_envs: 4
19
+ normalize: true
20
+ # policy_hyperparams:
21
+ # use_sde: true
22
+ # log_std_init: 0.0
23
+ # init_layers_orthogonal: false
24
+ algo_hyperparams:
25
+ n_steps: 100
26
+ sde_sample_freq: 16
27
+
28
+ Acrobot-v1:
29
+ n_timesteps: !!float 5e5
30
+ env_hyperparams:
31
+ normalize: true
32
+ n_envs: 16
33
+
34
+ # Tuned
35
+ LunarLander-v2:
36
+ device: cpu
37
+ n_timesteps: !!float 1e6
38
+ env_hyperparams:
39
+ n_envs: 4
40
+ normalize: true
41
+ algo_hyperparams:
42
+ n_steps: 2
43
+ gamma: 0.9955517404308908
44
+ gae_lambda: 0.9875340918797773
45
+ learning_rate: 0.0013814130817068916
46
+ learning_rate_decay: linear
47
+ ent_coef: !!float 3.388369146384422e-7
48
+ ent_coef_decay: none
49
+ max_grad_norm: 3.33982095073364
50
+ normalize_advantage: true
51
+ vf_coef: 0.1667838310548184
52
+
53
+ BipedalWalker-v3:
54
+ n_timesteps: !!float 5e6
55
+ env_hyperparams:
56
+ n_envs: 16
57
+ normalize: true
58
+ policy_hyperparams:
59
+ use_sde: true
60
+ log_std_init: -2
61
+ init_layers_orthogonal: false
62
+ algo_hyperparams:
63
+ ent_coef: 0
64
+ max_grad_norm: 0.5
65
+ n_steps: 8
66
+ gae_lambda: 0.9
67
+ vf_coef: 0.4
68
+ gamma: 0.99
69
+ learning_rate: !!float 9.6e-4
70
+ learning_rate_decay: linear
71
+
72
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
73
+ n_timesteps: !!float 2e6
74
+ env_hyperparams:
75
+ n_envs: 4
76
+ normalize: true
77
+ policy_hyperparams:
78
+ use_sde: true
79
+ log_std_init: -2
80
+ init_layers_orthogonal: false
81
+ algo_hyperparams: &pybullet-algo-defaults
82
+ n_steps: 8
83
+ ent_coef: 0
84
+ max_grad_norm: 0.5
85
+ gae_lambda: 0.9
86
+ gamma: 0.99
87
+ vf_coef: 0.4
88
+ learning_rate: !!float 9.6e-4
89
+ learning_rate_decay: linear
90
+
91
+ AntBulletEnv-v0:
92
+ <<: *pybullet-defaults
93
+
94
+ Walker2DBulletEnv-v0:
95
+ <<: *pybullet-defaults
96
+
97
+ HopperBulletEnv-v0:
98
+ <<: *pybullet-defaults
99
+
100
+ CarRacing-v0:
101
+ n_timesteps: !!float 4e6
102
+ env_hyperparams:
103
+ n_envs: 8
104
+ frame_stack: 4
105
+ normalize: true
106
+ normalize_kwargs:
107
+ norm_obs: false
108
+ norm_reward: true
109
+ policy_hyperparams:
110
+ use_sde: true
111
+ log_std_init: -2
112
+ init_layers_orthogonal: false
113
+ activation_fn: relu
114
+ share_features_extractor: false
115
+ cnn_feature_dim: 256
116
+ hidden_sizes: [256]
117
+ algo_hyperparams:
118
+ n_steps: 512
119
+ learning_rate: !!float 1.62e-5
120
+ gamma: 0.997
121
+ gae_lambda: 0.975
122
+ ent_coef: 0
123
+ sde_sample_freq: 128
124
+ vf_coef: 0.64
125
+
126
+ _atari: &atari-defaults
127
+ n_timesteps: !!float 1e7
128
+ env_hyperparams: &atari-env-defaults
129
+ n_envs: 16
130
+ frame_stack: 4
131
+ no_reward_timeout_steps: 1000
132
+ no_reward_fire_steps: 500
133
+ vec_env_class: async
134
+ policy_hyperparams: &atari-policy-defaults
135
+ activation_fn: relu
136
+ algo_hyperparams:
137
+ ent_coef: 0.01
138
+ vf_coef: 0.25
rl_algo_impls/hyperparams/dqn.yml ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 5e4
3
+ env_hyperparams:
4
+ rolling_length: 50
5
+ policy_hyperparams:
6
+ hidden_sizes: [256, 256]
7
+ algo_hyperparams:
8
+ learning_rate: !!float 2.3e-3
9
+ batch_size: 64
10
+ buffer_size: 100000
11
+ learning_starts: 1000
12
+ gamma: 0.99
13
+ target_update_interval: 10
14
+ train_freq: 256
15
+ gradient_steps: 128
16
+ exploration_fraction: 0.16
17
+ exploration_final_eps: 0.04
18
+ eval_params:
19
+ step_freq: !!float 1e4
20
+
21
+ CartPole-v0:
22
+ <<: *cartpole-defaults
23
+ n_timesteps: !!float 4e4
24
+
25
+ MountainCar-v0:
26
+ n_timesteps: !!float 1.2e5
27
+ env_hyperparams:
28
+ rolling_length: 50
29
+ policy_hyperparams:
30
+ hidden_sizes: [256, 256]
31
+ algo_hyperparams:
32
+ learning_rate: !!float 4e-3
33
+ batch_size: 128
34
+ buffer_size: 10000
35
+ learning_starts: 1000
36
+ gamma: 0.98
37
+ target_update_interval: 600
38
+ train_freq: 16
39
+ gradient_steps: 8
40
+ exploration_fraction: 0.2
41
+ exploration_final_eps: 0.07
42
+
43
+ Acrobot-v1:
44
+ n_timesteps: !!float 1e5
45
+ env_hyperparams:
46
+ rolling_length: 50
47
+ policy_hyperparams:
48
+ hidden_sizes: [256, 256]
49
+ algo_hyperparams:
50
+ learning_rate: !!float 6.3e-4
51
+ batch_size: 128
52
+ buffer_size: 50000
53
+ learning_starts: 0
54
+ gamma: 0.99
55
+ target_update_interval: 250
56
+ train_freq: 4
57
+ gradient_steps: -1
58
+ exploration_fraction: 0.12
59
+ exploration_final_eps: 0.1
60
+
61
+ LunarLander-v2:
62
+ n_timesteps: !!float 5e5
63
+ env_hyperparams:
64
+ rolling_length: 50
65
+ policy_hyperparams:
66
+ hidden_sizes: [256, 256]
67
+ algo_hyperparams:
68
+ learning_rate: !!float 1e-4
69
+ batch_size: 256
70
+ buffer_size: 100000
71
+ learning_starts: 10000
72
+ gamma: 0.99
73
+ target_update_interval: 250
74
+ train_freq: 8
75
+ gradient_steps: -1
76
+ exploration_fraction: 0.12
77
+ exploration_final_eps: 0.1
78
+ max_grad_norm: 0.5
79
+ eval_params:
80
+ step_freq: 25_000
81
+
82
+ _atari: &atari-defaults
83
+ n_timesteps: !!float 1e7
84
+ env_hyperparams:
85
+ frame_stack: 4
86
+ no_reward_timeout_steps: 1_000
87
+ no_reward_fire_steps: 500
88
+ n_envs: 8
89
+ vec_env_class: async
90
+ algo_hyperparams:
91
+ buffer_size: 100000
92
+ learning_rate: !!float 1e-4
93
+ batch_size: 32
94
+ learning_starts: 100000
95
+ target_update_interval: 1000
96
+ train_freq: 8
97
+ gradient_steps: 2
98
+ exploration_fraction: 0.1
99
+ exploration_final_eps: 0.01
100
+ eval_params:
101
+ deterministic: false
102
+
103
+ PongNoFrameskip-v4:
104
+ <<: *atari-defaults
105
+ n_timesteps: !!float 2.5e6
106
+
107
+ _impala-atari: &impala-atari-defaults
108
+ <<: *atari-defaults
109
+ policy_hyperparams:
110
+ cnn_style: impala
111
+ cnn_feature_dim: 256
112
+ init_layers_orthogonal: true
113
+ cnn_layers_init_orthogonal: false
114
+
115
+ impala-PongNoFrameskip-v4:
116
+ <<: *impala-atari-defaults
117
+ env_id: PongNoFrameskip-v4
118
+ n_timesteps: !!float 2.5e6
119
+
120
+ impala-BreakoutNoFrameskip-v4:
121
+ <<: *impala-atari-defaults
122
+ env_id: BreakoutNoFrameskip-v4
123
+
124
+ impala-SpaceInvadersNoFrameskip-v4:
125
+ <<: *impala-atari-defaults
126
+ env_id: SpaceInvadersNoFrameskip-v4
127
+
128
+ impala-QbertNoFrameskip-v4:
129
+ <<: *impala-atari-defaults
130
+ env_id: QbertNoFrameskip-v4
rl_algo_impls/hyperparams/ppo.yml ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 1e5
3
+ env_hyperparams:
4
+ n_envs: 8
5
+ algo_hyperparams:
6
+ n_steps: 32
7
+ batch_size: 256
8
+ n_epochs: 20
9
+ gae_lambda: 0.8
10
+ gamma: 0.98
11
+ ent_coef: 0.0
12
+ learning_rate: 0.001
13
+ learning_rate_decay: linear
14
+ clip_range: 0.2
15
+ clip_range_decay: linear
16
+ eval_params:
17
+ step_freq: !!float 2.5e4
18
+
19
+ CartPole-v0:
20
+ <<: *cartpole-defaults
21
+ n_timesteps: !!float 5e4
22
+
23
+ MountainCar-v0:
24
+ n_timesteps: !!float 1e6
25
+ env_hyperparams:
26
+ normalize: true
27
+ n_envs: 16
28
+ algo_hyperparams:
29
+ n_steps: 16
30
+ n_epochs: 4
31
+ gae_lambda: 0.98
32
+ gamma: 0.99
33
+ ent_coef: 0.0
34
+
35
+ MountainCarContinuous-v0:
36
+ n_timesteps: !!float 1e5
37
+ env_hyperparams:
38
+ normalize: true
39
+ n_envs: 4
40
+ # policy_hyperparams:
41
+ # init_layers_orthogonal: false
42
+ # log_std_init: -3.29
43
+ # use_sde: true
44
+ algo_hyperparams:
45
+ n_steps: 512
46
+ batch_size: 256
47
+ n_epochs: 10
48
+ learning_rate: !!float 7.77e-5
49
+ ent_coef: 0.01 # 0.00429
50
+ ent_coef_decay: linear
51
+ clip_range: 0.1
52
+ gae_lambda: 0.9
53
+ max_grad_norm: 5
54
+ vf_coef: 0.19
55
+ eval_params:
56
+ step_freq: 5000
57
+
58
+ Acrobot-v1:
59
+ n_timesteps: !!float 1e6
60
+ env_hyperparams:
61
+ n_envs: 16
62
+ normalize: true
63
+ algo_hyperparams:
64
+ n_steps: 256
65
+ n_epochs: 4
66
+ gae_lambda: 0.94
67
+ gamma: 0.99
68
+ ent_coef: 0.0
69
+
70
+ LunarLander-v2:
71
+ n_timesteps: !!float 4e6
72
+ env_hyperparams:
73
+ n_envs: 16
74
+ algo_hyperparams:
75
+ n_steps: 1024
76
+ batch_size: 64
77
+ n_epochs: 4
78
+ gae_lambda: 0.98
79
+ gamma: 0.999
80
+ learning_rate: !!float 5e-4
81
+ learning_rate_decay: linear
82
+ clip_range: 0.2
83
+ clip_range_decay: linear
84
+ ent_coef: 0.01
85
+ normalize_advantage: false
86
+
87
+ BipedalWalker-v3:
88
+ n_timesteps: !!float 10e6
89
+ env_hyperparams:
90
+ n_envs: 16
91
+ normalize: true
92
+ algo_hyperparams:
93
+ n_steps: 2048
94
+ batch_size: 64
95
+ gae_lambda: 0.95
96
+ gamma: 0.99
97
+ n_epochs: 10
98
+ ent_coef: 0.001
99
+ learning_rate: !!float 2.5e-4
100
+ learning_rate_decay: linear
101
+ clip_range: 0.2
102
+ clip_range_decay: linear
103
+
104
+ CarRacing-v0: &carracing-defaults
105
+ n_timesteps: !!float 4e6
106
+ env_hyperparams:
107
+ n_envs: 8
108
+ frame_stack: 4
109
+ policy_hyperparams: &carracing-policy-defaults
110
+ use_sde: true
111
+ log_std_init: -2
112
+ init_layers_orthogonal: false
113
+ activation_fn: relu
114
+ share_features_extractor: false
115
+ cnn_feature_dim: 256
116
+ hidden_sizes: [256]
117
+ algo_hyperparams:
118
+ n_steps: 512
119
+ batch_size: 128
120
+ n_epochs: 10
121
+ learning_rate: !!float 1e-4
122
+ learning_rate_decay: linear
123
+ gamma: 0.99
124
+ gae_lambda: 0.95
125
+ ent_coef: 0.0
126
+ sde_sample_freq: 4
127
+ max_grad_norm: 0.5
128
+ vf_coef: 0.5
129
+ clip_range: 0.2
130
+
131
+ impala-CarRacing-v0:
132
+ <<: *carracing-defaults
133
+ env_id: CarRacing-v0
134
+ policy_hyperparams:
135
+ <<: *carracing-policy-defaults
136
+ cnn_style: impala
137
+ init_layers_orthogonal: true
138
+ cnn_layers_init_orthogonal: false
139
+ hidden_sizes: []
140
+
141
+ # BreakoutNoFrameskip-v4
142
+ # PongNoFrameskip-v4
143
+ # SpaceInvadersNoFrameskip-v4
144
+ # QbertNoFrameskip-v4
145
+ _atari: &atari-defaults
146
+ n_timesteps: !!float 1e7
147
+ env_hyperparams: &atari-env-defaults
148
+ n_envs: 8
149
+ frame_stack: 4
150
+ no_reward_timeout_steps: 1000
151
+ no_reward_fire_steps: 500
152
+ vec_env_class: async
153
+ policy_hyperparams: &atari-policy-defaults
154
+ activation_fn: relu
155
+ algo_hyperparams:
156
+ n_steps: 128
157
+ batch_size: 256
158
+ n_epochs: 4
159
+ learning_rate: !!float 2.5e-4
160
+ learning_rate_decay: linear
161
+ clip_range: 0.1
162
+ clip_range_decay: linear
163
+ vf_coef: 0.5
164
+ ent_coef: 0.01
165
+ eval_params:
166
+ deterministic: false
167
+
168
+ _norm-rewards-atari: &norm-rewards-atari-default
169
+ <<: *atari-defaults
170
+ env_hyperparams:
171
+ <<: *atari-env-defaults
172
+ clip_atari_rewards: false
173
+ normalize: true
174
+ normalize_kwargs:
175
+ norm_obs: false
176
+ norm_reward: true
177
+
178
+ norm-rewards-BreakoutNoFrameskip-v4:
179
+ <<: *norm-rewards-atari-default
180
+ env_id: BreakoutNoFrameskip-v4
181
+
182
+ debug-PongNoFrameskip-v4:
183
+ <<: *atari-defaults
184
+ device: cpu
185
+ env_id: PongNoFrameskip-v4
186
+ env_hyperparams:
187
+ <<: *atari-env-defaults
188
+ vec_env_class: sync
189
+
190
+ _impala-atari: &impala-atari-defaults
191
+ <<: *atari-defaults
192
+ policy_hyperparams:
193
+ <<: *atari-policy-defaults
194
+ cnn_style: impala
195
+ cnn_feature_dim: 256
196
+ init_layers_orthogonal: true
197
+ cnn_layers_init_orthogonal: false
198
+
199
+ impala-PongNoFrameskip-v4:
200
+ <<: *impala-atari-defaults
201
+ env_id: PongNoFrameskip-v4
202
+
203
+ impala-BreakoutNoFrameskip-v4:
204
+ <<: *impala-atari-defaults
205
+ env_id: BreakoutNoFrameskip-v4
206
+
207
+ impala-SpaceInvadersNoFrameskip-v4:
208
+ <<: *impala-atari-defaults
209
+ env_id: SpaceInvadersNoFrameskip-v4
210
+
211
+ impala-QbertNoFrameskip-v4:
212
+ <<: *impala-atari-defaults
213
+ env_id: QbertNoFrameskip-v4
214
+
215
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
216
+ n_timesteps: !!float 2e6
217
+ env_hyperparams: &pybullet-env-defaults
218
+ n_envs: 16
219
+ normalize: true
220
+ policy_hyperparams: &pybullet-policy-defaults
221
+ pi_hidden_sizes: [256, 256]
222
+ v_hidden_sizes: [256, 256]
223
+ activation_fn: relu
224
+ algo_hyperparams: &pybullet-algo-defaults
225
+ n_steps: 512
226
+ batch_size: 128
227
+ n_epochs: 20
228
+ gamma: 0.99
229
+ gae_lambda: 0.9
230
+ ent_coef: 0.0
231
+ max_grad_norm: 0.5
232
+ vf_coef: 0.5
233
+ learning_rate: !!float 3e-5
234
+ clip_range: 0.4
235
+
236
+ AntBulletEnv-v0:
237
+ <<: *pybullet-defaults
238
+ policy_hyperparams:
239
+ <<: *pybullet-policy-defaults
240
+ algo_hyperparams:
241
+ <<: *pybullet-algo-defaults
242
+
243
+ Walker2DBulletEnv-v0:
244
+ <<: *pybullet-defaults
245
+ algo_hyperparams:
246
+ <<: *pybullet-algo-defaults
247
+ clip_range_decay: linear
248
+
249
+ HopperBulletEnv-v0:
250
+ <<: *pybullet-defaults
251
+ algo_hyperparams:
252
+ <<: *pybullet-algo-defaults
253
+ clip_range_decay: linear
254
+
255
+ HumanoidBulletEnv-v0:
256
+ <<: *pybullet-defaults
257
+ n_timesteps: !!float 1e7
258
+ env_hyperparams:
259
+ <<: *pybullet-env-defaults
260
+ n_envs: 8
261
+ policy_hyperparams:
262
+ <<: *pybullet-policy-defaults
263
+ # log_std_init: -1
264
+ algo_hyperparams:
265
+ <<: *pybullet-algo-defaults
266
+ n_steps: 2048
267
+ batch_size: 64
268
+ n_epochs: 10
269
+ gae_lambda: 0.95
270
+ learning_rate: !!float 2.5e-4
271
+ clip_range: 0.2
272
+
273
+ _procgen: &procgen-defaults
274
+ env_hyperparams: &procgen-env-defaults
275
+ env_type: procgen
276
+ n_envs: 64
277
+ # grayscale: false
278
+ # frame_stack: 4
279
+ normalize: true # procgen only normalizes reward
280
+ make_kwargs: &procgen-make-kwargs-defaults
281
+ num_threads: 8
282
+ policy_hyperparams: &procgen-policy-defaults
283
+ activation_fn: relu
284
+ cnn_style: impala
285
+ cnn_feature_dim: 256
286
+ init_layers_orthogonal: true
287
+ cnn_layers_init_orthogonal: false
288
+ algo_hyperparams: &procgen-algo-defaults
289
+ gamma: 0.999
290
+ gae_lambda: 0.95
291
+ n_steps: 256
292
+ batch_size: 2048
293
+ n_epochs: 3
294
+ ent_coef: 0.01
295
+ clip_range: 0.2
296
+ # clip_range_decay: linear
297
+ clip_range_vf: 0.2
298
+ learning_rate: !!float 5e-4
299
+ # learning_rate_decay: linear
300
+ vf_coef: 0.5
301
+ eval_params: &procgen-eval-defaults
302
+ ignore_first_episode: true
303
+ # deterministic: false
304
+ step_freq: !!float 1e5
305
+
306
+ _procgen-easy: &procgen-easy-defaults
307
+ <<: *procgen-defaults
308
+ n_timesteps: !!float 25e6
309
+ env_hyperparams: &procgen-easy-env-defaults
310
+ <<: *procgen-env-defaults
311
+ make_kwargs:
312
+ <<: *procgen-make-kwargs-defaults
313
+ distribution_mode: easy
314
+
315
+ procgen-coinrun-easy: &coinrun-easy-defaults
316
+ <<: *procgen-easy-defaults
317
+ env_id: coinrun
318
+
319
+ debug-procgen-coinrun:
320
+ <<: *coinrun-easy-defaults
321
+ device: cpu
322
+
323
+ procgen-starpilot-easy:
324
+ <<: *procgen-easy-defaults
325
+ env_id: starpilot
326
+
327
+ procgen-bossfight-easy:
328
+ <<: *procgen-easy-defaults
329
+ env_id: bossfight
330
+
331
+ procgen-bigfish-easy:
332
+ <<: *procgen-easy-defaults
333
+ env_id: bigfish
334
+
335
+ _procgen-hard: &procgen-hard-defaults
336
+ <<: *procgen-defaults
337
+ n_timesteps: !!float 200e6
338
+ env_hyperparams: &procgen-hard-env-defaults
339
+ <<: *procgen-env-defaults
340
+ n_envs: 256
341
+ make_kwargs:
342
+ <<: *procgen-make-kwargs-defaults
343
+ distribution_mode: hard
344
+ algo_hyperparams: &procgen-hard-algo-defaults
345
+ <<: *procgen-algo-defaults
346
+ batch_size: 8192
347
+ clip_range_decay: linear
348
+ learning_rate_decay: linear
349
+ eval_params:
350
+ <<: *procgen-eval-defaults
351
+ step_freq: !!float 5e5
352
+
353
+ procgen-starpilot-hard: &procgen-starpilot-hard-defaults
354
+ <<: *procgen-hard-defaults
355
+ env_id: starpilot
356
+
357
+ procgen-starpilot-hard-2xIMPALA:
358
+ <<: *procgen-starpilot-hard-defaults
359
+ policy_hyperparams:
360
+ <<: *procgen-policy-defaults
361
+ impala_channels: [32, 64, 64]
362
+ algo_hyperparams:
363
+ <<: *procgen-hard-algo-defaults
364
+ learning_rate: !!float 3.3e-4
365
+
366
+ procgen-starpilot-hard-2xIMPALA-fat:
367
+ <<: *procgen-starpilot-hard-defaults
368
+ policy_hyperparams:
369
+ <<: *procgen-policy-defaults
370
+ impala_channels: [32, 64, 64]
371
+ cnn_feature_dim: 512
372
+ algo_hyperparams:
373
+ <<: *procgen-hard-algo-defaults
374
+ learning_rate: !!float 2.5e-4
375
+
376
+ procgen-starpilot-hard-4xIMPALA:
377
+ <<: *procgen-starpilot-hard-defaults
378
+ policy_hyperparams:
379
+ <<: *procgen-policy-defaults
380
+ impala_channels: [64, 128, 128]
381
+ algo_hyperparams:
382
+ <<: *procgen-hard-algo-defaults
383
+ learning_rate: !!float 2.1e-4
rl_algo_impls/hyperparams/vpg.yml ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CartPole-v1: &cartpole-defaults
2
+ n_timesteps: !!float 4e5
3
+ algo_hyperparams:
4
+ n_steps: 4096
5
+ pi_lr: 0.01
6
+ gamma: 0.99
7
+ gae_lambda: 1
8
+ val_lr: 0.01
9
+ train_v_iters: 80
10
+ eval_params:
11
+ step_freq: !!float 2.5e4
12
+
13
+ CartPole-v0:
14
+ <<: *cartpole-defaults
15
+ n_timesteps: !!float 1e5
16
+ algo_hyperparams:
17
+ n_steps: 1024
18
+ pi_lr: 0.01
19
+ gamma: 0.99
20
+ gae_lambda: 1
21
+ val_lr: 0.01
22
+ train_v_iters: 80
23
+
24
+ MountainCar-v0:
25
+ n_timesteps: !!float 1e6
26
+ env_hyperparams:
27
+ normalize: true
28
+ n_envs: 16
29
+ algo_hyperparams:
30
+ n_steps: 200
31
+ pi_lr: 0.005
32
+ gamma: 0.99
33
+ gae_lambda: 0.97
34
+ val_lr: 0.01
35
+ train_v_iters: 80
36
+ max_grad_norm: 0.5
37
+
38
+ MountainCarContinuous-v0:
39
+ n_timesteps: !!float 3e5
40
+ env_hyperparams:
41
+ normalize: true
42
+ n_envs: 4
43
+ # policy_hyperparams:
44
+ # init_layers_orthogonal: false
45
+ # log_std_init: -3.29
46
+ # use_sde: true
47
+ algo_hyperparams:
48
+ n_steps: 1000
49
+ pi_lr: !!float 5e-4
50
+ gamma: 0.99
51
+ gae_lambda: 0.9
52
+ val_lr: !!float 1e-3
53
+ train_v_iters: 80
54
+ max_grad_norm: 5
55
+ eval_params:
56
+ step_freq: 5000
57
+
58
+ Acrobot-v1:
59
+ n_timesteps: !!float 2e5
60
+ algo_hyperparams:
61
+ n_steps: 2048
62
+ pi_lr: 0.005
63
+ gamma: 0.99
64
+ gae_lambda: 0.97
65
+ val_lr: 0.01
66
+ train_v_iters: 80
67
+ max_grad_norm: 0.5
68
+
69
+ LunarLander-v2:
70
+ n_timesteps: !!float 4e6
71
+ policy_hyperparams:
72
+ hidden_sizes: [256, 256]
73
+ algo_hyperparams:
74
+ n_steps: 2048
75
+ pi_lr: 0.0001
76
+ gamma: 0.999
77
+ gae_lambda: 0.97
78
+ val_lr: 0.0001
79
+ train_v_iters: 80
80
+ max_grad_norm: 0.5
81
+ eval_params:
82
+ deterministic: false
83
+
84
+ BipedalWalker-v3:
85
+ n_timesteps: !!float 10e6
86
+ env_hyperparams:
87
+ n_envs: 16
88
+ normalize: true
89
+ policy_hyperparams:
90
+ hidden_sizes: [256, 256]
91
+ algo_hyperparams:
92
+ n_steps: 1600
93
+ gae_lambda: 0.95
94
+ gamma: 0.99
95
+ pi_lr: !!float 1e-4
96
+ val_lr: !!float 1e-4
97
+ train_v_iters: 80
98
+ max_grad_norm: 0.5
99
+ eval_params:
100
+ deterministic: false
101
+
102
+ CarRacing-v0:
103
+ n_timesteps: !!float 4e6
104
+ env_hyperparams:
105
+ frame_stack: 4
106
+ n_envs: 4
107
+ vec_env_class: sync
108
+ policy_hyperparams:
109
+ use_sde: true
110
+ log_std_init: -2
111
+ init_layers_orthogonal: false
112
+ activation_fn: relu
113
+ cnn_feature_dim: 256
114
+ hidden_sizes: [256]
115
+ algo_hyperparams:
116
+ n_steps: 1000
117
+ pi_lr: !!float 5e-5
118
+ gamma: 0.99
119
+ gae_lambda: 0.95
120
+ val_lr: !!float 1e-4
121
+ train_v_iters: 40
122
+ max_grad_norm: 0.5
123
+ sde_sample_freq: 4
124
+
125
+ HalfCheetahBulletEnv-v0: &pybullet-defaults
126
+ n_timesteps: !!float 2e6
127
+ env_hyperparams: &pybullet-env-defaults
128
+ normalize: true
129
+ policy_hyperparams: &pybullet-policy-defaults
130
+ hidden_sizes: [256, 256]
131
+ algo_hyperparams: &pybullet-algo-defaults
132
+ n_steps: 4000
133
+ pi_lr: !!float 3e-4
134
+ gamma: 0.99
135
+ gae_lambda: 0.97
136
+ val_lr: !!float 1e-3
137
+ train_v_iters: 80
138
+ max_grad_norm: 0.5
139
+
140
+ AntBulletEnv-v0:
141
+ <<: *pybullet-defaults
142
+ policy_hyperparams:
143
+ <<: *pybullet-policy-defaults
144
+ hidden_sizes: [400, 300]
145
+ algo_hyperparams:
146
+ <<: *pybullet-algo-defaults
147
+ pi_lr: !!float 7e-4
148
+ val_lr: !!float 7e-3
149
+
150
+ HopperBulletEnv-v0:
151
+ <<: *pybullet-defaults
152
+
153
+ Walker2DBulletEnv-v0:
154
+ <<: *pybullet-defaults
155
+
156
+ FrozenLake-v1:
157
+ n_timesteps: !!float 8e5
158
+ env_params:
159
+ make_kwargs:
160
+ map_name: 8x8
161
+ is_slippery: true
162
+ policy_hyperparams:
163
+ hidden_sizes: [64]
164
+ algo_hyperparams:
165
+ n_steps: 2048
166
+ pi_lr: 0.01
167
+ gamma: 0.99
168
+ gae_lambda: 0.98
169
+ val_lr: 0.01
170
+ train_v_iters: 80
171
+ max_grad_norm: 0.5
172
+ eval_params:
173
+ step_freq: !!float 5e4
174
+ n_episodes: 10
175
+ save_best: true
176
+
177
+ _atari: &atari-defaults
178
+ n_timesteps: !!float 25e6
179
+ env_hyperparams:
180
+ n_envs: 4
181
+ frame_stack: 4
182
+ no_reward_timeout_steps: 1000
183
+ no_reward_fire_steps: 500
184
+ vec_env_class: async
185
+ policy_hyperparams:
186
+ activation_fn: relu
187
+ algo_hyperparams:
188
+ n_steps: 2048
189
+ pi_lr: !!float 5e-5
190
+ gamma: 0.99
191
+ gae_lambda: 0.95
192
+ val_lr: !!float 1e-4
193
+ train_v_iters: 80
194
+ max_grad_norm: 0.5
195
+ ent_coef: 0.01
196
+ eval_params:
197
+ deterministic: false
rl_algo_impls/optimize.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import gc
3
+ import inspect
4
+ import logging
5
+ import numpy as np
6
+ import optuna
7
+ import os
8
+ import torch
9
+ import wandb
10
+
11
+ from dataclasses import asdict, dataclass
12
+ from optuna.pruners import HyperbandPruner
13
+ from optuna.samplers import TPESampler
14
+ from optuna.visualization import plot_optimization_history, plot_param_importances
15
+ from torch.utils.tensorboard.writer import SummaryWriter
16
+ from typing import Callable, List, NamedTuple, Optional, Sequence, Union
17
+
18
+ from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
19
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
20
+ from rl_algo_impls.runner.env import make_env, make_eval_env
21
+ from rl_algo_impls.runner.running_utils import (
22
+ base_parser,
23
+ load_hyperparams,
24
+ set_seeds,
25
+ get_device,
26
+ make_policy,
27
+ ALGOS,
28
+ hparam_dict,
29
+ )
30
+ from rl_algo_impls.shared.callbacks.optimize_callback import (
31
+ Evaluation,
32
+ OptimizeCallback,
33
+ evaluation,
34
+ )
35
+ from rl_algo_impls.shared.stats import EpisodesStats
36
+
37
+
38
+ @dataclass
39
+ class StudyArgs:
40
+ load_study: bool
41
+ study_name: Optional[str] = None
42
+ storage_path: Optional[str] = None
43
+ n_trials: int = 100
44
+ n_jobs: int = 1
45
+ n_evaluations: int = 4
46
+ n_eval_envs: int = 8
47
+ n_eval_episodes: int = 16
48
+ timeout: Union[int, float, None] = None
49
+ wandb_project_name: Optional[str] = None
50
+ wandb_entity: Optional[str] = None
51
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
52
+ wandb_group: Optional[str] = None
53
+ virtual_display: bool = False
54
+
55
+
56
+ class Args(NamedTuple):
57
+ train_args: Sequence[RunArgs]
58
+ study_args: StudyArgs
59
+
60
+
61
+ def parse_args() -> Args:
62
+ parser = base_parser()
63
+ parser.add_argument(
64
+ "--load-study",
65
+ action="store_true",
66
+ help="Load a preexisting study, useful for parallelization",
67
+ )
68
+ parser.add_argument("--study-name", type=str, help="Optuna study name")
69
+ parser.add_argument(
70
+ "--storage-path",
71
+ type=str,
72
+ help="Path of database for Optuna to persist to",
73
+ )
74
+ parser.add_argument(
75
+ "--wandb-project-name",
76
+ type=str,
77
+ default="rl-algo-impls-tuning",
78
+ help="WandB project name to upload tuning data to. If none, won't upload",
79
+ )
80
+ parser.add_argument(
81
+ "--wandb-entity",
82
+ type=str,
83
+ help="WandB team. None uses the default entity",
84
+ )
85
+ parser.add_argument(
86
+ "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
87
+ )
88
+ parser.add_argument(
89
+ "--wandb-group", type=str, help="WandB group to group trials under"
90
+ )
91
+ parser.add_argument(
92
+ "--n-trials", type=int, default=100, help="Maximum number of trials"
93
+ )
94
+ parser.add_argument(
95
+ "--n-jobs", type=int, default=1, help="Number of jobs to run in parallel"
96
+ )
97
+ parser.add_argument(
98
+ "--n-evaluations",
99
+ type=int,
100
+ default=4,
101
+ help="Number of evaluations during the training",
102
+ )
103
+ parser.add_argument(
104
+ "--n-eval-envs",
105
+ type=int,
106
+ default=8,
107
+ help="Number of envs in vectorized eval environment",
108
+ )
109
+ parser.add_argument(
110
+ "--n-eval-episodes",
111
+ type=int,
112
+ default=16,
113
+ help="Number of episodes to complete for evaluation",
114
+ )
115
+ parser.add_argument("--timeout", type=int, help="Seconds to timeout optimization")
116
+ parser.add_argument(
117
+ "--virtual-display", action="store_true", help="Use headless virtual display"
118
+ )
119
+ # parser.set_defaults(
120
+ # algo=["a2c"],
121
+ # env=["CartPole-v1"],
122
+ # seed=[100, 200, 300],
123
+ # n_trials=5,
124
+ # virtual_display=True,
125
+ # )
126
+ train_dict, study_dict = {}, {}
127
+ for k, v in vars(parser.parse_args()).items():
128
+ if k in inspect.signature(StudyArgs).parameters:
129
+ study_dict[k] = v
130
+ else:
131
+ train_dict[k] = v
132
+
133
+ study_args = StudyArgs(**study_dict)
134
+ # Hyperparameter tuning across algos and envs not supported
135
+ assert len(train_dict["algo"]) == 1
136
+ assert len(train_dict["env"]) == 1
137
+ train_args = RunArgs.expand_from_dict(train_dict)
138
+
139
+ if not all((study_args.study_name, study_args.storage_path)):
140
+ hyperparams = load_hyperparams(train_args[0].algo, train_args[0].env)
141
+ config = Config(train_args[0], hyperparams, os.getcwd())
142
+ if study_args.study_name is None:
143
+ study_args.study_name = config.run_name(include_seed=False)
144
+ if study_args.storage_path is None:
145
+ study_args.storage_path = (
146
+ f"sqlite:///{os.path.join(config.runs_dir, 'tuning.db')}"
147
+ )
148
+ # Default set group name to study name
149
+ study_args.wandb_group = study_args.wandb_group or study_args.study_name
150
+
151
+ return Args(train_args, study_args)
152
+
153
+
154
+ def objective_fn(
155
+ args: Sequence[RunArgs], study_args: StudyArgs
156
+ ) -> Callable[[optuna.Trial], float]:
157
+ def objective(trial: optuna.Trial) -> float:
158
+ if len(args) == 1:
159
+ return simple_optimize(trial, args[0], study_args)
160
+ else:
161
+ return stepwise_optimize(trial, args, study_args)
162
+
163
+ return objective
164
+
165
+
166
+ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -> float:
167
+ base_hyperparams = load_hyperparams(args.algo, args.env)
168
+ base_config = Config(args, base_hyperparams, os.getcwd())
169
+ if args.algo == "a2c":
170
+ hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
171
+ else:
172
+ raise ValueError(f"Optimizing {args.algo} isn't supported")
173
+ config = Config(args, hyperparams, os.getcwd())
174
+
175
+ wandb_enabled = bool(study_args.wandb_project_name)
176
+ if wandb_enabled:
177
+ wandb.init(
178
+ project=study_args.wandb_project_name,
179
+ entity=study_args.wandb_entity,
180
+ config=asdict(hyperparams),
181
+ name=f"{config.model_name()}-{str(trial.number)}",
182
+ tags=study_args.wandb_tags,
183
+ group=study_args.wandb_group,
184
+ sync_tensorboard=True,
185
+ monitor_gym=True,
186
+ save_code=True,
187
+ reinit=True,
188
+ )
189
+ wandb.config.update(args)
190
+
191
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
192
+ set_seeds(args.seed, args.use_deterministic_algorithms)
193
+
194
+ env = make_env(
195
+ config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
196
+ )
197
+ device = get_device(config.device, env)
198
+ policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
199
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
200
+
201
+ eval_env = make_eval_env(
202
+ config,
203
+ EnvHyperparams(**config.env_hyperparams),
204
+ override_n_envs=study_args.n_eval_envs,
205
+ )
206
+ callback = OptimizeCallback(
207
+ policy,
208
+ eval_env,
209
+ trial,
210
+ tb_writer,
211
+ step_freq=config.n_timesteps // study_args.n_evaluations,
212
+ n_episodes=study_args.n_eval_episodes,
213
+ deterministic=config.eval_params.get("deterministic", True),
214
+ )
215
+ try:
216
+ algo.learn(config.n_timesteps, callback=callback)
217
+
218
+ if not callback.is_pruned:
219
+ callback.evaluate()
220
+ if not callback.is_pruned:
221
+ policy.save(config.model_dir_path(best=False))
222
+
223
+ eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
224
+ train_stat: EpisodesStats = callback.last_train_stat # type: ignore
225
+
226
+ tb_writer.add_hparams(
227
+ hparam_dict(hyperparams, vars(args)),
228
+ {
229
+ "hparam/last_mean": eval_stat.score.mean,
230
+ "hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
231
+ "hparam/train_mean": train_stat.score.mean,
232
+ "hparam/train_result": train_stat.score.mean - train_stat.score.std,
233
+ "hparam/score": callback.last_score,
234
+ "hparam/is_pruned": callback.is_pruned,
235
+ },
236
+ None,
237
+ config.run_name(),
238
+ )
239
+ tb_writer.close()
240
+
241
+ if wandb_enabled:
242
+ wandb.run.summary["state"] = "Pruned" if callback.is_pruned else "Complete"
243
+ wandb.finish(quiet=True)
244
+
245
+ if callback.is_pruned:
246
+ raise optuna.exceptions.TrialPruned()
247
+
248
+ return callback.last_score
249
+ except AssertionError as e:
250
+ logging.warning(e)
251
+ return np.nan
252
+ finally:
253
+ env.close()
254
+ eval_env.close()
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
257
+
258
+
259
+ def stepwise_optimize(
260
+ trial: optuna.Trial, args: Sequence[RunArgs], study_args: StudyArgs
261
+ ) -> float:
262
+ algo = args[0].algo
263
+ env_id = args[0].env
264
+ base_hyperparams = load_hyperparams(algo, env_id)
265
+ base_config = Config(args[0], base_hyperparams, os.getcwd())
266
+ if algo == "a2c":
267
+ hyperparams = a2c_sample_params(trial, base_hyperparams, base_config)
268
+ else:
269
+ raise ValueError(f"Optimizing {algo} isn't supported")
270
+
271
+ wandb_enabled = bool(study_args.wandb_project_name)
272
+ if wandb_enabled:
273
+ wandb.init(
274
+ project=study_args.wandb_project_name,
275
+ entity=study_args.wandb_entity,
276
+ config=asdict(hyperparams),
277
+ name=f"{study_args.study_name}-{str(trial.number)}",
278
+ tags=study_args.wandb_tags,
279
+ group=study_args.wandb_group,
280
+ save_code=True,
281
+ reinit=True,
282
+ )
283
+
284
+ score = -np.inf
285
+
286
+ for i in range(study_args.n_evaluations):
287
+ evaluations: List[Evaluation] = []
288
+
289
+ for arg in args:
290
+ config = Config(arg, hyperparams, os.getcwd())
291
+
292
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
293
+ set_seeds(arg.seed, arg.use_deterministic_algorithms)
294
+
295
+ env = make_env(
296
+ config,
297
+ EnvHyperparams(**config.env_hyperparams),
298
+ normalize_load_path=config.model_dir_path() if i > 0 else None,
299
+ tb_writer=tb_writer,
300
+ )
301
+ device = get_device(config.device, env)
302
+ policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
303
+ if i > 0:
304
+ policy.load(config.model_dir_path())
305
+ algo = ALGOS[arg.algo](
306
+ policy, env, device, tb_writer, **config.algo_hyperparams
307
+ )
308
+
309
+ eval_env = make_eval_env(
310
+ config,
311
+ EnvHyperparams(**config.env_hyperparams),
312
+ normalize_load_path=config.model_dir_path() if i > 0 else None,
313
+ override_n_envs=study_args.n_eval_envs,
314
+ )
315
+
316
+ start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
317
+ train_timesteps = (
318
+ int((i + 1) * config.n_timesteps / study_args.n_evaluations)
319
+ - start_timesteps
320
+ )
321
+
322
+ try:
323
+ algo.learn(
324
+ train_timesteps,
325
+ callback=None,
326
+ total_timesteps=config.n_timesteps,
327
+ start_timesteps=start_timesteps,
328
+ )
329
+
330
+ evaluations.append(
331
+ evaluation(
332
+ policy,
333
+ eval_env,
334
+ tb_writer,
335
+ study_args.n_eval_episodes,
336
+ config.eval_params.get("deterministic", True),
337
+ start_timesteps + train_timesteps,
338
+ )
339
+ )
340
+
341
+ policy.save(config.model_dir_path())
342
+
343
+ tb_writer.close()
344
+
345
+ except AssertionError as e:
346
+ logging.warning(e)
347
+ if wandb_enabled:
348
+ wandb_finish("Error")
349
+ return np.nan
350
+ finally:
351
+ env.close()
352
+ eval_env.close()
353
+ gc.collect()
354
+ torch.cuda.empty_cache()
355
+
356
+ d = {}
357
+ for idx, e in enumerate(evaluations):
358
+ d[f"{idx}/eval_mean"] = e.eval_stat.score.mean
359
+ d[f"{idx}/train_mean"] = e.train_stat.score.mean
360
+ d[f"{idx}/score"] = e.score
361
+ d["eval"] = np.mean([e.eval_stat.score.mean for e in evaluations]).item()
362
+ d["train"] = np.mean([e.train_stat.score.mean for e in evaluations]).item()
363
+ score = np.mean([e.score for e in evaluations]).item()
364
+ d["score"] = score
365
+
366
+ step = i + 1
367
+ wandb.log(d, step=step)
368
+
369
+ print(f"Trial #{trial.number} Step {step} Score: {round(score, 2)}")
370
+ trial.report(score, step)
371
+ if trial.should_prune():
372
+ if wandb_enabled:
373
+ wandb_finish("Pruned")
374
+ raise optuna.exceptions.TrialPruned()
375
+
376
+ if wandb_enabled:
377
+ wandb_finish("Complete")
378
+ return score
379
+
380
+
381
+ def wandb_finish(state: str) -> None:
382
+ wandb.run.summary["state"] = state
383
+ wandb.finish(quiet=True)
384
+
385
+
386
+ def optimize() -> None:
387
+ from pyvirtualdisplay.display import Display
388
+
389
+ train_args, study_args = parse_args()
390
+ if study_args.virtual_display:
391
+ virtual_display = Display(visible=False, size=(1400, 900))
392
+ virtual_display.start()
393
+
394
+ sampler = TPESampler(**TPESampler.hyperopt_parameters())
395
+ pruner = HyperbandPruner()
396
+ if study_args.load_study:
397
+ assert study_args.study_name
398
+ assert study_args.storage_path
399
+ study = optuna.load_study(
400
+ study_name=study_args.study_name,
401
+ storage=study_args.storage_path,
402
+ sampler=sampler,
403
+ pruner=pruner,
404
+ )
405
+ else:
406
+ study = optuna.create_study(
407
+ study_name=study_args.study_name,
408
+ storage=study_args.storage_path,
409
+ sampler=sampler,
410
+ pruner=pruner,
411
+ direction="maximize",
412
+ )
413
+
414
+ try:
415
+ study.optimize(
416
+ objective_fn(train_args, study_args),
417
+ n_trials=study_args.n_trials,
418
+ n_jobs=study_args.n_jobs,
419
+ timeout=study_args.timeout,
420
+ )
421
+ except KeyboardInterrupt:
422
+ pass
423
+
424
+ best = study.best_trial
425
+ print(f"Best Trial Value: {best.value}")
426
+ print("Attributes:")
427
+ for key, value in list(best.params.items()) + list(best.user_attrs.items()):
428
+ print(f" {key}: {value}")
429
+
430
+ df = study.trials_dataframe()
431
+ df = df[df.state == "COMPLETE"].sort_values(by=["value"], ascending=False)
432
+ print(df.to_markdown(index=False))
433
+
434
+ fig1 = plot_optimization_history(study)
435
+ fig1.write_image("opt_history.png")
436
+ fig2 = plot_param_importances(study)
437
+ fig2.write_image("param_importances.png")
438
+
439
+
440
+ if __name__ == "__main__":
441
+ optimize()
rl_algo_impls/ppo/ppo.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from dataclasses import asdict, dataclass, field
6
+ from time import perf_counter
7
+ from torch.optim import Adam
8
+ from torch.utils.tensorboard.writer import SummaryWriter
9
+ from typing import List, Optional, NamedTuple, TypeVar
10
+
11
+ from rl_algo_impls.shared.algorithm import Algorithm
12
+ from rl_algo_impls.shared.callbacks.callback import Callback
13
+ from rl_algo_impls.shared.gae import compute_advantage, compute_rtg_and_advantage
14
+ from rl_algo_impls.shared.policy.on_policy import ActorCritic
15
+ from rl_algo_impls.shared.schedule import (
16
+ constant_schedule,
17
+ linear_schedule,
18
+ update_learning_rate,
19
+ )
20
+ from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator
21
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
22
+
23
+
24
+ @dataclass
25
+ class PPOTrajectory(Trajectory):
26
+ logp_a: List[float] = field(default_factory=list)
27
+
28
+ def add(
29
+ self,
30
+ obs: np.ndarray,
31
+ act: np.ndarray,
32
+ next_obs: np.ndarray,
33
+ rew: float,
34
+ terminated: bool,
35
+ v: float,
36
+ logp_a: float,
37
+ ):
38
+ super().add(obs, act, next_obs, rew, terminated, v)
39
+ self.logp_a.append(logp_a)
40
+
41
+
42
+ class PPOTrajectoryAccumulator(TrajectoryAccumulator):
43
+ def __init__(self, num_envs: int) -> None:
44
+ super().__init__(num_envs, PPOTrajectory)
45
+
46
+ def step(
47
+ self,
48
+ obs: VecEnvObs,
49
+ action: np.ndarray,
50
+ next_obs: VecEnvObs,
51
+ reward: np.ndarray,
52
+ done: np.ndarray,
53
+ val: np.ndarray,
54
+ logp_a: np.ndarray,
55
+ ) -> None:
56
+ super().step(obs, action, next_obs, reward, done, val, logp_a)
57
+
58
+
59
+ class TrainStepStats(NamedTuple):
60
+ loss: float
61
+ pi_loss: float
62
+ v_loss: float
63
+ entropy_loss: float
64
+ approx_kl: float
65
+ clipped_frac: float
66
+ val_clipped_frac: float
67
+
68
+
69
+ @dataclass
70
+ class TrainStats:
71
+ loss: float
72
+ pi_loss: float
73
+ v_loss: float
74
+ entropy_loss: float
75
+ approx_kl: float
76
+ clipped_frac: float
77
+ val_clipped_frac: float
78
+ explained_var: float
79
+
80
+ def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None:
81
+ self.loss = np.mean([s.loss for s in step_stats]).item()
82
+ self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
83
+ self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
84
+ self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
85
+ self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
86
+ self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
87
+ self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item()
88
+ self.explained_var = explained_var
89
+
90
+ def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
91
+ for name, value in asdict(self).items():
92
+ tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
93
+
94
+ def __repr__(self) -> str:
95
+ return " | ".join(
96
+ [
97
+ f"Loss: {round(self.loss, 2)}",
98
+ f"Pi L: {round(self.pi_loss, 2)}",
99
+ f"V L: {round(self.v_loss, 2)}",
100
+ f"E L: {round(self.entropy_loss, 2)}",
101
+ f"Apx KL Div: {round(self.approx_kl, 2)}",
102
+ f"Clip Frac: {round(self.clipped_frac, 2)}",
103
+ f"Val Clip Frac: {round(self.val_clipped_frac, 2)}",
104
+ ]
105
+ )
106
+
107
+
108
+ PPOSelf = TypeVar("PPOSelf", bound="PPO")
109
+
110
+
111
+ class PPO(Algorithm):
112
+ def __init__(
113
+ self,
114
+ policy: ActorCritic,
115
+ env: VecEnv,
116
+ device: torch.device,
117
+ tb_writer: SummaryWriter,
118
+ learning_rate: float = 3e-4,
119
+ learning_rate_decay: str = "none",
120
+ n_steps: int = 2048,
121
+ batch_size: int = 64,
122
+ n_epochs: int = 10,
123
+ gamma: float = 0.99,
124
+ gae_lambda: float = 0.95,
125
+ clip_range: float = 0.2,
126
+ clip_range_decay: str = "none",
127
+ clip_range_vf: Optional[float] = None,
128
+ clip_range_vf_decay: str = "none",
129
+ normalize_advantage: bool = True,
130
+ ent_coef: float = 0.0,
131
+ ent_coef_decay: str = "none",
132
+ vf_coef: float = 0.5,
133
+ ppo2_vf_coef_halving: bool = False,
134
+ max_grad_norm: float = 0.5,
135
+ update_rtg_between_epochs: bool = False,
136
+ sde_sample_freq: int = -1,
137
+ ) -> None:
138
+ super().__init__(policy, env, device, tb_writer)
139
+ self.policy = policy
140
+
141
+ self.gamma = gamma
142
+ self.gae_lambda = gae_lambda
143
+ self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
144
+ self.lr_schedule = (
145
+ linear_schedule(learning_rate, 0)
146
+ if learning_rate_decay == "linear"
147
+ else constant_schedule(learning_rate)
148
+ )
149
+ self.max_grad_norm = max_grad_norm
150
+ self.clip_range_schedule = (
151
+ linear_schedule(clip_range, 0)
152
+ if clip_range_decay == "linear"
153
+ else constant_schedule(clip_range)
154
+ )
155
+ self.clip_range_vf_schedule = None
156
+ if clip_range_vf:
157
+ self.clip_range_vf_schedule = (
158
+ linear_schedule(clip_range_vf, 0)
159
+ if clip_range_vf_decay == "linear"
160
+ else constant_schedule(clip_range_vf)
161
+ )
162
+ self.normalize_advantage = normalize_advantage
163
+ self.ent_coef_schedule = (
164
+ linear_schedule(ent_coef, 0)
165
+ if ent_coef_decay == "linear"
166
+ else constant_schedule(ent_coef)
167
+ )
168
+ self.vf_coef = vf_coef
169
+ self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
170
+
171
+ self.n_steps = n_steps
172
+ self.batch_size = batch_size
173
+ self.n_epochs = n_epochs
174
+ self.sde_sample_freq = sde_sample_freq
175
+
176
+ self.update_rtg_between_epochs = update_rtg_between_epochs
177
+
178
+ def learn(
179
+ self: PPOSelf,
180
+ total_timesteps: int,
181
+ callback: Optional[Callback] = None,
182
+ ) -> PPOSelf:
183
+ obs = self.env.reset()
184
+ ts_elapsed = 0
185
+ while ts_elapsed < total_timesteps:
186
+ start_time = perf_counter()
187
+ accumulator = self._collect_trajectories(obs)
188
+ rollout_steps = self.n_steps * self.env.num_envs
189
+ ts_elapsed += rollout_steps
190
+ progress = ts_elapsed / total_timesteps
191
+ train_stats = self.train(accumulator.all_trajectories, progress, ts_elapsed)
192
+ train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
193
+ end_time = perf_counter()
194
+ self.tb_writer.add_scalar(
195
+ "train/steps_per_second",
196
+ rollout_steps / (end_time - start_time),
197
+ ts_elapsed,
198
+ )
199
+ if callback:
200
+ callback.on_step(timesteps_elapsed=rollout_steps)
201
+
202
+ return self
203
+
204
+ def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
205
+ self.policy.eval()
206
+ accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
207
+ self.policy.reset_noise()
208
+ for i in range(self.n_steps):
209
+ if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
210
+ self.policy.reset_noise()
211
+ action, value, logp_a, clamped_action = self.policy.step(obs)
212
+ next_obs, reward, done, _ = self.env.step(clamped_action)
213
+ accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
214
+ obs = next_obs
215
+ return accumulator
216
+
217
+ def train(
218
+ self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
219
+ ) -> TrainStats:
220
+ self.policy.train()
221
+ learning_rate = self.lr_schedule(progress)
222
+ update_learning_rate(self.optimizer, learning_rate)
223
+ self.tb_writer.add_scalar(
224
+ "charts/learning_rate",
225
+ self.optimizer.param_groups[0]["lr"],
226
+ timesteps_elapsed,
227
+ )
228
+
229
+ pi_clip = self.clip_range_schedule(progress)
230
+ self.tb_writer.add_scalar("charts/pi_clip", pi_clip, timesteps_elapsed)
231
+ if self.clip_range_vf_schedule:
232
+ v_clip = self.clip_range_vf_schedule(progress)
233
+ self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
234
+ else:
235
+ v_clip = None
236
+ ent_coef = self.ent_coef_schedule(progress)
237
+ self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
238
+
239
+ obs = torch.as_tensor(
240
+ np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
241
+ )
242
+ act = torch.as_tensor(
243
+ np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
244
+ )
245
+ rtg, adv = compute_rtg_and_advantage(
246
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
247
+ )
248
+ orig_v = torch.as_tensor(
249
+ np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
250
+ )
251
+ orig_logp_a = torch.as_tensor(
252
+ np.concatenate([np.array(t.logp_a) for t in trajectories]),
253
+ device=self.device,
254
+ )
255
+
256
+ step_stats = []
257
+ for _ in range(self.n_epochs):
258
+ step_stats.clear()
259
+ if self.update_rtg_between_epochs:
260
+ rtg, adv = compute_rtg_and_advantage(
261
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
262
+ )
263
+ else:
264
+ adv = compute_advantage(
265
+ trajectories, self.policy, self.gamma, self.gae_lambda, self.device
266
+ )
267
+ idxs = torch.randperm(len(obs))
268
+ for i in range(0, len(obs), self.batch_size):
269
+ mb_idxs = idxs[i : i + self.batch_size]
270
+ mb_adv = adv[mb_idxs]
271
+ if self.normalize_advantage:
272
+ mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
273
+ self.policy.reset_noise(self.batch_size)
274
+ step_stats.append(
275
+ self._train_step(
276
+ pi_clip,
277
+ v_clip,
278
+ ent_coef,
279
+ obs[mb_idxs],
280
+ act[mb_idxs],
281
+ rtg[mb_idxs],
282
+ mb_adv,
283
+ orig_v[mb_idxs],
284
+ orig_logp_a[mb_idxs],
285
+ )
286
+ )
287
+
288
+ y_pred, y_true = orig_v.cpu().numpy(), rtg.cpu().numpy()
289
+ var_y = np.var(y_true).item()
290
+ explained_var = (
291
+ np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
292
+ )
293
+
294
+ return TrainStats(step_stats, explained_var)
295
+
296
+ def _train_step(
297
+ self,
298
+ pi_clip: float,
299
+ v_clip: Optional[float],
300
+ ent_coef: float,
301
+ obs: torch.Tensor,
302
+ act: torch.Tensor,
303
+ rtg: torch.Tensor,
304
+ adv: torch.Tensor,
305
+ orig_v: torch.Tensor,
306
+ orig_logp_a: torch.Tensor,
307
+ ) -> TrainStepStats:
308
+ logp_a, entropy, v = self.policy(obs, act)
309
+ logratio = logp_a - orig_logp_a
310
+ ratio = torch.exp(logratio)
311
+ clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
312
+ pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
313
+
314
+ v_loss_unclipped = (v - rtg) ** 2
315
+ if v_clip:
316
+ v_loss_clipped = (
317
+ orig_v + torch.clamp(v - orig_v, -v_clip, v_clip) - rtg
318
+ ) ** 2
319
+ v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
320
+ else:
321
+ v_loss = v_loss_unclipped.mean()
322
+ if self.ppo2_vf_coef_halving:
323
+ v_loss *= 0.5
324
+
325
+ entropy_loss = -entropy.mean()
326
+
327
+ loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
328
+
329
+ self.optimizer.zero_grad()
330
+ loss.backward()
331
+ nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
332
+ self.optimizer.step()
333
+
334
+ with torch.no_grad():
335
+ approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
336
+ clipped_frac = (
337
+ ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
338
+ )
339
+ val_clipped_frac = (
340
+ (((v - orig_v).abs() > v_clip).float().mean().cpu().numpy().item())
341
+ if v_clip
342
+ else 0
343
+ )
344
+
345
+ return TrainStepStats(
346
+ loss.item(),
347
+ pi_loss.item(),
348
+ v_loss.item(),
349
+ entropy_loss.item(),
350
+ approx_kl,
351
+ clipped_frac,
352
+ val_clipped_frac,
353
+ )
rl_algo_impls/publish/markdown_format.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import wandb.apis.public
4
+ import yaml
5
+
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass, asdict
8
+ from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
9
+ from urllib.parse import urlparse
10
+
11
+ from rl_algo_impls.runner.evaluate import Evaluation
12
+
13
+ EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
14
+
15
+
16
+ @dataclass
17
+ class EvaluationRow:
18
+ algo: str
19
+ env: str
20
+ seed: Optional[int]
21
+ reward_mean: float
22
+ reward_std: float
23
+ eval_episodes: int
24
+ best: str
25
+ wandb_url: str
26
+
27
+ @staticmethod
28
+ def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
29
+ results = defaultdict(list)
30
+ for r in rows:
31
+ for k, v in asdict(r).items():
32
+ results[k].append(v)
33
+ return pd.DataFrame(results)
34
+
35
+
36
+ class EvalTableData(NamedTuple):
37
+ run: wandb.apis.public.Run
38
+ evaluation: Evaluation
39
+
40
+
41
+ def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
42
+ best_stats = sorted(
43
+ [d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
44
+ )[0]
45
+ table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
46
+ rows = [
47
+ EvaluationRow(
48
+ config.algo,
49
+ config.env_id,
50
+ config.seed(),
51
+ stats.score.mean,
52
+ stats.score.std,
53
+ len(stats),
54
+ "*" if stats == best_stats else "",
55
+ f"[wandb]({r.url})",
56
+ )
57
+ for (r, (_, stats, config)) in table_data
58
+ ]
59
+ df = EvaluationRow.data_frame(rows)
60
+ return df.to_markdown(index=False)
61
+
62
+
63
+ def github_project_link(github_url: str) -> str:
64
+ return f"[{urlparse(github_url).path}]({github_url})"
65
+
66
+
67
+ def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
68
+ algo_caps = algo.upper()
69
+ lines = [
70
+ f"# **{algo_caps}** Agent playing **{env}**",
71
+ f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
72
+ f"the {github_project_link(github_url)} repo.",
73
+ f"All models trained at this commit can be found at {wandb_report_url}.",
74
+ ]
75
+ return "\n\n".join(lines)
76
+
77
+
78
+ def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
79
+ if not commit_hash:
80
+ return github_project_link(github_url)
81
+ return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
82
+
83
+
84
+ def results_section(
85
+ table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
86
+ ) -> str:
87
+ # type: ignore
88
+ lines = [
89
+ "## Training Results",
90
+ f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
91
+ + "agents using different initial seeds. "
92
+ + f"These agents were trained by checking out "
93
+ + f"{github_tree_link(github_url, commit_hash)}. "
94
+ + "The best and last models were kept from each training. "
95
+ + "This submission has loaded the best models from each training, reevaluates "
96
+ + "them, and selects the best model from these latest evaluations (mean - std).",
97
+ ]
98
+ lines.append(evaluation_table(table_data))
99
+ return "\n\n".join(lines)
100
+
101
+
102
+ def prerequisites_section() -> str:
103
+ return """
104
+ ### Prerequisites: Weights & Biases (WandB)
105
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
106
+ By default training goes to a rl-algo-impls project while benchmarks go to
107
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
108
+ models and the model weights are uploaded to WandB.
109
+
110
+ Before doing anything below, you'll need to create a wandb account and run `wandb
111
+ login`.
112
+ """
113
+
114
+
115
+ def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
116
+ return f"""
117
+ ## Usage
118
+ {urlparse(github_url).path}: {github_url}
119
+
120
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
121
+ implementation could be sufficiently different to not be able to reproduce similar
122
+ results. You might need to checkout the commit the agent was trained on:
123
+ {github_tree_link(github_url, commit_hash)}.
124
+ ```
125
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
126
+ python enjoy.py --wandb-run-path={run_path}
127
+ ```
128
+
129
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
130
+ Colab starting from the
131
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
132
+ notebook.
133
+ """
134
+
135
+
136
+ def training_setion(
137
+ github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
138
+ ) -> str:
139
+ return f"""
140
+ ## Training
141
+ If you want the highest chance to reproduce these results, you'll want to checkout the
142
+ commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
143
+ training is deterministic, different hardware will give different results.
144
+
145
+ ```
146
+ python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
147
+ ```
148
+
149
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
150
+ Colab starting from the
151
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
152
+ notebook.
153
+ """
154
+
155
+
156
+ def benchmarking_section(report_url: str) -> str:
157
+ return f"""
158
+ ## Benchmarking (with Lambda Labs instance)
159
+ This and other models from {report_url} were generated by running a script on a Lambda
160
+ Labs instance. In a Lambda Labs instance terminal:
161
+ ```
162
+ git clone git@github.com:sgoodfriend/rl-algo-impls.git
163
+ cd rl-algo-impls
164
+ bash ./lambda_labs/setup.sh
165
+ wandb login
166
+ bash ./lambda_labs/benchmark.sh [-a {{"ppo a2c dqn vpg"}}] [-e ENVS] [-j {{6}}] [-p {{rl-algo-impls-benchmarks}}] [-s {{"1 2 3"}}]
167
+ ```
168
+
169
+ ### Alternative: Google Colab Pro+
170
+ As an alternative,
171
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
172
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
173
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
174
+ """
175
+
176
+
177
+ def hyperparams_section(run_config: Dict[str, Any]) -> str:
178
+ return f"""
179
+ ## Hyperparameters
180
+ This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
181
+ run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
182
+ close and has some additional data:
183
+ ```
184
+ {yaml.dump(run_config)}
185
+ ```
186
+ """
187
+
188
+
189
+ def model_card_text(
190
+ algo: str,
191
+ env: str,
192
+ github_url: str,
193
+ commit_hash: str,
194
+ wandb_report_url: str,
195
+ table_data: List[EvalTableData],
196
+ best_eval: EvalTableData,
197
+ ) -> str:
198
+ run, (_, _, config) = best_eval
199
+ run_path = "/".join(run.path)
200
+ return "\n\n".join(
201
+ [
202
+ header_section(algo, env, github_url, wandb_report_url),
203
+ results_section(table_data, algo, github_url, commit_hash),
204
+ prerequisites_section(),
205
+ usage_section(github_url, run_path, commit_hash),
206
+ training_setion(github_url, commit_hash, algo, env, config.seed()),
207
+ benchmarking_section(wandb_report_url),
208
+ hyperparams_section(run.config),
209
+ ]
210
+ )
rl_algo_impls/runner/config.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import inspect
3
+ import itertools
4
+ import os
5
+
6
+ from datetime import datetime
7
+ from dataclasses import dataclass
8
+ from typing import Any, Dict, List, Optional, Type, TypeVar, Union
9
+
10
+
11
+ RunArgsSelf = TypeVar("RunArgsSelf", bound="RunArgs")
12
+
13
+
14
+ @dataclass
15
+ class RunArgs:
16
+ algo: str
17
+ env: str
18
+ seed: Optional[int] = None
19
+ use_deterministic_algorithms: bool = True
20
+
21
+ @classmethod
22
+ def expand_from_dict(
23
+ cls: Type[RunArgsSelf], d: Dict[str, Any]
24
+ ) -> List[RunArgsSelf]:
25
+ maybe_listify = lambda v: [v] if isinstance(v, str) or isinstance(v, int) else v
26
+ algos = maybe_listify(d["algo"])
27
+ envs = maybe_listify(d["env"])
28
+ seeds = maybe_listify(d["seed"])
29
+ args = []
30
+ for algo, env, seed in itertools.product(algos, envs, seeds):
31
+ _d = d.copy()
32
+ _d.update({"algo": algo, "env": env, "seed": seed})
33
+ args.append(cls(**_d))
34
+ return args
35
+
36
+
37
+ @dataclass
38
+ class EnvHyperparams:
39
+ env_type: str = "gymvec"
40
+ n_envs: int = 1
41
+ frame_stack: int = 1
42
+ make_kwargs: Optional[Dict[str, Any]] = None
43
+ no_reward_timeout_steps: Optional[int] = None
44
+ no_reward_fire_steps: Optional[int] = None
45
+ vec_env_class: str = "sync"
46
+ normalize: bool = False
47
+ normalize_kwargs: Optional[Dict[str, Any]] = None
48
+ rolling_length: int = 100
49
+ train_record_video: bool = False
50
+ video_step_interval: Union[int, float] = 1_000_000
51
+ initial_steps_to_truncate: Optional[int] = None
52
+ clip_atari_rewards: bool = True
53
+
54
+
55
+ HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
56
+
57
+
58
+ @dataclass
59
+ class Hyperparams:
60
+ device: str = "auto"
61
+ n_timesteps: Union[int, float] = 100_000
62
+ env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
63
+ policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
64
+ algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
65
+ eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
66
+ env_id: Optional[str] = None
67
+
68
+ @classmethod
69
+ def from_dict_with_extra_fields(
70
+ cls: Type[HyperparamsSelf], d: Dict[str, Any]
71
+ ) -> HyperparamsSelf:
72
+ return cls(
73
+ **{k: v for k, v in d.items() if k in inspect.signature(cls).parameters}
74
+ )
75
+
76
+
77
+ @dataclass
78
+ class Config:
79
+ args: RunArgs
80
+ hyperparams: Hyperparams
81
+ root_dir: str
82
+ run_id: str = datetime.now().isoformat()
83
+
84
+ def seed(self, training: bool = True) -> Optional[int]:
85
+ seed = self.args.seed
86
+ if training or seed is None:
87
+ return seed
88
+ return seed + self.env_hyperparams.get("n_envs", 1)
89
+
90
+ @property
91
+ def device(self) -> str:
92
+ return self.hyperparams.device
93
+
94
+ @property
95
+ def n_timesteps(self) -> int:
96
+ return int(self.hyperparams.n_timesteps)
97
+
98
+ @property
99
+ def env_hyperparams(self) -> Dict[str, Any]:
100
+ return self.hyperparams.env_hyperparams
101
+
102
+ @property
103
+ def policy_hyperparams(self) -> Dict[str, Any]:
104
+ return self.hyperparams.policy_hyperparams
105
+
106
+ @property
107
+ def algo_hyperparams(self) -> Dict[str, Any]:
108
+ return self.hyperparams.algo_hyperparams
109
+
110
+ @property
111
+ def eval_params(self) -> Dict[str, Any]:
112
+ return self.hyperparams.eval_params
113
+
114
+ @property
115
+ def algo(self) -> str:
116
+ return self.args.algo
117
+
118
+ @property
119
+ def env_id(self) -> str:
120
+ return self.hyperparams.env_id or self.args.env
121
+
122
+ def model_name(self, include_seed: bool = True) -> str:
123
+ # Use arg env name instead of environment name
124
+ parts = [self.algo, self.args.env]
125
+ if include_seed and self.args.seed is not None:
126
+ parts.append(f"S{self.args.seed}")
127
+
128
+ # Assume that the custom arg name already has the necessary information
129
+ if not self.hyperparams.env_id:
130
+ make_kwargs = self.env_hyperparams.get("make_kwargs", {})
131
+ if make_kwargs:
132
+ for k, v in make_kwargs.items():
133
+ if type(v) == bool and v:
134
+ parts.append(k)
135
+ elif type(v) == int and v:
136
+ parts.append(f"{k}{v}")
137
+ else:
138
+ parts.append(str(v))
139
+
140
+ return "-".join(parts)
141
+
142
+ def run_name(self, include_seed: bool = True) -> str:
143
+ parts = [self.model_name(include_seed=include_seed), self.run_id]
144
+ return "-".join(parts)
145
+
146
+ @property
147
+ def saved_models_dir(self) -> str:
148
+ return os.path.join(self.root_dir, "saved_models")
149
+
150
+ @property
151
+ def downloaded_models_dir(self) -> str:
152
+ return os.path.join(self.root_dir, "downloaded_models")
153
+
154
+ def model_dir_name(
155
+ self,
156
+ best: bool = False,
157
+ extension: str = "",
158
+ ) -> str:
159
+ return self.model_name() + ("-best" if best else "") + extension
160
+
161
+ def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
162
+ return os.path.join(
163
+ self.saved_models_dir if not downloaded else self.downloaded_models_dir,
164
+ self.model_dir_name(best=best),
165
+ )
166
+
167
+ @property
168
+ def runs_dir(self) -> str:
169
+ return os.path.join(self.root_dir, "runs")
170
+
171
+ @property
172
+ def tensorboard_summary_path(self) -> str:
173
+ return os.path.join(self.runs_dir, self.run_name())
174
+
175
+ @property
176
+ def logs_path(self) -> str:
177
+ return os.path.join(self.runs_dir, f"log.yml")
178
+
179
+ @property
180
+ def videos_dir(self) -> str:
181
+ return os.path.join(self.root_dir, "videos")
182
+
183
+ @property
184
+ def video_prefix(self) -> str:
185
+ return os.path.join(self.videos_dir, self.model_name())
186
+
187
+ @property
188
+ def best_videos_dir(self) -> str:
189
+ return os.path.join(self.videos_dir, f"{self.model_name()}-best")
rl_algo_impls/runner/env.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import numpy as np
3
+ import os
4
+
5
+ from dataclasses import asdict, astuple
6
+ from gym.vector.async_vector_env import AsyncVectorEnv
7
+ from gym.vector.sync_vector_env import SyncVectorEnv
8
+ from gym.wrappers.resize_observation import ResizeObservation
9
+ from gym.wrappers.gray_scale_observation import GrayScaleObservation
10
+ from gym.wrappers.frame_stack import FrameStack
11
+ from stable_baselines3.common.atari_wrappers import (
12
+ MaxAndSkipEnv,
13
+ NoopResetEnv,
14
+ )
15
+ from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
16
+ from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
17
+ from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
18
+ from torch.utils.tensorboard.writer import SummaryWriter
19
+ from typing import Callable, Optional
20
+
21
+ from rl_algo_impls.runner.config import Config, EnvHyperparams
22
+ from rl_algo_impls.shared.policy.policy import VEC_NORMALIZE_FILENAME
23
+ from rl_algo_impls.wrappers.atari_wrappers import (
24
+ EpisodicLifeEnv,
25
+ FireOnLifeStarttEnv,
26
+ ClipRewardEnv,
27
+ )
28
+ from rl_algo_impls.wrappers.episode_record_video import EpisodeRecordVideo
29
+ from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
30
+ from rl_algo_impls.wrappers.initial_step_truncate_wrapper import (
31
+ InitialStepTruncateWrapper,
32
+ )
33
+ from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
34
+ from rl_algo_impls.wrappers.no_reward_timeout import NoRewardTimeout
35
+ from rl_algo_impls.wrappers.noop_env_seed import NoopEnvSeed
36
+ from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
37
+ from rl_algo_impls.wrappers.sync_vector_env_render_compat import (
38
+ SyncVectorEnvRenderCompat,
39
+ )
40
+ from rl_algo_impls.wrappers.transpose_image_observation import TransposeImageObservation
41
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
42
+ from rl_algo_impls.wrappers.video_compat_wrapper import VideoCompatWrapper
43
+
44
+
45
+ def make_env(
46
+ config: Config,
47
+ hparams: EnvHyperparams,
48
+ training: bool = True,
49
+ render: bool = False,
50
+ normalize_load_path: Optional[str] = None,
51
+ tb_writer: Optional[SummaryWriter] = None,
52
+ ) -> VecEnv:
53
+ if hparams.env_type == "procgen":
54
+ return _make_procgen_env(
55
+ config,
56
+ hparams,
57
+ training=training,
58
+ render=render,
59
+ normalize_load_path=normalize_load_path,
60
+ tb_writer=tb_writer,
61
+ )
62
+ elif hparams.env_type in {"sb3vec", "gymvec"}:
63
+ return _make_vec_env(
64
+ config,
65
+ hparams,
66
+ training=training,
67
+ render=render,
68
+ normalize_load_path=normalize_load_path,
69
+ tb_writer=tb_writer,
70
+ )
71
+ else:
72
+ raise ValueError(f"env_type {hparams.env_type} not supported")
73
+
74
+
75
+ def make_eval_env(
76
+ config: Config,
77
+ hparams: EnvHyperparams,
78
+ override_n_envs: Optional[int] = None,
79
+ **kwargs,
80
+ ) -> VecEnv:
81
+ kwargs = kwargs.copy()
82
+ kwargs["training"] = False
83
+ if override_n_envs is not None:
84
+ hparams_kwargs = asdict(hparams)
85
+ hparams_kwargs["n_envs"] = override_n_envs
86
+ if override_n_envs == 1:
87
+ hparams_kwargs["vec_env_class"] = "sync"
88
+ hparams = EnvHyperparams(**hparams_kwargs)
89
+ return make_env(config, hparams, **kwargs)
90
+
91
+
92
+ def _make_vec_env(
93
+ config: Config,
94
+ hparams: EnvHyperparams,
95
+ training: bool = True,
96
+ render: bool = False,
97
+ normalize_load_path: Optional[str] = None,
98
+ tb_writer: Optional[SummaryWriter] = None,
99
+ ) -> VecEnv:
100
+ (
101
+ env_type,
102
+ n_envs,
103
+ frame_stack,
104
+ make_kwargs,
105
+ no_reward_timeout_steps,
106
+ no_reward_fire_steps,
107
+ vec_env_class,
108
+ normalize,
109
+ normalize_kwargs,
110
+ rolling_length,
111
+ train_record_video,
112
+ video_step_interval,
113
+ initial_steps_to_truncate,
114
+ clip_atari_rewards,
115
+ ) = astuple(hparams)
116
+
117
+ if "BulletEnv" in config.env_id:
118
+ import pybullet_envs
119
+
120
+ spec = gym.spec(config.env_id)
121
+ seed = config.seed(training=training)
122
+
123
+ make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
124
+ if "BulletEnv" in config.env_id and render:
125
+ make_kwargs["render"] = True
126
+ if "CarRacing" in config.env_id:
127
+ make_kwargs["verbose"] = 0
128
+ if "procgen" in config.env_id:
129
+ if not render:
130
+ make_kwargs["render_mode"] = "rgb_array"
131
+
132
+ def make(idx: int) -> Callable[[], gym.Env]:
133
+ def _make() -> gym.Env:
134
+ env = gym.make(config.env_id, **make_kwargs)
135
+ env = gym.wrappers.RecordEpisodeStatistics(env)
136
+ env = VideoCompatWrapper(env)
137
+ if training and train_record_video and idx == 0:
138
+ env = EpisodeRecordVideo(
139
+ env,
140
+ config.video_prefix,
141
+ step_increment=n_envs,
142
+ video_step_interval=int(video_step_interval),
143
+ )
144
+ if training and initial_steps_to_truncate:
145
+ env = InitialStepTruncateWrapper(
146
+ env, idx * initial_steps_to_truncate // n_envs
147
+ )
148
+ if "AtariEnv" in spec.entry_point: # type: ignore
149
+ env = NoopResetEnv(env, noop_max=30)
150
+ env = MaxAndSkipEnv(env, skip=4)
151
+ env = EpisodicLifeEnv(env, training=training)
152
+ action_meanings = env.unwrapped.get_action_meanings()
153
+ if "FIRE" in action_meanings: # type: ignore
154
+ env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
155
+ if clip_atari_rewards:
156
+ env = ClipRewardEnv(env, training=training)
157
+ env = ResizeObservation(env, (84, 84))
158
+ env = GrayScaleObservation(env, keep_dim=False)
159
+ env = FrameStack(env, frame_stack)
160
+ elif "CarRacing" in config.env_id:
161
+ env = ResizeObservation(env, (64, 64))
162
+ env = GrayScaleObservation(env, keep_dim=False)
163
+ env = FrameStack(env, frame_stack)
164
+ elif "procgen" in config.env_id:
165
+ # env = GrayScaleObservation(env, keep_dim=False)
166
+ env = NoopEnvSeed(env)
167
+ env = TransposeImageObservation(env)
168
+ if frame_stack > 1:
169
+ env = FrameStack(env, frame_stack)
170
+
171
+ if no_reward_timeout_steps:
172
+ env = NoRewardTimeout(
173
+ env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
174
+ )
175
+
176
+ if seed is not None:
177
+ env.seed(seed + idx)
178
+ env.action_space.seed(seed + idx)
179
+ env.observation_space.seed(seed + idx)
180
+
181
+ return env
182
+
183
+ return _make
184
+
185
+ if env_type == "sb3vec":
186
+ VecEnvClass = {"sync": DummyVecEnv, "async": SubprocVecEnv}[vec_env_class]
187
+ elif env_type == "gymvec":
188
+ VecEnvClass = {"sync": SyncVectorEnv, "async": AsyncVectorEnv}[vec_env_class]
189
+ else:
190
+ raise ValueError(f"env_type {env_type} unsupported")
191
+ envs = VecEnvClass([make(i) for i in range(n_envs)])
192
+ if env_type == "gymvec" and vec_env_class == "sync":
193
+ envs = SyncVectorEnvRenderCompat(envs)
194
+ if training:
195
+ assert tb_writer
196
+ envs = EpisodeStatsWriter(
197
+ envs, tb_writer, training=training, rolling_length=rolling_length
198
+ )
199
+ if normalize:
200
+ normalize_kwargs = normalize_kwargs or {}
201
+ if env_type == "sb3vec":
202
+ if normalize_load_path:
203
+ envs = VecNormalize.load(
204
+ os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
205
+ envs, # type: ignore
206
+ )
207
+ else:
208
+ envs = VecNormalize(
209
+ envs, # type: ignore
210
+ training=training,
211
+ **normalize_kwargs,
212
+ )
213
+ if not training:
214
+ envs.norm_reward = False
215
+ else:
216
+ if normalize_kwargs.get("norm_obs", True):
217
+ envs = NormalizeObservation(
218
+ envs, training=training, clip=normalize_kwargs.get("clip_obs", 10.0)
219
+ )
220
+ if training and normalize_kwargs.get("norm_reward", True):
221
+ envs = NormalizeReward(
222
+ envs,
223
+ training=training,
224
+ clip=normalize_kwargs.get("clip_reward", 10.0),
225
+ )
226
+ return envs
227
+
228
+
229
+ def _make_procgen_env(
230
+ config: Config,
231
+ hparams: EnvHyperparams,
232
+ training: bool = True,
233
+ render: bool = False,
234
+ normalize_load_path: Optional[str] = None,
235
+ tb_writer: Optional[SummaryWriter] = None,
236
+ ) -> VecEnv:
237
+ from gym3 import ViewerWrapper, ExtractDictObWrapper
238
+ from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv
239
+
240
+ (
241
+ _, # env_type
242
+ n_envs,
243
+ _, # frame_stack
244
+ make_kwargs,
245
+ _, # no_reward_timeout_steps
246
+ _, # no_reward_fire_steps
247
+ _, # vec_env_class
248
+ normalize,
249
+ normalize_kwargs,
250
+ rolling_length,
251
+ _, # train_record_video
252
+ _, # video_step_interval
253
+ _, # initial_steps_to_truncate
254
+ _, # clip_atari_rewards
255
+ ) = astuple(hparams)
256
+
257
+ seed = config.seed(training=training)
258
+
259
+ make_kwargs = make_kwargs or {}
260
+ make_kwargs["render_mode"] = "rgb_array"
261
+ if seed is not None:
262
+ make_kwargs["rand_seed"] = seed
263
+
264
+ envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs)
265
+ envs = ExtractDictObWrapper(envs, key="rgb")
266
+ if render:
267
+ envs = ViewerWrapper(envs, info_key="rgb")
268
+ envs = ToBaselinesVecEnv(envs)
269
+ envs = IsVectorEnv(envs)
270
+ # TODO: Handle Grayscale and/or FrameStack
271
+ envs = TransposeImageObservation(envs)
272
+
273
+ envs = gym.wrappers.RecordEpisodeStatistics(envs)
274
+
275
+ if seed is not None:
276
+ envs.action_space.seed(seed)
277
+ envs.observation_space.seed(seed)
278
+
279
+ if training:
280
+ assert tb_writer
281
+ envs = EpisodeStatsWriter(
282
+ envs, tb_writer, training=training, rolling_length=rolling_length
283
+ )
284
+ if normalize and training:
285
+ normalize_kwargs = normalize_kwargs or {}
286
+ envs = gym.wrappers.NormalizeReward(envs)
287
+ clip_obs = normalize_kwargs.get("clip_reward", 10.0)
288
+ envs = gym.wrappers.TransformReward(
289
+ envs, lambda r: np.clip(r, -clip_obs, clip_obs)
290
+ )
291
+
292
+ return envs # type: ignore
rl_algo_impls/runner/evaluate.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Optional
6
+
7
+ from rl_algo_impls.runner.env import make_eval_env
8
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
9
+ from rl_algo_impls.runner.running_utils import (
10
+ load_hyperparams,
11
+ set_seeds,
12
+ get_device,
13
+ make_policy,
14
+ )
15
+ from rl_algo_impls.shared.callbacks.eval_callback import evaluate
16
+ from rl_algo_impls.shared.policy.policy import Policy
17
+ from rl_algo_impls.shared.stats import EpisodesStats
18
+
19
+
20
+ @dataclass
21
+ class EvalArgs(RunArgs):
22
+ render: bool = True
23
+ best: bool = True
24
+ n_envs: Optional[int] = 1
25
+ n_episodes: int = 3
26
+ deterministic_eval: Optional[bool] = None
27
+ no_print_returns: bool = False
28
+ wandb_run_path: Optional[str] = None
29
+
30
+
31
+ class Evaluation(NamedTuple):
32
+ policy: Policy
33
+ stats: EpisodesStats
34
+ config: Config
35
+
36
+
37
+ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
38
+ if args.wandb_run_path:
39
+ import wandb
40
+
41
+ api = wandb.Api()
42
+ run = api.run(args.wandb_run_path)
43
+ params = run.config
44
+
45
+ args.algo = params["algo"]
46
+ args.env = params["env"]
47
+ args.seed = params.get("seed", None)
48
+ args.use_deterministic_algorithms = params.get(
49
+ "use_deterministic_algorithms", True
50
+ )
51
+
52
+ config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
53
+ model_path = config.model_dir_path(best=args.best, downloaded=True)
54
+
55
+ model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
56
+ run.file(model_archive_name).download()
57
+ if os.path.isdir(model_path):
58
+ shutil.rmtree(model_path)
59
+ shutil.unpack_archive(model_archive_name, model_path)
60
+ os.remove(model_archive_name)
61
+ else:
62
+ hyperparams = load_hyperparams(args.algo, args.env)
63
+
64
+ config = Config(args, hyperparams, root_dir)
65
+ model_path = config.model_dir_path(best=args.best)
66
+
67
+ print(args)
68
+
69
+ set_seeds(args.seed, args.use_deterministic_algorithms)
70
+
71
+ env = make_eval_env(
72
+ config,
73
+ EnvHyperparams(**config.env_hyperparams),
74
+ override_n_envs=args.n_envs,
75
+ render=args.render,
76
+ normalize_load_path=model_path,
77
+ )
78
+ device = get_device(config.device, env)
79
+ policy = make_policy(
80
+ args.algo,
81
+ env,
82
+ device,
83
+ load_path=model_path,
84
+ **config.policy_hyperparams,
85
+ ).eval()
86
+
87
+ deterministic = (
88
+ args.deterministic_eval
89
+ if args.deterministic_eval is not None
90
+ else config.eval_params.get("deterministic", True)
91
+ )
92
+ return Evaluation(
93
+ policy,
94
+ evaluate(
95
+ env,
96
+ policy,
97
+ args.n_episodes,
98
+ render=args.render,
99
+ deterministic=deterministic,
100
+ print_returns=not args.no_print_returns,
101
+ ),
102
+ config,
103
+ )
rl_algo_impls/runner/running_utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gym
3
+ import json
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import os
7
+ import random
8
+ import torch
9
+ import torch.backends.cudnn
10
+ import yaml
11
+
12
+ from dataclasses import asdict
13
+ from gym.spaces import Box, Discrete
14
+ from pathlib import Path
15
+ from torch.utils.tensorboard.writer import SummaryWriter
16
+ from typing import Dict, Optional, Type, Union
17
+
18
+ from rl_algo_impls.runner.config import Hyperparams
19
+ from rl_algo_impls.shared.algorithm import Algorithm
20
+ from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
21
+ from rl_algo_impls.shared.policy.on_policy import ActorCritic
22
+ from rl_algo_impls.shared.policy.policy import Policy
23
+
24
+ from rl_algo_impls.a2c.a2c import A2C
25
+ from rl_algo_impls.dqn.dqn import DQN
26
+ from rl_algo_impls.dqn.policy import DQNPolicy
27
+ from rl_algo_impls.ppo.ppo import PPO
28
+ from rl_algo_impls.vpg.vpg import VanillaPolicyGradient
29
+ from rl_algo_impls.vpg.policy import VPGActorCritic
30
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
31
+
32
+ ALGOS: Dict[str, Type[Algorithm]] = {
33
+ "dqn": DQN,
34
+ "vpg": VanillaPolicyGradient,
35
+ "ppo": PPO,
36
+ "a2c": A2C,
37
+ }
38
+ POLICIES: Dict[str, Type[Policy]] = {
39
+ "dqn": DQNPolicy,
40
+ "vpg": VPGActorCritic,
41
+ "ppo": ActorCritic,
42
+ "a2c": ActorCritic,
43
+ }
44
+
45
+ HYPERPARAMS_PATH = "hyperparams"
46
+
47
+
48
+ def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument(
51
+ "--algo",
52
+ default=["dqn"],
53
+ type=str,
54
+ choices=list(ALGOS.keys()),
55
+ nargs="+" if multiple else 1,
56
+ help="Abbreviation(s) of algorithm(s)",
57
+ )
58
+ parser.add_argument(
59
+ "--env",
60
+ default=["CartPole-v1"],
61
+ type=str,
62
+ nargs="+" if multiple else 1,
63
+ help="Name of environment(s) in gym",
64
+ )
65
+ parser.add_argument(
66
+ "--seed",
67
+ default=[1],
68
+ type=int,
69
+ nargs="*" if multiple else "?",
70
+ help="Seeds to run experiment. Unset will do one run with no set seed",
71
+ )
72
+ return parser
73
+
74
+
75
+ def load_hyperparams(algo: str, env_id: str) -> Hyperparams:
76
+ root_path = Path(__file__).parent.parent
77
+ hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
78
+ with open(hyperparams_path, "r") as f:
79
+ hyperparams_dict = yaml.safe_load(f)
80
+
81
+ if env_id in hyperparams_dict:
82
+ return Hyperparams(**hyperparams_dict[env_id])
83
+
84
+ if "BulletEnv" in env_id:
85
+ import pybullet_envs
86
+ spec = gym.spec(env_id)
87
+ if "AtariEnv" in str(spec.entry_point) and "_atari" in hyperparams_dict:
88
+ return Hyperparams(**hyperparams_dict["_atari"])
89
+ else:
90
+ raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
91
+
92
+
93
+ def get_device(device: str, env: VecEnv) -> torch.device:
94
+ # cuda by default
95
+ if device == "auto":
96
+ device = "cuda"
97
+ # Apple MPS is a second choice (sometimes)
98
+ if device == "cuda" and not torch.cuda.is_available():
99
+ device = "mps"
100
+ # If no MPS, fallback to cpu
101
+ if device == "mps" and not torch.backends.mps.is_available():
102
+ device = "cpu"
103
+ # Simple environments like Discreet and 1-D Boxes might also be better
104
+ # served with the CPU.
105
+ if device == "mps":
106
+ obs_space = single_observation_space(env)
107
+ if isinstance(obs_space, Discrete):
108
+ device = "cpu"
109
+ elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
110
+ device = "cpu"
111
+ print(f"Device: {device}")
112
+ return torch.device(device)
113
+
114
+
115
+ def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
116
+ if seed is None:
117
+ return
118
+ random.seed(seed)
119
+ np.random.seed(seed)
120
+ torch.manual_seed(seed)
121
+ torch.backends.cudnn.benchmark = False
122
+ torch.use_deterministic_algorithms(use_deterministic_algorithms)
123
+ os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
124
+ # Stop warning and it would introduce stochasticity if I was using TF
125
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
126
+
127
+
128
+ def make_policy(
129
+ algo: str,
130
+ env: VecEnv,
131
+ device: torch.device,
132
+ load_path: Optional[str] = None,
133
+ **kwargs,
134
+ ) -> Policy:
135
+ policy = POLICIES[algo](env, **kwargs).to(device)
136
+ if load_path:
137
+ policy.load(load_path)
138
+ return policy
139
+
140
+
141
+ def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
142
+ figure = plt.figure()
143
+ cumulative_steps = [
144
+ (idx + 1) * callback.step_freq for idx in range(len(callback.stats))
145
+ ]
146
+ plt.plot(
147
+ cumulative_steps,
148
+ [s.score.mean for s in callback.stats],
149
+ "b-",
150
+ label="mean",
151
+ )
152
+ plt.plot(
153
+ cumulative_steps,
154
+ [s.score.mean - s.score.std for s in callback.stats],
155
+ "g--",
156
+ label="mean-std",
157
+ )
158
+ plt.fill_between(
159
+ cumulative_steps,
160
+ [s.score.min for s in callback.stats], # type: ignore
161
+ [s.score.max for s in callback.stats], # type: ignore
162
+ facecolor="cyan",
163
+ label="range",
164
+ )
165
+ plt.xlabel("Steps")
166
+ plt.ylabel("Score")
167
+ plt.legend()
168
+ plt.title(f"Eval {run_name}")
169
+ tb_writer.add_figure("eval", figure)
170
+
171
+
172
+ Scalar = Union[bool, str, float, int, None]
173
+
174
+
175
+ def hparam_dict(
176
+ hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
177
+ ) -> Dict[str, Scalar]:
178
+ flattened = args.copy()
179
+ for k, v in flattened.items():
180
+ if isinstance(v, list):
181
+ flattened[k] = json.dumps(v)
182
+ for k, v in asdict(hyperparams).items():
183
+ if isinstance(v, dict):
184
+ for sk, sv in v.items():
185
+ key = f"{k}/{sk}"
186
+ if isinstance(sv, dict) or isinstance(sv, list):
187
+ flattened[key] = str(sv)
188
+ else:
189
+ flattened[key] = sv
190
+ else:
191
+ flattened[k] = v # type: ignore
192
+ return flattened # type: ignore
rl_algo_impls/runner/train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ import dataclasses
7
+ import shutil
8
+ import wandb
9
+ import yaml
10
+
11
+ from dataclasses import asdict, dataclass
12
+ from torch.utils.tensorboard.writer import SummaryWriter
13
+ from typing import Any, Dict, Optional, Sequence
14
+
15
+ from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
16
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
17
+ from rl_algo_impls.runner.env import make_env, make_eval_env
18
+ from rl_algo_impls.runner.running_utils import (
19
+ ALGOS,
20
+ load_hyperparams,
21
+ set_seeds,
22
+ get_device,
23
+ make_policy,
24
+ plot_eval_callback,
25
+ hparam_dict,
26
+ )
27
+ from rl_algo_impls.shared.stats import EpisodesStats
28
+
29
+
30
+ @dataclass
31
+ class TrainArgs(RunArgs):
32
+ wandb_project_name: Optional[str] = None
33
+ wandb_entity: Optional[str] = None
34
+ wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
35
+ wandb_group: Optional[str] = None
36
+
37
+
38
+ def train(args: TrainArgs):
39
+ print(args)
40
+ hyperparams = load_hyperparams(args.algo, args.env)
41
+ print(hyperparams)
42
+ config = Config(args, hyperparams, os.getcwd())
43
+
44
+ wandb_enabled = args.wandb_project_name
45
+ if wandb_enabled:
46
+ wandb.tensorboard.patch(
47
+ root_logdir=config.tensorboard_summary_path, pytorch=True
48
+ )
49
+ wandb.init(
50
+ project=args.wandb_project_name,
51
+ entity=args.wandb_entity,
52
+ config=asdict(hyperparams),
53
+ name=config.run_name(),
54
+ monitor_gym=True,
55
+ save_code=True,
56
+ tags=args.wandb_tags,
57
+ group=args.wandb_group,
58
+ )
59
+ wandb.config.update(args)
60
+
61
+ tb_writer = SummaryWriter(config.tensorboard_summary_path)
62
+
63
+ set_seeds(args.seed, args.use_deterministic_algorithms)
64
+
65
+ env = make_env(
66
+ config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
67
+ )
68
+ device = get_device(config.device, env)
69
+ policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
70
+ algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
71
+
72
+ num_parameters = policy.num_parameters()
73
+ num_trainable_parameters = policy.num_trainable_parameters()
74
+ if wandb_enabled:
75
+ wandb.run.summary["num_parameters"] = num_parameters
76
+ wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters
77
+ else:
78
+ print(
79
+ f"num_parameters = {num_parameters} ; "
80
+ f"num_trainable_parameters = {num_trainable_parameters}"
81
+ )
82
+
83
+ eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
84
+ record_best_videos = config.eval_params.get("record_best_videos", True)
85
+ callback = EvalCallback(
86
+ policy,
87
+ eval_env,
88
+ tb_writer,
89
+ best_model_path=config.model_dir_path(best=True),
90
+ **config.eval_params,
91
+ video_env=make_eval_env(
92
+ config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
93
+ )
94
+ if record_best_videos
95
+ else None,
96
+ best_video_dir=config.best_videos_dir,
97
+ )
98
+ algo.learn(config.n_timesteps, callback=callback)
99
+
100
+ policy.save(config.model_dir_path(best=False))
101
+
102
+ eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
103
+
104
+ plot_eval_callback(callback, tb_writer, config.run_name())
105
+
106
+ log_dict: Dict[str, Any] = {
107
+ "eval": eval_stats._asdict(),
108
+ }
109
+ if callback.best:
110
+ log_dict["best_eval"] = callback.best._asdict()
111
+ log_dict.update(asdict(hyperparams))
112
+ log_dict.update(vars(args))
113
+ with open(config.logs_path, "a") as f:
114
+ yaml.dump({config.run_name(): log_dict}, f)
115
+
116
+ best_eval_stats: EpisodesStats = callback.best # type: ignore
117
+ tb_writer.add_hparams(
118
+ hparam_dict(hyperparams, vars(args)),
119
+ {
120
+ "hparam/best_mean": best_eval_stats.score.mean,
121
+ "hparam/best_result": best_eval_stats.score.mean
122
+ - best_eval_stats.score.std,
123
+ "hparam/last_mean": eval_stats.score.mean,
124
+ "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
125
+ },
126
+ None,
127
+ config.run_name(),
128
+ )
129
+
130
+ tb_writer.close()
131
+
132
+ if wandb_enabled:
133
+ shutil.make_archive(
134
+ os.path.join(wandb.run.dir, config.model_dir_name()),
135
+ "zip",
136
+ config.model_dir_path(),
137
+ )
138
+ shutil.make_archive(
139
+ os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
140
+ "zip",
141
+ config.model_dir_path(best=True),
142
+ )
143
+ wandb.finish()
rl_algo_impls/shared/algorithm.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+
4
+ from abc import ABC, abstractmethod
5
+ from torch.utils.tensorboard.writer import SummaryWriter
6
+ from typing import Optional, TypeVar
7
+
8
+ from rl_algo_impls.shared.callbacks.callback import Callback
9
+ from rl_algo_impls.shared.policy.policy import Policy
10
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
11
+
12
+ AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
13
+
14
+
15
+ class Algorithm(ABC):
16
+ @abstractmethod
17
+ def __init__(
18
+ self,
19
+ policy: Policy,
20
+ env: VecEnv,
21
+ device: torch.device,
22
+ tb_writer: SummaryWriter,
23
+ **kwargs,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.policy = policy
27
+ self.env = env
28
+ self.device = device
29
+ self.tb_writer = tb_writer
30
+
31
+ @abstractmethod
32
+ def learn(
33
+ self: AlgorithmSelf,
34
+ train_timesteps: int,
35
+ callback: Optional[Callback] = None,
36
+ total_timesteps: Optional[int] = None,
37
+ start_timesteps: int = 0,
38
+ ) -> AlgorithmSelf:
39
+ ...
rl_algo_impls/shared/callbacks/callback.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+
4
+ class Callback(ABC):
5
+ def __init__(self) -> None:
6
+ super().__init__()
7
+ self.timesteps_elapsed = 0
8
+
9
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
10
+ self.timesteps_elapsed += timesteps_elapsed
11
+ return True
rl_algo_impls/shared/callbacks/eval_callback.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import numpy as np
3
+ import os
4
+
5
+ from time import perf_counter
6
+ from torch.utils.tensorboard.writer import SummaryWriter
7
+ from typing import List, Optional, Union
8
+
9
+ from rl_algo_impls.shared.callbacks.callback import Callback
10
+ from rl_algo_impls.shared.policy.policy import Policy
11
+ from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
12
+ from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
13
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
14
+
15
+
16
+ class EvaluateAccumulator(EpisodeAccumulator):
17
+ def __init__(
18
+ self,
19
+ num_envs: int,
20
+ goal_episodes: int,
21
+ print_returns: bool = True,
22
+ ignore_first_episode: bool = False,
23
+ ):
24
+ super().__init__(num_envs)
25
+ self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
26
+ self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
27
+ self.print_returns = print_returns
28
+ if ignore_first_episode:
29
+ first_done = set()
30
+
31
+ def should_record_done(idx: int) -> bool:
32
+ has_done_first_episode = idx in first_done
33
+ first_done.add(idx)
34
+ return has_done_first_episode
35
+
36
+ self.should_record_done = should_record_done
37
+ else:
38
+ self.should_record_done = lambda idx: True
39
+
40
+ def on_done(self, ep_idx: int, episode: Episode) -> None:
41
+ if (
42
+ self.should_record_done(ep_idx)
43
+ and len(self.completed_episodes_by_env_idx[ep_idx])
44
+ >= self.goal_episodes_per_env
45
+ ):
46
+ return
47
+ self.completed_episodes_by_env_idx[ep_idx].append(episode)
48
+ if self.print_returns:
49
+ print(
50
+ f"Episode {len(self)} | "
51
+ f"Score {episode.score} | "
52
+ f"Length {episode.length}"
53
+ )
54
+
55
+ def __len__(self) -> int:
56
+ return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
57
+
58
+ @property
59
+ def episodes(self) -> List[Episode]:
60
+ return list(itertools.chain(*self.completed_episodes_by_env_idx))
61
+
62
+ def is_done(self) -> bool:
63
+ return all(
64
+ len(ce) == self.goal_episodes_per_env
65
+ for ce in self.completed_episodes_by_env_idx
66
+ )
67
+
68
+
69
+ def evaluate(
70
+ env: VecEnv,
71
+ policy: Policy,
72
+ n_episodes: int,
73
+ render: bool = False,
74
+ deterministic: bool = True,
75
+ print_returns: bool = True,
76
+ ignore_first_episode: bool = False,
77
+ ) -> EpisodesStats:
78
+ policy.sync_normalization(env)
79
+ policy.eval()
80
+
81
+ episodes = EvaluateAccumulator(
82
+ env.num_envs, n_episodes, print_returns, ignore_first_episode
83
+ )
84
+
85
+ obs = env.reset()
86
+ while not episodes.is_done():
87
+ act = policy.act(obs, deterministic=deterministic)
88
+ obs, rew, done, _ = env.step(act)
89
+ episodes.step(rew, done)
90
+ if render:
91
+ env.render()
92
+ stats = EpisodesStats(episodes.episodes)
93
+ if print_returns:
94
+ print(stats)
95
+ return stats
96
+
97
+
98
+ class EvalCallback(Callback):
99
+ def __init__(
100
+ self,
101
+ policy: Policy,
102
+ env: VecEnv,
103
+ tb_writer: SummaryWriter,
104
+ best_model_path: Optional[str] = None,
105
+ step_freq: Union[int, float] = 50_000,
106
+ n_episodes: int = 10,
107
+ save_best: bool = True,
108
+ deterministic: bool = True,
109
+ record_best_videos: bool = True,
110
+ video_env: Optional[VecEnv] = None,
111
+ best_video_dir: Optional[str] = None,
112
+ max_video_length: int = 3600,
113
+ ignore_first_episode: bool = False,
114
+ ) -> None:
115
+ super().__init__()
116
+ self.policy = policy
117
+ self.env = env
118
+ self.tb_writer = tb_writer
119
+ self.best_model_path = best_model_path
120
+ self.step_freq = int(step_freq)
121
+ self.n_episodes = n_episodes
122
+ self.save_best = save_best
123
+ self.deterministic = deterministic
124
+ self.stats: List[EpisodesStats] = []
125
+ self.best = None
126
+
127
+ self.record_best_videos = record_best_videos
128
+ assert video_env or not record_best_videos
129
+ self.video_env = video_env
130
+ assert best_video_dir or not record_best_videos
131
+ self.best_video_dir = best_video_dir
132
+ if best_video_dir:
133
+ os.makedirs(best_video_dir, exist_ok=True)
134
+ self.max_video_length = max_video_length
135
+ self.best_video_base_path = None
136
+
137
+ self.ignore_first_episode = ignore_first_episode
138
+
139
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
140
+ super().on_step(timesteps_elapsed)
141
+ if self.timesteps_elapsed // self.step_freq >= len(self.stats):
142
+ self.evaluate()
143
+ return True
144
+
145
+ def evaluate(
146
+ self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
147
+ ) -> EpisodesStats:
148
+ start_time = perf_counter()
149
+ eval_stat = evaluate(
150
+ self.env,
151
+ self.policy,
152
+ n_episodes or self.n_episodes,
153
+ deterministic=self.deterministic,
154
+ print_returns=print_returns or False,
155
+ ignore_first_episode=self.ignore_first_episode,
156
+ )
157
+ end_time = perf_counter()
158
+ self.tb_writer.add_scalar(
159
+ "eval/steps_per_second",
160
+ eval_stat.length.sum() / (end_time - start_time),
161
+ self.timesteps_elapsed,
162
+ )
163
+ self.policy.train(True)
164
+ print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
165
+
166
+ self.stats.append(eval_stat)
167
+
168
+ if not self.best or eval_stat >= self.best:
169
+ strictly_better = not self.best or eval_stat > self.best
170
+ self.best = eval_stat
171
+ if self.save_best:
172
+ assert self.best_model_path
173
+ self.policy.save(self.best_model_path)
174
+ print("Saved best model")
175
+ self.best.write_to_tensorboard(
176
+ self.tb_writer, "best_eval", self.timesteps_elapsed
177
+ )
178
+ if strictly_better and self.record_best_videos:
179
+ assert self.video_env and self.best_video_dir
180
+ self.best_video_base_path = os.path.join(
181
+ self.best_video_dir, str(self.timesteps_elapsed)
182
+ )
183
+ video_wrapped = VecEpisodeRecorder(
184
+ self.video_env,
185
+ self.best_video_base_path,
186
+ max_video_length=self.max_video_length,
187
+ )
188
+ video_stats = evaluate(
189
+ video_wrapped,
190
+ self.policy,
191
+ 1,
192
+ deterministic=self.deterministic,
193
+ print_returns=False,
194
+ )
195
+ print(f"Saved best video: {video_stats}")
196
+
197
+ eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
198
+
199
+ return eval_stat
rl_algo_impls/shared/callbacks/optimize_callback.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import optuna
3
+
4
+ from time import perf_counter
5
+ from torch.utils.tensorboard.writer import SummaryWriter
6
+ from typing import NamedTuple, Union
7
+
8
+ from rl_algo_impls.shared.callbacks.callback import Callback
9
+ from rl_algo_impls.shared.callbacks.eval_callback import evaluate
10
+ from rl_algo_impls.shared.policy.policy import Policy
11
+ from rl_algo_impls.shared.stats import EpisodesStats
12
+ from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
13
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, find_wrapper
14
+
15
+
16
+ class Evaluation(NamedTuple):
17
+ eval_stat: EpisodesStats
18
+ train_stat: EpisodesStats
19
+ score: float
20
+
21
+
22
+ class OptimizeCallback(Callback):
23
+ def __init__(
24
+ self,
25
+ policy: Policy,
26
+ env: VecEnv,
27
+ trial: optuna.Trial,
28
+ tb_writer: SummaryWriter,
29
+ step_freq: Union[int, float] = 50_000,
30
+ n_episodes: int = 10,
31
+ deterministic: bool = True,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.policy = policy
35
+ self.env = env
36
+ self.trial = trial
37
+ self.tb_writer = tb_writer
38
+ self.step_freq = step_freq
39
+ self.n_episodes = n_episodes
40
+ self.deterministic = deterministic
41
+
42
+ stats_writer = find_wrapper(policy.env, EpisodeStatsWriter)
43
+ assert stats_writer
44
+ self.stats_writer = stats_writer
45
+
46
+ self.eval_step = 1
47
+ self.is_pruned = False
48
+ self.last_eval_stat = None
49
+ self.last_train_stat = None
50
+ self.last_score = -np.inf
51
+
52
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
53
+ super().on_step(timesteps_elapsed)
54
+ if self.timesteps_elapsed >= self.eval_step * self.step_freq:
55
+ self.evaluate()
56
+ return not self.is_pruned
57
+ return True
58
+
59
+ def evaluate(self) -> None:
60
+ self.last_eval_stat, self.last_train_stat, score = evaluation(
61
+ self.policy,
62
+ self.env,
63
+ self.tb_writer,
64
+ self.n_episodes,
65
+ self.deterministic,
66
+ self.timesteps_elapsed,
67
+ )
68
+ self.last_score = score
69
+
70
+ self.trial.report(score, self.eval_step)
71
+ if self.trial.should_prune():
72
+ self.is_pruned = True
73
+
74
+ self.eval_step += 1
75
+
76
+
77
+ def evaluation(
78
+ policy: Policy,
79
+ env: VecEnv,
80
+ tb_writer: SummaryWriter,
81
+ n_episodes: int,
82
+ deterministic: bool,
83
+ timesteps_elapsed: int,
84
+ ) -> Evaluation:
85
+ start_time = perf_counter()
86
+ eval_stat = evaluate(
87
+ env,
88
+ policy,
89
+ n_episodes,
90
+ deterministic=deterministic,
91
+ print_returns=False,
92
+ )
93
+ end_time = perf_counter()
94
+ tb_writer.add_scalar(
95
+ "eval/steps_per_second",
96
+ eval_stat.length.sum() / (end_time - start_time),
97
+ timesteps_elapsed,
98
+ )
99
+ policy.train()
100
+ print(f"Eval Timesteps: {timesteps_elapsed} | {eval_stat}")
101
+ eval_stat.write_to_tensorboard(tb_writer, "eval", timesteps_elapsed)
102
+
103
+ stats_writer = find_wrapper(policy.env, EpisodeStatsWriter)
104
+ assert stats_writer
105
+
106
+ train_stat = EpisodesStats(stats_writer.episodes)
107
+ print(f" Train Stat: {train_stat}")
108
+
109
+ score = (eval_stat.score.mean + train_stat.score.mean) / 2
110
+ print(f" Score: {round(score, 2)}")
111
+ tb_writer.add_scalar(
112
+ "eval/score",
113
+ score,
114
+ timesteps_elapsed,
115
+ )
116
+
117
+ return Evaluation(eval_stat, train_stat, score)
rl_algo_impls/shared/gae.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from typing import NamedTuple, Sequence
5
+
6
+ from rl_algo_impls.shared.policy.on_policy import OnPolicy
7
+ from rl_algo_impls.shared.trajectory import Trajectory
8
+
9
+
10
+ class RtgAdvantage(NamedTuple):
11
+ rewards_to_go: torch.Tensor
12
+ advantage: torch.Tensor
13
+
14
+
15
+ def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
16
+ dc = x.copy()
17
+ for i in reversed(range(len(x) - 1)):
18
+ dc[i] += gamma * dc[i + 1]
19
+ return dc
20
+
21
+
22
+ def compute_advantage(
23
+ trajectories: Sequence[Trajectory],
24
+ policy: OnPolicy,
25
+ gamma: float,
26
+ gae_lambda: float,
27
+ device: torch.device,
28
+ ) -> torch.Tensor:
29
+ advantage = []
30
+ for traj in trajectories:
31
+ last_val = 0
32
+ if not traj.terminated and traj.next_obs is not None:
33
+ last_val = policy.value(traj.next_obs)
34
+ rew = np.append(np.array(traj.rew), last_val)
35
+ v = np.append(np.array(traj.v), last_val)
36
+ deltas = rew[:-1] + gamma * v[1:] - v[:-1]
37
+ advantage.append(discounted_cumsum(deltas, gamma * gae_lambda))
38
+ return torch.as_tensor(
39
+ np.concatenate(advantage), dtype=torch.float32, device=device
40
+ )
41
+
42
+
43
+ def compute_rtg_and_advantage(
44
+ trajectories: Sequence[Trajectory],
45
+ policy: OnPolicy,
46
+ gamma: float,
47
+ gae_lambda: float,
48
+ device: torch.device,
49
+ ) -> RtgAdvantage:
50
+ rewards_to_go = []
51
+ advantages = []
52
+ for traj in trajectories:
53
+ last_val = 0
54
+ if not traj.terminated and traj.next_obs is not None:
55
+ last_val = policy.value(traj.next_obs)
56
+ rew = np.append(np.array(traj.rew), last_val)
57
+ v = np.append(np.array(traj.v), last_val)
58
+ deltas = rew[:-1] + gamma * v[1:] - v[:-1]
59
+ adv = discounted_cumsum(deltas, gamma * gae_lambda)
60
+ advantages.append(adv)
61
+ rewards_to_go.append(v[:-1] + adv)
62
+ return RtgAdvantage(
63
+ torch.as_tensor(
64
+ np.concatenate(rewards_to_go), dtype=torch.float32, device=device
65
+ ),
66
+ torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
67
+ )
rl_algo_impls/shared/module/feature_extractor.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from abc import ABC, abstractmethod
7
+ from gym.spaces import Box, Discrete
8
+ from stable_baselines3.common.preprocessing import get_flattened_obs_dim
9
+ from typing import Dict, Optional, Sequence, Type
10
+
11
+ from rl_algo_impls.shared.module.module import layer_init
12
+
13
+
14
+ class CnnFeatureExtractor(nn.Module, ABC):
15
+ @abstractmethod
16
+ def __init__(
17
+ self,
18
+ in_channels: int,
19
+ activation: Type[nn.Module] = nn.ReLU,
20
+ init_layers_orthogonal: Optional[bool] = None,
21
+ **kwargs,
22
+ ) -> None:
23
+ super().__init__()
24
+
25
+
26
+ class NatureCnn(CnnFeatureExtractor):
27
+ """
28
+ CNN from DQN Nature paper: Mnih, Volodymyr, et al.
29
+ "Human-level control through deep reinforcement learning."
30
+ Nature 518.7540 (2015): 529-533.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ in_channels: int,
36
+ activation: Type[nn.Module] = nn.ReLU,
37
+ init_layers_orthogonal: Optional[bool] = None,
38
+ **kwargs,
39
+ ) -> None:
40
+ if init_layers_orthogonal is None:
41
+ init_layers_orthogonal = True
42
+ super().__init__(in_channels, activation, init_layers_orthogonal)
43
+ self.cnn = nn.Sequential(
44
+ layer_init(
45
+ nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
46
+ init_layers_orthogonal,
47
+ ),
48
+ activation(),
49
+ layer_init(
50
+ nn.Conv2d(32, 64, kernel_size=4, stride=2),
51
+ init_layers_orthogonal,
52
+ ),
53
+ activation(),
54
+ layer_init(
55
+ nn.Conv2d(64, 64, kernel_size=3, stride=1),
56
+ init_layers_orthogonal,
57
+ ),
58
+ activation(),
59
+ nn.Flatten(),
60
+ )
61
+
62
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
63
+ return self.cnn(obs)
64
+
65
+
66
+ class ResidualBlock(nn.Module):
67
+ def __init__(
68
+ self,
69
+ channels: int,
70
+ activation: Type[nn.Module] = nn.ReLU,
71
+ init_layers_orthogonal: bool = False,
72
+ ) -> None:
73
+ super().__init__()
74
+ self.residual = nn.Sequential(
75
+ activation(),
76
+ layer_init(
77
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
78
+ ),
79
+ activation(),
80
+ layer_init(
81
+ nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
82
+ ),
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return x + self.residual(x)
87
+
88
+
89
+ class ConvSequence(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels: int,
93
+ out_channels: int,
94
+ activation: Type[nn.Module] = nn.ReLU,
95
+ init_layers_orthogonal: bool = False,
96
+ ) -> None:
97
+ super().__init__()
98
+ self.seq = nn.Sequential(
99
+ layer_init(
100
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
101
+ init_layers_orthogonal,
102
+ ),
103
+ nn.MaxPool2d(3, stride=2, padding=1),
104
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
105
+ ResidualBlock(out_channels, activation, init_layers_orthogonal),
106
+ )
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ return self.seq(x)
110
+
111
+
112
+ class ImpalaCnn(CnnFeatureExtractor):
113
+ """
114
+ IMPALA-style CNN architecture
115
+ """
116
+
117
+ def __init__(
118
+ self,
119
+ in_channels: int,
120
+ activation: Type[nn.Module] = nn.ReLU,
121
+ init_layers_orthogonal: Optional[bool] = None,
122
+ impala_channels: Sequence[int] = (16, 32, 32),
123
+ **kwargs,
124
+ ) -> None:
125
+ if init_layers_orthogonal is None:
126
+ init_layers_orthogonal = False
127
+ super().__init__(in_channels, activation, init_layers_orthogonal)
128
+ sequences = []
129
+ for out_channels in impala_channels:
130
+ sequences.append(
131
+ ConvSequence(
132
+ in_channels, out_channels, activation, init_layers_orthogonal
133
+ )
134
+ )
135
+ in_channels = out_channels
136
+ sequences.extend(
137
+ [
138
+ activation(),
139
+ nn.Flatten(),
140
+ ]
141
+ )
142
+ self.seq = nn.Sequential(*sequences)
143
+
144
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
145
+ return self.seq(obs)
146
+
147
+
148
+ CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
149
+ "nature": NatureCnn,
150
+ "impala": ImpalaCnn,
151
+ }
152
+
153
+
154
+ class FeatureExtractor(nn.Module):
155
+ def __init__(
156
+ self,
157
+ obs_space: gym.Space,
158
+ activation: Type[nn.Module],
159
+ init_layers_orthogonal: bool = False,
160
+ cnn_feature_dim: int = 512,
161
+ cnn_style: str = "nature",
162
+ cnn_layers_init_orthogonal: Optional[bool] = None,
163
+ impala_channels: Sequence[int] = (16, 32, 32),
164
+ ) -> None:
165
+ super().__init__()
166
+ if isinstance(obs_space, Box):
167
+ # Conv2D: (channels, height, width)
168
+ if len(obs_space.shape) == 3:
169
+ cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
170
+ obs_space.shape[0],
171
+ activation,
172
+ init_layers_orthogonal=cnn_layers_init_orthogonal,
173
+ impala_channels=impala_channels,
174
+ )
175
+
176
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
177
+ if len(obs.shape) == 3:
178
+ obs = obs.unsqueeze(0)
179
+ return obs.float() / 255.0
180
+
181
+ with torch.no_grad():
182
+ cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
183
+ self.preprocess = preprocess
184
+ self.feature_extractor = nn.Sequential(
185
+ cnn,
186
+ layer_init(
187
+ nn.Linear(cnn_out.shape[1], cnn_feature_dim),
188
+ init_layers_orthogonal,
189
+ ),
190
+ activation(),
191
+ )
192
+ self.out_dim = cnn_feature_dim
193
+ elif len(obs_space.shape) == 1:
194
+
195
+ def preprocess(obs: torch.Tensor) -> torch.Tensor:
196
+ if len(obs.shape) == 1:
197
+ obs = obs.unsqueeze(0)
198
+ return obs.float()
199
+
200
+ self.preprocess = preprocess
201
+ self.feature_extractor = nn.Flatten()
202
+ self.out_dim = get_flattened_obs_dim(obs_space)
203
+ else:
204
+ raise ValueError(f"Unsupported observation space: {obs_space}")
205
+ elif isinstance(obs_space, Discrete):
206
+ self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
207
+ self.feature_extractor = nn.Flatten()
208
+ self.out_dim = obs_space.n
209
+ else:
210
+ raise NotImplementedError
211
+
212
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
213
+ if self.preprocess:
214
+ obs = self.preprocess(obs)
215
+ return self.feature_extractor(obs)
rl_algo_impls/shared/module/module.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+
4
+ from typing import Sequence, Type
5
+
6
+
7
+ def mlp(
8
+ layer_sizes: Sequence[int],
9
+ activation: Type[nn.Module],
10
+ output_activation: Type[nn.Module] = nn.Identity,
11
+ init_layers_orthogonal: bool = False,
12
+ final_layer_gain: float = np.sqrt(2),
13
+ ) -> nn.Module:
14
+ layers = []
15
+ for i in range(len(layer_sizes) - 2):
16
+ layers.append(
17
+ layer_init(
18
+ nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
19
+ )
20
+ )
21
+ layers.append(activation())
22
+ layers.append(
23
+ layer_init(
24
+ nn.Linear(layer_sizes[-2], layer_sizes[-1]),
25
+ init_layers_orthogonal,
26
+ std=final_layer_gain,
27
+ )
28
+ )
29
+ layers.append(output_activation())
30
+ return nn.Sequential(*layers)
31
+
32
+
33
+ def layer_init(
34
+ layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
35
+ ) -> nn.Module:
36
+ if not init_layers_orthogonal:
37
+ return layer
38
+ nn.init.orthogonal_(layer.weight, std) # type: ignore
39
+ nn.init.constant_(layer.bias, 0.0) # type: ignore
40
+ return layer
rl_algo_impls/shared/policy/actor.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from abc import ABC, abstractmethod
6
+ from gym.spaces import Box, Discrete
7
+ from torch.distributions import Categorical, Distribution, Normal
8
+ from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
9
+
10
+ from rl_algo_impls.shared.module.module import mlp
11
+
12
+
13
+ class PiForward(NamedTuple):
14
+ pi: Distribution
15
+ logp_a: Optional[torch.Tensor]
16
+ entropy: Optional[torch.Tensor]
17
+
18
+
19
+ class Actor(nn.Module, ABC):
20
+ @abstractmethod
21
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
22
+ ...
23
+
24
+
25
+ class CategoricalActorHead(Actor):
26
+ def __init__(
27
+ self,
28
+ act_dim: int,
29
+ hidden_sizes: Sequence[int] = (32,),
30
+ activation: Type[nn.Module] = nn.Tanh,
31
+ init_layers_orthogonal: bool = True,
32
+ ) -> None:
33
+ super().__init__()
34
+ layer_sizes = tuple(hidden_sizes) + (act_dim,)
35
+ self._fc = mlp(
36
+ layer_sizes,
37
+ activation,
38
+ init_layers_orthogonal=init_layers_orthogonal,
39
+ final_layer_gain=0.01,
40
+ )
41
+
42
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
43
+ logits = self._fc(obs)
44
+ pi = Categorical(logits=logits)
45
+ logp_a = None
46
+ entropy = None
47
+ if a is not None:
48
+ logp_a = pi.log_prob(a)
49
+ entropy = pi.entropy()
50
+ return PiForward(pi, logp_a, entropy)
51
+
52
+
53
+ class GaussianDistribution(Normal):
54
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
55
+ return super().log_prob(a).sum(axis=-1)
56
+
57
+ def sample(self) -> torch.Tensor:
58
+ return self.rsample()
59
+
60
+
61
+ class GaussianActorHead(Actor):
62
+ def __init__(
63
+ self,
64
+ act_dim: int,
65
+ hidden_sizes: Sequence[int] = (32,),
66
+ activation: Type[nn.Module] = nn.Tanh,
67
+ init_layers_orthogonal: bool = True,
68
+ log_std_init: float = -0.5,
69
+ ) -> None:
70
+ super().__init__()
71
+ layer_sizes = tuple(hidden_sizes) + (act_dim,)
72
+ self.mu_net = mlp(
73
+ layer_sizes,
74
+ activation,
75
+ init_layers_orthogonal=init_layers_orthogonal,
76
+ final_layer_gain=0.01,
77
+ )
78
+ self.log_std = nn.Parameter(
79
+ torch.ones(act_dim, dtype=torch.float32) * log_std_init
80
+ )
81
+
82
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
83
+ mu = self.mu_net(obs)
84
+ std = torch.exp(self.log_std)
85
+ return GaussianDistribution(mu, std)
86
+
87
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
88
+ pi = self._distribution(obs)
89
+ logp_a = None
90
+ entropy = None
91
+ if a is not None:
92
+ logp_a = pi.log_prob(a)
93
+ entropy = pi.entropy()
94
+ return PiForward(pi, logp_a, entropy)
95
+
96
+
97
+ class TanhBijector:
98
+ def __init__(self, epsilon: float = 1e-6) -> None:
99
+ self.epsilon = epsilon
100
+
101
+ @staticmethod
102
+ def forward(x: torch.Tensor) -> torch.Tensor:
103
+ return torch.tanh(x)
104
+
105
+ @staticmethod
106
+ def inverse(y: torch.Tensor) -> torch.Tensor:
107
+ eps = torch.finfo(y.dtype).eps
108
+ clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps)
109
+ return torch.atanh(clamped_y)
110
+
111
+ def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
112
+ return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
113
+
114
+
115
+ def sum_independent_dims(tensor: torch.Tensor) -> torch.Tensor:
116
+ if len(tensor.shape) > 1:
117
+ return tensor.sum(dim=1)
118
+ return tensor.sum()
119
+
120
+
121
+ class StateDependentNoiseDistribution(Normal):
122
+ def __init__(
123
+ self,
124
+ loc,
125
+ scale,
126
+ latent_sde: torch.Tensor,
127
+ exploration_mat: torch.Tensor,
128
+ exploration_matrices: torch.Tensor,
129
+ bijector: Optional[TanhBijector] = None,
130
+ validate_args=None,
131
+ ):
132
+ super().__init__(loc, scale, validate_args)
133
+ self.latent_sde = latent_sde
134
+ self.exploration_mat = exploration_mat
135
+ self.exploration_matrices = exploration_matrices
136
+ self.bijector = bijector
137
+
138
+ def log_prob(self, a: torch.Tensor) -> torch.Tensor:
139
+ gaussian_a = self.bijector.inverse(a) if self.bijector else a
140
+ log_prob = sum_independent_dims(super().log_prob(gaussian_a))
141
+ if self.bijector:
142
+ log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1)
143
+ return log_prob
144
+
145
+ def sample(self) -> torch.Tensor:
146
+ noise = self._get_noise()
147
+ actions = self.mean + noise
148
+ return self.bijector.forward(actions) if self.bijector else actions
149
+
150
+ def _get_noise(self) -> torch.Tensor:
151
+ if len(self.latent_sde) == 1 or len(self.latent_sde) != len(
152
+ self.exploration_matrices
153
+ ):
154
+ return torch.mm(self.latent_sde, self.exploration_mat)
155
+ # (batch_size, n_features) -> (batch_size, 1, n_features)
156
+ latent_sde = self.latent_sde.unsqueeze(dim=1)
157
+ # (batch_size, 1, n_actions)
158
+ noise = torch.bmm(latent_sde, self.exploration_matrices)
159
+ return noise.squeeze(dim=1)
160
+
161
+ @property
162
+ def mode(self) -> torch.Tensor:
163
+ mean = super().mode
164
+ return self.bijector.forward(mean) if self.bijector else mean
165
+
166
+
167
+ StateDependentNoiseActorHeadSelf = TypeVar(
168
+ "StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead"
169
+ )
170
+
171
+
172
+ class StateDependentNoiseActorHead(Actor):
173
+ def __init__(
174
+ self,
175
+ act_dim: int,
176
+ hidden_sizes: Sequence[int] = (32,),
177
+ activation: Type[nn.Module] = nn.Tanh,
178
+ init_layers_orthogonal: bool = True,
179
+ log_std_init: float = -0.5,
180
+ full_std: bool = True,
181
+ squash_output: bool = False,
182
+ learn_std: bool = False,
183
+ ) -> None:
184
+ super().__init__()
185
+ self.act_dim = act_dim
186
+ layer_sizes = tuple(hidden_sizes) + (self.act_dim,)
187
+ if len(layer_sizes) == 2:
188
+ self.latent_net = nn.Identity()
189
+ elif len(layer_sizes) > 2:
190
+ self.latent_net = mlp(
191
+ layer_sizes[:-1],
192
+ activation,
193
+ output_activation=activation,
194
+ init_layers_orthogonal=init_layers_orthogonal,
195
+ )
196
+ else:
197
+ raise ValueError("hidden_sizes must be of at least length 1")
198
+ self.mu_net = mlp(
199
+ layer_sizes[-2:],
200
+ activation,
201
+ init_layers_orthogonal=init_layers_orthogonal,
202
+ final_layer_gain=0.01,
203
+ )
204
+ self.full_std = full_std
205
+ std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1)
206
+ self.log_std = nn.Parameter(
207
+ torch.ones(std_dim, dtype=torch.float32) * log_std_init
208
+ )
209
+ self.bijector = TanhBijector() if squash_output else None
210
+ self.learn_std = learn_std
211
+ self.device = None
212
+
213
+ self.exploration_mat = None
214
+ self.exploration_matrices = None
215
+ self.sample_weights()
216
+
217
+ def to(
218
+ self: StateDependentNoiseActorHeadSelf,
219
+ device: Optional[torch.device] = None,
220
+ dtype: Optional[Union[torch.dtype, str]] = None,
221
+ non_blocking: bool = False,
222
+ ) -> StateDependentNoiseActorHeadSelf:
223
+ super().to(device, dtype, non_blocking)
224
+ self.device = device
225
+ return self
226
+
227
+ def _distribution(self, obs: torch.Tensor) -> Distribution:
228
+ latent = self.latent_net(obs)
229
+ mu = self.mu_net(latent)
230
+ latent_sde = latent if self.learn_std else latent.detach()
231
+ variance = torch.mm(latent_sde**2, self._get_std() ** 2)
232
+ assert self.exploration_mat is not None
233
+ assert self.exploration_matrices is not None
234
+ return StateDependentNoiseDistribution(
235
+ mu,
236
+ torch.sqrt(variance + 1e-6),
237
+ latent_sde,
238
+ self.exploration_mat,
239
+ self.exploration_matrices,
240
+ self.bijector,
241
+ )
242
+
243
+ def _get_std(self) -> torch.Tensor:
244
+ std = torch.exp(self.log_std)
245
+ if self.full_std:
246
+ return std
247
+ ones = torch.ones(self.log_std.shape[0], self.act_dim)
248
+ if self.device:
249
+ ones = ones.to(self.device)
250
+ return ones * std
251
+
252
+ def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
253
+ pi = self._distribution(obs)
254
+ logp_a = None
255
+ entropy = None
256
+ if a is not None:
257
+ logp_a = pi.log_prob(a)
258
+ entropy = -logp_a if self.bijector else sum_independent_dims(pi.entropy())
259
+ return PiForward(pi, logp_a, entropy)
260
+
261
+ def sample_weights(self, batch_size: int = 1) -> None:
262
+ std = self._get_std()
263
+ weights_dist = Normal(torch.zeros_like(std), std)
264
+ # Reparametrization trick to pass gradients
265
+ self.exploration_mat = weights_dist.rsample()
266
+ self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
267
+
268
+
269
+ def actor_head(
270
+ action_space: gym.Space,
271
+ hidden_sizes: Sequence[int],
272
+ init_layers_orthogonal: bool,
273
+ activation: Type[nn.Module],
274
+ log_std_init: float = -0.5,
275
+ use_sde: bool = False,
276
+ full_std: bool = True,
277
+ squash_output: bool = False,
278
+ ) -> Actor:
279
+ assert not use_sde or isinstance(
280
+ action_space, Box
281
+ ), "use_sde only valid if Box action_space"
282
+ assert not squash_output or use_sde, "squash_output only valid if use_sde"
283
+ if isinstance(action_space, Discrete):
284
+ return CategoricalActorHead(
285
+ action_space.n,
286
+ hidden_sizes=hidden_sizes,
287
+ activation=activation,
288
+ init_layers_orthogonal=init_layers_orthogonal,
289
+ )
290
+ elif isinstance(action_space, Box):
291
+ if use_sde:
292
+ return StateDependentNoiseActorHead(
293
+ action_space.shape[0],
294
+ hidden_sizes=hidden_sizes,
295
+ activation=activation,
296
+ init_layers_orthogonal=init_layers_orthogonal,
297
+ log_std_init=log_std_init,
298
+ full_std=full_std,
299
+ squash_output=squash_output,
300
+ )
301
+ else:
302
+ return GaussianActorHead(
303
+ action_space.shape[0],
304
+ hidden_sizes=hidden_sizes,
305
+ activation=activation,
306
+ init_layers_orthogonal=init_layers_orthogonal,
307
+ log_std_init=log_std_init,
308
+ )
309
+ else:
310
+ raise ValueError(f"Unsupported action space: {action_space}")
rl_algo_impls/shared/policy/critic.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from typing import Sequence, Type
6
+
7
+ from rl_algo_impls.shared.module.module import mlp
8
+
9
+
10
+ class CriticHead(nn.Module):
11
+ def __init__(
12
+ self,
13
+ hidden_sizes: Sequence[int] = (32,),
14
+ activation: Type[nn.Module] = nn.Tanh,
15
+ init_layers_orthogonal: bool = True,
16
+ ) -> None:
17
+ super().__init__()
18
+ layer_sizes = tuple(hidden_sizes) + (1,)
19
+ self._fc = mlp(
20
+ layer_sizes,
21
+ activation,
22
+ init_layers_orthogonal=init_layers_orthogonal,
23
+ final_layer_gain=1.0,
24
+ )
25
+
26
+ def forward(self, obs: torch.Tensor) -> torch.Tensor:
27
+ v = self._fc(obs)
28
+ return v.squeeze(-1)
rl_algo_impls/shared/policy/on_policy.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import numpy as np
3
+ import torch
4
+
5
+ from abc import abstractmethod
6
+ from gym.spaces import Box, Discrete, Space
7
+ from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
8
+
9
+ from rl_algo_impls.shared.module.feature_extractor import FeatureExtractor
10
+ from rl_algo_impls.shared.policy.actor import (
11
+ PiForward,
12
+ StateDependentNoiseActorHead,
13
+ actor_head,
14
+ )
15
+ from rl_algo_impls.shared.policy.critic import CriticHead
16
+ from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
17
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
18
+ VecEnv,
19
+ VecEnvObs,
20
+ single_observation_space,
21
+ single_action_space,
22
+ )
23
+
24
+
25
+ class Step(NamedTuple):
26
+ a: np.ndarray
27
+ v: np.ndarray
28
+ logp_a: np.ndarray
29
+ clamped_a: np.ndarray
30
+
31
+
32
+ class ACForward(NamedTuple):
33
+ logp_a: torch.Tensor
34
+ entropy: torch.Tensor
35
+ v: torch.Tensor
36
+
37
+
38
+ FEAT_EXT_FILE_NAME = "feat_ext.pt"
39
+ V_FEAT_EXT_FILE_NAME = "v_feat_ext.pt"
40
+ PI_FILE_NAME = "pi.pt"
41
+ V_FILE_NAME = "v.pt"
42
+ ActorCriticSelf = TypeVar("ActorCriticSelf", bound="ActorCritic")
43
+
44
+
45
+ def clamp_actions(
46
+ actions: np.ndarray, action_space: gym.Space, squash_output: bool
47
+ ) -> np.ndarray:
48
+ if isinstance(action_space, Box):
49
+ low, high = action_space.low, action_space.high # type: ignore
50
+ if squash_output:
51
+ # Squashed output is already between -1 and 1. Rescale if the actual
52
+ # output needs to something other than -1 and 1
53
+ return low + 0.5 * (actions + 1) * (high - low)
54
+ else:
55
+ return np.clip(actions, low, high)
56
+ return actions
57
+
58
+
59
+ def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
60
+ if isinstance(obs_space, Box):
61
+ if len(obs_space.shape) == 3:
62
+ # By default feature extractor to output has no hidden layers
63
+ return []
64
+ elif len(obs_space.shape) == 1:
65
+ return [64, 64]
66
+ else:
67
+ raise ValueError(f"Unsupported observation space: {obs_space}")
68
+ elif isinstance(obs_space, Discrete):
69
+ return [64]
70
+ else:
71
+ raise ValueError(f"Unsupported observation space: {obs_space}")
72
+
73
+
74
+ class OnPolicy(Policy):
75
+ @abstractmethod
76
+ def value(self, obs: VecEnvObs) -> np.ndarray:
77
+ ...
78
+
79
+ @abstractmethod
80
+ def step(self, obs: VecEnvObs) -> Step:
81
+ ...
82
+
83
+
84
+ class ActorCritic(OnPolicy):
85
+ def __init__(
86
+ self,
87
+ env: VecEnv,
88
+ pi_hidden_sizes: Optional[Sequence[int]] = None,
89
+ v_hidden_sizes: Optional[Sequence[int]] = None,
90
+ init_layers_orthogonal: bool = True,
91
+ activation_fn: str = "tanh",
92
+ log_std_init: float = -0.5,
93
+ use_sde: bool = False,
94
+ full_std: bool = True,
95
+ squash_output: bool = False,
96
+ share_features_extractor: bool = True,
97
+ cnn_feature_dim: int = 512,
98
+ cnn_style: str = "nature",
99
+ cnn_layers_init_orthogonal: Optional[bool] = None,
100
+ impala_channels: Sequence[int] = (16, 32, 32),
101
+ **kwargs,
102
+ ) -> None:
103
+ super().__init__(env, **kwargs)
104
+
105
+ observation_space = single_observation_space(env)
106
+ action_space = single_action_space(env)
107
+
108
+ pi_hidden_sizes = (
109
+ pi_hidden_sizes
110
+ if pi_hidden_sizes is not None
111
+ else default_hidden_sizes(observation_space)
112
+ )
113
+ v_hidden_sizes = (
114
+ v_hidden_sizes
115
+ if v_hidden_sizes is not None
116
+ else default_hidden_sizes(observation_space)
117
+ )
118
+
119
+ activation = ACTIVATION[activation_fn]
120
+ self.action_space = action_space
121
+ self.squash_output = squash_output
122
+ self.share_features_extractor = share_features_extractor
123
+ self._feature_extractor = FeatureExtractor(
124
+ observation_space,
125
+ activation,
126
+ init_layers_orthogonal=init_layers_orthogonal,
127
+ cnn_feature_dim=cnn_feature_dim,
128
+ cnn_style=cnn_style,
129
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
130
+ impala_channels=impala_channels,
131
+ )
132
+ self._pi = actor_head(
133
+ self.action_space,
134
+ (self._feature_extractor.out_dim,) + tuple(pi_hidden_sizes),
135
+ init_layers_orthogonal,
136
+ activation,
137
+ log_std_init=log_std_init,
138
+ use_sde=use_sde,
139
+ full_std=full_std,
140
+ squash_output=squash_output,
141
+ )
142
+
143
+ if not share_features_extractor:
144
+ self._v_feature_extractor = FeatureExtractor(
145
+ observation_space,
146
+ activation,
147
+ init_layers_orthogonal=init_layers_orthogonal,
148
+ cnn_feature_dim=cnn_feature_dim,
149
+ cnn_style=cnn_style,
150
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
151
+ )
152
+ v_hidden_sizes = (self._v_feature_extractor.out_dim,) + tuple(
153
+ v_hidden_sizes
154
+ )
155
+ else:
156
+ self._v_feature_extractor = None
157
+ v_hidden_sizes = (self._feature_extractor.out_dim,) + tuple(v_hidden_sizes)
158
+ self._v = CriticHead(
159
+ hidden_sizes=v_hidden_sizes,
160
+ activation=activation,
161
+ init_layers_orthogonal=init_layers_orthogonal,
162
+ )
163
+
164
+ def _pi_forward(
165
+ self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
166
+ ) -> Tuple[PiForward, torch.Tensor]:
167
+ p_fe = self._feature_extractor(obs)
168
+ pi_forward = self._pi(p_fe, action)
169
+
170
+ return pi_forward, p_fe
171
+
172
+ def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
173
+ v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
174
+ return self._v(v_fe)
175
+
176
+ def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
177
+ (_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
178
+ v = self._v_forward(obs, p_fc)
179
+
180
+ assert logp_a is not None
181
+ assert entropy is not None
182
+ return ACForward(logp_a, entropy, v)
183
+
184
+ def value(self, obs: VecEnvObs) -> np.ndarray:
185
+ o = self._as_tensor(obs)
186
+ with torch.no_grad():
187
+ fe = (
188
+ self._v_feature_extractor(o)
189
+ if self._v_feature_extractor
190
+ else self._feature_extractor(o)
191
+ )
192
+ v = self._v(fe)
193
+ return v.cpu().numpy()
194
+
195
+ def step(self, obs: VecEnvObs) -> Step:
196
+ o = self._as_tensor(obs)
197
+ with torch.no_grad():
198
+ (pi, _, _), p_fc = self._pi_forward(o)
199
+ a = pi.sample()
200
+ logp_a = pi.log_prob(a)
201
+
202
+ v = self._v_forward(o, p_fc)
203
+
204
+ a_np = a.cpu().numpy()
205
+ clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
206
+ return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
207
+
208
+ def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
209
+ if not deterministic:
210
+ return self.step(obs).clamped_a
211
+ else:
212
+ o = self._as_tensor(obs)
213
+ with torch.no_grad():
214
+ (pi, _, _), _ = self._pi_forward(o)
215
+ a = pi.mode
216
+ return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
217
+
218
+ def load(self, path: str) -> None:
219
+ super().load(path)
220
+ self.reset_noise()
221
+
222
+ def reset_noise(self, batch_size: Optional[int] = None) -> None:
223
+ if isinstance(self._pi, StateDependentNoiseActorHead):
224
+ self._pi.sample_weights(
225
+ batch_size=batch_size if batch_size else self.env.num_envs
226
+ )
rl_algo_impls/shared/policy/optimize_on_policy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+
3
+ from gym.spaces import Box
4
+ from typing import Any, Dict
5
+
6
+ from rl_algo_impls.wrappers.vectorable_wrapper import (
7
+ VecEnv,
8
+ single_action_space,
9
+ )
10
+
11
+
12
+ def sample_on_policy_hyperparams(
13
+ trial: optuna.Trial, policy_hparams: Dict[str, Any], env: VecEnv
14
+ ) -> Dict[str, Any]:
15
+ act_space = single_action_space(env)
16
+
17
+ policy_hparams["init_layers_orthogonal"] = trial.suggest_categorical(
18
+ "init_layers_orthogonal", [True, False]
19
+ )
20
+ policy_hparams["activation_fn"] = trial.suggest_categorical(
21
+ "activation_fn", ["tanh", "relu"]
22
+ )
23
+
24
+ if isinstance(act_space, Box):
25
+ policy_hparams["log_std_init"] = trial.suggest_float("log_std_init", -5, 0.5)
26
+ policy_hparams["use_sde"] = trial.suggest_categorical("use_sde", [False, True])
27
+
28
+ if policy_hparams.get("use_sde", False):
29
+ policy_hparams["squash_output"] = trial.suggest_categorical(
30
+ "squash_output", [False, True]
31
+ )
32
+ elif "squash_output" in policy_hparams:
33
+ del policy_hparams["squash_output"]
34
+
35
+ return policy_hparams