sgoodfriend
commited on
Commit
•
0c44e35
1
Parent(s):
7f09aac
A2C playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- a2c-AntBulletEnv-v0.zip +0 -3
- a2c-AntBulletEnv-v0/_stable_baselines3_version +0 -1
- a2c-AntBulletEnv-v0/data +0 -109
- a2c-AntBulletEnv-v0/policy.optimizer.pth +0 -3
- a2c-AntBulletEnv-v0/policy.pth +0 -3
- a2c-AntBulletEnv-v0/pytorch_variables.pth +0 -3
- a2c-AntBulletEnv-v0/system_info.txt +0 -7
- a2c/a2c.py +0 -201
- benchmarks/benchmark_test.sh +0 -32
- benchmarks/colab_atari1.sh +0 -5
- benchmarks/colab_atari2.sh +0 -5
- benchmarks/colab_basic.sh +0 -5
- benchmarks/colab_benchmark.ipynb +0 -195
- benchmarks/colab_carracing.sh +0 -5
- benchmarks/colab_pybullet.sh +0 -5
- benchmarks/train_loop.sh +0 -15
- colab_enjoy.ipynb +0 -198
- colab_requirements.txt +0 -14
- colab_train.ipynb +0 -200
- config.json +0 -1
- dqn/dqn.py +0 -182
- dqn/policy.py +0 -52
- dqn/q_net.py +0 -41
- hyperparams/a2c.yml +0 -127
- hyperparams/dqn.yml +0 -130
- hyperparams/ppo.yml +0 -383
- hyperparams/vpg.yml +0 -197
- lambda_labs/benchmark.sh +0 -34
- lambda_labs/impala_atari_benchmark.sh +0 -19
- lambda_labs/lambda_requirements.txt +0 -16
- lambda_labs/procgen_benchmark.sh +0 -18
- lambda_labs/setup.sh +0 -10
- lambda_labs/starpilot_hard_benchmark.sh +0 -16
- poetry.lock +0 -0
- ppo/ppo.py +0 -349
- publish/markdown_format.py +0 -210
- replay.meta.json +1 -1
- results.json +0 -1
- rl_algo_impls/benchmark_publish.py +2 -2
- rl_algo_impls/huggingface_publish.py +1 -0
- runner/config.py +0 -155
- runner/env.py +0 -284
- runner/evaluate.py +0 -103
- runner/running_utils.py +0 -195
- runner/train.py +0 -141
- shared/algorithm.py +0 -35
- shared/callbacks/callback.py +0 -12
- shared/callbacks/eval_callback.py +0 -199
- shared/gae.py +0 -67
- shared/module/feature_extractor.py +0 -215
a2c-AntBulletEnv-v0.zip
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:7fe167fa72a102631922c50e6c755640b932d69be9416b03234f512eeb672602
|
3 |
-
size 130079
|
|
|
|
|
|
|
|
a2c-AntBulletEnv-v0/_stable_baselines3_version
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
1.7.0
|
|
|
|
a2c-AntBulletEnv-v0/data
DELETED
@@ -1,109 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"policy_class": {
|
3 |
-
":type:": "<class 'abc.ABCMeta'>",
|
4 |
-
":serialized:": "gAWVOwAAAAAAAACMIXN0YWJsZV9iYXNlbGluZXMzLmNvbW1vbi5wb2xpY2llc5SMEUFjdG9yQ3JpdGljUG9saWN5lJOULg==",
|
5 |
-
"__module__": "stable_baselines3.common.policies",
|
6 |
-
"__doc__": "\n Policy class for actor-critic algorithms (has both policy and value prediction).\n Used by A2C, PPO and the likes.\n\n :param observation_space: Observation space\n :param action_space: Action space\n :param lr_schedule: Learning rate schedule (could be constant)\n :param net_arch: The specification of the policy and value networks.\n :param activation_fn: Activation function\n :param ortho_init: Whether to use or not orthogonal initialization\n :param use_sde: Whether to use State Dependent Exploration or not\n :param log_std_init: Initial value for the log standard deviation\n :param full_std: Whether to use (n_features x n_actions) parameters\n for the std instead of only (n_features,) when using gSDE\n :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure\n a positive standard deviation (cf paper). It allows to keep variance\n above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.\n :param squash_output: Whether to squash the output using a tanh function,\n this allows to ensure boundaries when using gSDE.\n :param features_extractor_class: Features extractor to use.\n :param features_extractor_kwargs: Keyword arguments\n to pass to the features extractor.\n :param share_features_extractor: If True, the features extractor is shared between the policy and value networks.\n :param normalize_images: Whether to normalize images or not,\n dividing by 255.0 (True by default)\n :param optimizer_class: The optimizer to use,\n ``th.optim.Adam`` by default\n :param optimizer_kwargs: Additional keyword arguments,\n excluding the learning rate, to pass to the optimizer\n ",
|
7 |
-
"__init__": "<function ActorCriticPolicy.__init__ at 0x7ff06ed7c3a0>",
|
8 |
-
"_get_constructor_parameters": "<function ActorCriticPolicy._get_constructor_parameters at 0x7ff06ed7c430>",
|
9 |
-
"reset_noise": "<function ActorCriticPolicy.reset_noise at 0x7ff06ed7c4c0>",
|
10 |
-
"_build_mlp_extractor": "<function ActorCriticPolicy._build_mlp_extractor at 0x7ff06ed7c550>",
|
11 |
-
"_build": "<function ActorCriticPolicy._build at 0x7ff06ed7c5e0>",
|
12 |
-
"forward": "<function ActorCriticPolicy.forward at 0x7ff06ed7c670>",
|
13 |
-
"extract_features": "<function ActorCriticPolicy.extract_features at 0x7ff06ed7c700>",
|
14 |
-
"_get_action_dist_from_latent": "<function ActorCriticPolicy._get_action_dist_from_latent at 0x7ff06ed7c790>",
|
15 |
-
"_predict": "<function ActorCriticPolicy._predict at 0x7ff06ed7c820>",
|
16 |
-
"evaluate_actions": "<function ActorCriticPolicy.evaluate_actions at 0x7ff06ed7c8b0>",
|
17 |
-
"get_distribution": "<function ActorCriticPolicy.get_distribution at 0x7ff06ed7c940>",
|
18 |
-
"predict_values": "<function ActorCriticPolicy.predict_values at 0x7ff06ed7c9d0>",
|
19 |
-
"__abstractmethods__": "frozenset()",
|
20 |
-
"_abc_impl": "<_abc_data object at 0x7ff06ed70f60>"
|
21 |
-
},
|
22 |
-
"verbose": 1,
|
23 |
-
"policy_kwargs": {
|
24 |
-
":type:": "<class 'dict'>",
|
25 |
-
":serialized:": "gAWVowAAAAAAAAB9lCiMDGxvZ19zdGRfaW5pdJRK/v///4wKb3J0aG9faW5pdJSJjA9vcHRpbWl6ZXJfY2xhc3OUjBN0b3JjaC5vcHRpbS5ybXNwcm9wlIwHUk1TcHJvcJSTlIwQb3B0aW1pemVyX2t3YXJnc5R9lCiMBWFscGhhlEc/764UeuFHrowDZXBzlEc+5Pi1iONo8YwMd2VpZ2h0X2RlY2F5lEsAdXUu",
|
26 |
-
"log_std_init": -2,
|
27 |
-
"ortho_init": false,
|
28 |
-
"optimizer_class": "<class 'torch.optim.rmsprop.RMSprop'>",
|
29 |
-
"optimizer_kwargs": {
|
30 |
-
"alpha": 0.99,
|
31 |
-
"eps": 1e-05,
|
32 |
-
"weight_decay": 0
|
33 |
-
}
|
34 |
-
},
|
35 |
-
"observation_space": {
|
36 |
-
":type:": "<class 'gym.spaces.box.Box'>",
|
37 |
-
":serialized:": "gAWVZwIAAAAAAACMDmd5bS5zcGFjZXMuYm94lIwDQm94lJOUKYGUfZQojAVkdHlwZZSMBW51bXB5lGgFk5SMAmY0lImIh5RSlChLA4wBPJROTk5K/////0r/////SwB0lGKMBl9zaGFwZZRLHIWUjANsb3eUjBJudW1weS5jb3JlLm51bWVyaWOUjAtfZnJvbWJ1ZmZlcpSTlCiWcAAAAAAAAAAAAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/lGgKSxyFlIwBQ5R0lFKUjARoaWdolGgSKJZwAAAAAAAAAAAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH+UaApLHIWUaBV0lFKUjA1ib3VuZGVkX2JlbG93lGgSKJYcAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACUaAeMAmIxlImIh5RSlChLA4wBfJROTk5K/////0r/////SwB0lGJLHIWUaBV0lFKUjA1ib3VuZGVkX2Fib3ZllGgSKJYcAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACUaCFLHIWUaBV0lFKUjApfbnBfcmFuZG9tlE51Yi4=",
|
38 |
-
"dtype": "float32",
|
39 |
-
"_shape": [
|
40 |
-
28
|
41 |
-
],
|
42 |
-
"low": "[-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf\n -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf]",
|
43 |
-
"high": "[inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf\n inf inf inf inf inf inf inf inf inf inf]",
|
44 |
-
"bounded_below": "[False False False False False False False False False False False False\n False False False False False False False False False False False False\n False False False False]",
|
45 |
-
"bounded_above": "[False False False False False False False False False False False False\n False False False False False False False False False False False False\n False False False False]",
|
46 |
-
"_np_random": null
|
47 |
-
},
|
48 |
-
"action_space": {
|
49 |
-
":type:": "<class 'gym.spaces.box.Box'>",
|
50 |
-
":serialized:": "gAWVnwEAAAAAAACMDmd5bS5zcGFjZXMuYm94lIwDQm94lJOUKYGUfZQojAVkdHlwZZSMBW51bXB5lGgFk5SMAmY0lImIh5RSlChLA4wBPJROTk5K/////0r/////SwB0lGKMBl9zaGFwZZRLCIWUjANsb3eUjBJudW1weS5jb3JlLm51bWVyaWOUjAtfZnJvbWJ1ZmZlcpSTlCiWIAAAAAAAAAAAAIC/AACAvwAAgL8AAIC/AACAvwAAgL8AAIC/AACAv5RoCksIhZSMAUOUdJRSlIwEaGlnaJRoEiiWIAAAAAAAAAAAAIA/AACAPwAAgD8AAIA/AACAPwAAgD8AAIA/AACAP5RoCksIhZRoFXSUUpSMDWJvdW5kZWRfYmVsb3eUaBIolggAAAAAAAAAAQEBAQEBAQGUaAeMAmIxlImIh5RSlChLA4wBfJROTk5K/////0r/////SwB0lGJLCIWUaBV0lFKUjA1ib3VuZGVkX2Fib3ZllGgSKJYIAAAAAAAAAAEBAQEBAQEBlGghSwiFlGgVdJRSlIwKX25wX3JhbmRvbZROdWIu",
|
51 |
-
"dtype": "float32",
|
52 |
-
"_shape": [
|
53 |
-
8
|
54 |
-
],
|
55 |
-
"low": "[-1. -1. -1. -1. -1. -1. -1. -1.]",
|
56 |
-
"high": "[1. 1. 1. 1. 1. 1. 1. 1.]",
|
57 |
-
"bounded_below": "[ True True True True True True True True]",
|
58 |
-
"bounded_above": "[ True True True True True True True True]",
|
59 |
-
"_np_random": null
|
60 |
-
},
|
61 |
-
"n_envs": 4,
|
62 |
-
"num_timesteps": 2000000,
|
63 |
-
"_total_timesteps": 2000000,
|
64 |
-
"_num_timesteps_at_start": 0,
|
65 |
-
"seed": null,
|
66 |
-
"action_noise": null,
|
67 |
-
"start_time": 1674588743994659442,
|
68 |
-
"learning_rate": {
|
69 |
-
":type:": "<class 'function'>",
|
70 |
-
":serialized:": "gAWVdgIAAAAAAACMF2Nsb3VkcGlja2xlLmNsb3VkcGlja2xllIwOX21ha2VfZnVuY3Rpb26Uk5QoaACMDV9idWlsdGluX3R5cGWUk5SMCENvZGVUeXBllIWUUpQoSwFLAEsASwFLAksTQwiIAHwAFABTAJROhZQpjBJwcm9ncmVzc19yZW1haW5pbmeUhZSMHzxpcHl0aG9uLWlucHV0LTEzLWVhYTdkOGY5N2ZkNj6UjAhzY2hlZHVsZZRLBEMCAAGUjA1pbml0aWFsX3ZhbHVllIWUKXSUUpR9lCiMC19fcGFja2FnZV9flE6MCF9fbmFtZV9flIwIX19tYWluX1+UdU5OaACMEF9tYWtlX2VtcHR5X2NlbGyUk5QpUpSFlHSUUpSMHGNsb3VkcGlja2xlLmNsb3VkcGlja2xlX2Zhc3SUjBJfZnVuY3Rpb25fc2V0c3RhdGWUk5RoHH2UfZQoaBVoDYwMX19xdWFsbmFtZV9flIwhbGluZWFyX3NjaGVkdWxlLjxsb2NhbHM+LnNjaGVkdWxllIwPX19hbm5vdGF0aW9uc19flH2UKIwScHJvZ3Jlc3NfcmVtYWluaW5nlIwIYnVpbHRpbnOUjAVmbG9hdJSTlIwGcmV0dXJulGgpdYwOX19rd2RlZmF1bHRzX1+UTowMX19kZWZhdWx0c19flE6MCl9fbW9kdWxlX1+UaBaMB19fZG9jX1+UTowLX19jbG9zdXJlX1+UaACMCl9tYWtlX2NlbGyUk5RHP091EE1VHWmFlFKUhZSMF19jbG91ZHBpY2tsZV9zdWJtb2R1bGVzlF2UjAtfX2dsb2JhbHNfX5R9lHWGlIZSMC4="
|
71 |
-
},
|
72 |
-
"tensorboard_log": null,
|
73 |
-
"lr_schedule": {
|
74 |
-
":type:": "<class 'function'>",
|
75 |
-
":serialized:": "gAWVdgIAAAAAAACMF2Nsb3VkcGlja2xlLmNsb3VkcGlja2xllIwOX21ha2VfZnVuY3Rpb26Uk5QoaACMDV9idWlsdGluX3R5cGWUk5SMCENvZGVUeXBllIWUUpQoSwFLAEsASwFLAksTQwiIAHwAFABTAJROhZQpjBJwcm9ncmVzc19yZW1haW5pbmeUhZSMHzxpcHl0aG9uLWlucHV0LTEzLWVhYTdkOGY5N2ZkNj6UjAhzY2hlZHVsZZRLBEMCAAGUjA1pbml0aWFsX3ZhbHVllIWUKXSUUpR9lCiMC19fcGFja2FnZV9flE6MCF9fbmFtZV9flIwIX19tYWluX1+UdU5OaACMEF9tYWtlX2VtcHR5X2NlbGyUk5QpUpSFlHSUUpSMHGNsb3VkcGlja2xlLmNsb3VkcGlja2xlX2Zhc3SUjBJfZnVuY3Rpb25fc2V0c3RhdGWUk5RoHH2UfZQoaBVoDYwMX19xdWFsbmFtZV9flIwhbGluZWFyX3NjaGVkdWxlLjxsb2NhbHM+LnNjaGVkdWxllIwPX19hbm5vdGF0aW9uc19flH2UKIwScHJvZ3Jlc3NfcmVtYWluaW5nlIwIYnVpbHRpbnOUjAVmbG9hdJSTlIwGcmV0dXJulGgpdYwOX19rd2RlZmF1bHRzX1+UTowMX19kZWZhdWx0c19flE6MCl9fbW9kdWxlX1+UaBaMB19fZG9jX1+UTowLX19jbG9zdXJlX1+UaACMCl9tYWtlX2NlbGyUk5RHP091EE1VHWmFlFKUhZSMF19jbG91ZHBpY2tsZV9zdWJtb2R1bGVzlF2UjAtfX2dsb2JhbHNfX5R9lHWGlIZSMC4="
|
76 |
-
},
|
77 |
-
"_last_obs": {
|
78 |
-
":type:": "<class 'numpy.ndarray'>",
|
79 |
-
":serialized:": "gAWVNQIAAAAAAACMEm51bXB5LmNvcmUubnVtZXJpY5SMC19mcm9tYnVmZmVylJOUKJbAAQAAAAAAANkXHT/P+Bu/1A+1PjIvpT89Bh+/zjqjPhDbqT4Dqha/AjV/P+zYsr+uKGE/dj6Xv6Cxv7+R4wc8NxoKv6piD8CimWe/+TfIPi55aD9Zeju9gErvPhI+3L5IS6C/aYMmP8k7cr8P4ve/8OcFwEF3ab/IB8Q+iiKSPjT47T5cWJE/BK7Rvr5ZgD+8Kak900civ+X6jr4MOYU9Y5Shv5Jbtz6m2U++w9apP5QSCT7zfbk/P/O2PzWg8Lp/kpq+P/AOvwa3jb+DbFA/cyMdP3pKXj/JO3K//jAEP/DnBcDGWow/sCkSP3Xwnj5GF+s+0PqgP1LPhL70tAs/D5JdPdH8PDwVeaM+7Rabv2R/mz5gexrADzf+vtnLpj8/PMu+cPeXP8wagb4/bOU/gJVZPfvU8r+XP0C/uvbjvuIAMT9kUDM/yTtyv/4wBD9+tfQ+xlqMP1d1jT2Zgbe+1bHkPva9wj/bEYW/mrFZP9iSDD9GL/G+l7yMP/DjTz9xM7U/HN+OPqTRRr94K0XAr08LO38geb/41IG/tzC3v9ZobD/na868JuUBvnf+tr+WGXi/NfUov0BGhz8P4ve/frX0PkF3ab+UjAVudW1weZSMBWR0eXBllJOUjAJmNJSJiIeUUpQoSwOMATyUTk5OSv////9K/////0sAdJRiSwRLHIaUjAFDlHSUUpQu"
|
80 |
-
},
|
81 |
-
"_last_episode_starts": {
|
82 |
-
":type:": "<class 'numpy.ndarray'>",
|
83 |
-
":serialized:": "gAWVdwAAAAAAAACMEm51bXB5LmNvcmUubnVtZXJpY5SMC19mcm9tYnVmZmVylJOUKJYEAAAAAAAAAAAAAACUjAVudW1weZSMBWR0eXBllJOUjAJiMZSJiIeUUpQoSwOMAXyUTk5OSv////9K/////0sAdJRiSwSFlIwBQ5R0lFKULg=="
|
84 |
-
},
|
85 |
-
"_last_original_obs": {
|
86 |
-
":type:": "<class 'numpy.ndarray'>",
|
87 |
-
":serialized:": "gAWVNQIAAAAAAACMEm51bXB5LmNvcmUubnVtZXJpY5SMC19mcm9tYnVmZmVylJOUKJbAAQAAAAAAAAAAAAAtoGm1AACAPwAAAAAAAAAAAAAAAAAAAAAAAACA9o2EvQAAAACgRPm/AAAAAMwsiT0AAAAAyVXjPwAAAABeHa69AAAAAFdF5z8AAAAAAcmpPQAAAAApeea/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHEHLNAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAgEwf2b0AAAAA2Gz1vwAAAADLltm6AAAAAM9Z7D8AAAAA3yVmvQAAAADn69s/AAAAADiBBz0AAAAAw1rrvwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBNMrYAAIA/AAAAAAAAAAAAAAAAAAAAAAAAAIAyWb08AAAAAEuk/L8AAAAAL1SzvQAAAAByavk/AAAAAP8V470AAAAA/m/iPwAAAAAtweW9AAAAAExT9L8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADP6ZA0AACAPwAAAAAAAAAAAAAAAAAAAAAAAACAfKcNPgAAAAAF3eK/AAAAACbLMT0AAAAARnIAQAAAAAArb0K7AAAAAMWz/z8AAAAANVvXvQAAAAAaNe2/AAAAAAAAAAAAAAAAAAAAAAAAAACUjAVudW1weZSMBWR0eXBllJOUjAJmNJSJiIeUUpQoSwOMATyUTk5OSv////9K/////0sAdJRiSwRLHIaUjAFDlHSUUpQu"
|
88 |
-
},
|
89 |
-
"_episode_num": 0,
|
90 |
-
"use_sde": true,
|
91 |
-
"sde_sample_freq": -1,
|
92 |
-
"_current_progress_remaining": 0.0,
|
93 |
-
"ep_info_buffer": {
|
94 |
-
":type:": "<class 'collections.deque'>",
|
95 |
-
":serialized:": "gAWVRAwAAAAAAACMC2NvbGxlY3Rpb25zlIwFZGVxdWWUk5QpS2SGlFKUKH2UKIwBcpRHQKaK8jRlYlqMAWyUTegDjAF0lEdArJafl8w6AHV9lChoBkdAproSsySFG2gHTegDaAhHQKyeau01IiF1fZQoaAZHQKW/SZpi7TVoB03oA2gIR0Csnr6jnFHbdX2UKGgGR0ClbyI/7iyZaAdN6ANoCEdArKJTlmvnsHV9lChoBkdApXSPAh0QsmgHTegDaAhHQKyiYGeMAFR1fZQoaAZHQKYeObMHKOloB03oA2gIR0Csqia/qPfbdX2UKGgGR0Cl25fAbhm5aAdN6ANoCEdArKp0rAgxJ3V9lChoBkdApGo8rmQr+mgHTegDaAhHQKyt3NGmUGF1fZQoaAZHQKW26Hk92X9oB03oA2gIR0CsreV76YVqdX2UKGgGR0Clpgs1KoQ4aAdN6ANoCEdArLXvPw/gSHV9lChoBkdApZRJfOUt7WgHTegDaAhHQKy2Q7jDKo11fZQoaAZHQKWQJ2alUIdoB03oA2gIR0Csubl7Uoa2dX2UKGgGR0CmLI2nsLOSaAdN6ANoCEdArLnCPU8V6HV9lChoBkdApVqDB9Cu2mgHTegDaAhHQKzBwo1k1/F1fZQoaAZHQKZuXCqp97ZoB03oA2gIR0CswhCEQGwBdX2UKGgGR0Cmlu0j9n9OaAdN6ANoCEdArMWdjbzshXV9lChoBkdApnp/RgJC0GgHTegDaAhHQKzFpoL5RCR1fZQoaAZHQKR8O5J9RaZoB03oA2gIR0CszVechC+ldX2UKGgGR0CltMF41P30aAdN6ANoCEdArM2rY/Vy3nV9lChoBkdApc4PphWo32gHTegDaAhHQKzRWe/5+H91fZQoaAZHQKZrGVUMoc9oB03oA2gIR0Cs0WMKTjebdX2UKGgGR0CmgCTvy9VWaAdN6ANoCEdArNj4dGRV63V9lChoBkdApfwql3yI6GgHTegDaAhHQKzZS31BdD91fZQoaAZHQKZoYRq46OpoB03oA2gIR0Cs3Nb8m8dxdX2UKGgGR0ClFq4sVclgaAdN6ANoCEdArNzf2PDHfnV9lChoBkdApdattVJcxGgHTegDaAhHQKzkmxptaZB1fZQoaAZHQKWUFjVhCt1oB03oA2gIR0Cs5O8LjPv8dX2UKGgGR0CmL6/N7jT8aAdN6ANoCEdArOhYIKMNt3V9lChoBkdApdR+nQ6ZIGgHTegDaAhHQKzoYBRQ7911fZQoaAZHQKYjF24d6s1oB03oA2gIR0Cs8AV6/qPfdX2UKGgGR0ClmIZ9Vmz0aAdN6ANoCEdArPBaKpDNQnV9lChoBkdApk9kpiI+GGgHTegDaAhHQKz0EkN4JNV1fZQoaAZHQKVyiL2HtWxoB03oA2gIR0Cs9Bv5pJwsdX2UKGgGR0CmUOSncclxaAdN6ANoCEdArPvNdVvMr3V9lChoBkdAprML/VAiV2gHTegDaAhHQKz8HtdiUgV1fZQoaAZHQKWHJzcRDkVoB03oA2gIR0Cs/4O6NEPUdX2UKGgGR0Clvp+QlruZaAdN6ANoCEdArP+Lvd/KAHV9lChoBkdAptM4bfgrH2gHTegDaAhHQK0HKjbBXS11fZQoaAZHQKTNTr/sE7poB03oA2gIR0CtB3/xtpEhdX2UKGgGR0ClpbtFjNILaAdN6ANoCEdArQsPj+717XV9lChoBkdApN2DeZXuE2gHTegDaAhHQK0LGJl8PWh1fZQoaAZHQKOCkuKXOW1oB03oA2gIR0CtEulT3qRmdX2UKGgGR0Cl+HYb83uNaAdN6ANoCEdArRM6aPS2IHV9lChoBkdApPy5dKNADGgHTegDaAhHQK0Wr5WzWwx1fZQoaAZHQKZWr4dIXj5oB03oA2gIR0CtFrfCAMDwdX2UKGgGR0CmKUQW3z+WaAdN6ANoCEdArR6OpQ1rI3V9lChoBkdAphbtDtw71mgHTegDaAhHQK0e3Ktga3t1fZQoaAZHQKW7YPTXrdFoB03oA2gIR0CtIkclolD4dX2UKGgGR0Cl5MDps41haAdN6ANoCEdArSJQ/RmbsnV9lChoBkdApiBvW4EwFmgHTegDaAhHQK0p+HEdeY51fZQoaAZHQKQNrjc2zfJoB03oA2gIR0CtKkj28IzFdX2UKGgGR0Cckdajvd/KaAdNggNoCEdArSyY1pCa7XV9lChoBkdApomsPvrnkmgHTegDaAhHQK0tzR4yGi51fZQoaAZHQKZk+Fyq+8JoB03oA2gIR0CtNXu1OTJRdX2UKGgGR0CllYyxqwhXaAdN6ANoCEdArTXIMx46fnV9lChoBkdAphfyzRhMJ2gHTegDaAhHQK04BurIYFd1fZQoaAZHQKSM9KV6eGxoB03oA2gIR0CtOUYlpoK2dX2UKGgGR0CmiWAJTl1baAdN6ANoCEdArUEHCj1wpHV9lChoBkdApgHe0gKWs2gHTegDaAhHQK1BYPxQSBd1fZQoaAZHQKcAthnanJloB03oA2gIR0CtQ791U2k0dX2UKGgGR0CiDy96Tnq3aAdN6ANoCEdArUTnRPXTVnV9lChoBkdApFqvcSGrS2gHTegDaAhHQK1MZh/Aj6h1fZQoaAZHQKXAlP/rB0poB03oA2gIR0CtTLSQHRkVdX2UKGgGR0CjmJprDZUUaAdN6ANoCEdArU7xddE9dXV9lChoBkdAptUnsZ5zHWgHTegDaAhHQK1QI8VYZEV1fZQoaAZHQKLkzDlYEGJoB03oA2gIR0CtWCQob4rSdX2UKGgGR0CeM9MS9M9KaAdN6ANoCEdArVh42GZeA3V9lChoBkdApeqzwMH8j2gHTegDaAhHQK1a0TSsr/d1fZQoaAZHQKYbd2wmmchoB03oA2gIR0CtXAZeJHiFdX2UKGgGR0Cl3GUtRNypaAdN6ANoCEdArWO18qnWKHV9lChoBkdApWpvz+WGAWgHTegDaAhHQK1kBFn7Hhl1fZQoaAZHQKZSKHBUJfJoB03oA2gIR0CtZlCuMdcTdX2UKGgGR0CmpdCiAUcoaAdN6ANoCEdArWeCZnctXnV9lChoBkdAppM5Q53kgmgHTegDaAhHQK1vG8h9srN1fZQoaAZHQKYExGDtgKFoB03oA2gIR0Ctb22CVbA2dX2UKGgGR0CmtFdFWn0kaAdN6ANoCEdArXG/VurIYHV9lChoBkdApfNvitJWemgHTegDaAhHQK1zFY5DJEJ1fZQoaAZHQKXdBBSk0rNoB03oA2gIR0CteqxdY4hmdX2UKGgGR0CmhG8kdFOPaAdN6ANoCEdArXr6QvHtGHV9lChoBkdApU/SYPXkHWgHTegDaAhHQK19RnVXmvJ1fZQoaAZHQKWUeXk5p8FoB03oA2gIR0Ctfnhun/DMdX2UKGgGR0CkjzIt16mgaAdN6ANoCEdArYYPdCVrynV9lChoBkdApPtMZ75VO2gHTegDaAhHQK2GYNAkcCJ1fZQoaAZHQKZMGafjCHhoB03oA2gIR0CtiLaOYIBzdX2UKGgGR0Cl0eULMLWqaAdN6ANoCEdArYnyg7HQyHV9lChoBkdApd5MyN4qw2gHTegDaAhHQK2Rs0iQkop1fZQoaAZHQKVcAo5xR2toB03oA2gIR0CtkgYjB2wFdX2UKGgGR0ClzucPFvQ4aAdN6ANoCEdArZRV0NjLCHV9lChoBkdApVq6v7m+02gHTegDaAhHQK2VmQjD8+B1fZQoaAZHQKSTCldC3PRoB03oA2gIR0CtnXYwqRU4dX2UKGgGR0Cl5iNfw7T2aAdN6ANoCEdArZ3F+iJwbXV9lChoBkdApdJXDYRNAWgHTegDaAhHQK2gGQlKK511fZQoaAZHQKZjJvMKTjhoB03oA2gIR0CtoWfvWpZPdX2UKGgGR0Cllc9fsu3+aAdN6ANoCEdAralP+bVjJHV9lChoBkdApTLXMSsbN2gHTegDaAhHQK2pqDU3GXJ1fZQoaAZHQKYQnPVNHpdoB03oA2gIR0Ctq/a9kBjndX2UKGgGR0Cmo4jRlYlqaAdN6ANoCEdAra09Tgl4T3V9lChoBkdApNBz6JqIrWgHTegDaAhHQK2051jAi3Z1fZQoaAZHQKXI+y5Zr59oB03oA2gIR0CttTxWkrPMdX2UKGgGR0CharHoPkJbaAdNVgNoCEdArbXjyUcGT3VlLg=="
|
96 |
-
},
|
97 |
-
"ep_success_buffer": {
|
98 |
-
":type:": "<class 'collections.deque'>",
|
99 |
-
":serialized:": "gAWVIAAAAAAAAACMC2NvbGxlY3Rpb25zlIwFZGVxdWWUk5QpS2SGlFKULg=="
|
100 |
-
},
|
101 |
-
"_n_updates": 62500,
|
102 |
-
"n_steps": 8,
|
103 |
-
"gamma": 0.99,
|
104 |
-
"gae_lambda": 0.9,
|
105 |
-
"ent_coef": 0.0,
|
106 |
-
"vf_coef": 0.4,
|
107 |
-
"max_grad_norm": 0.5,
|
108 |
-
"normalize_advantage": false
|
109 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a2c-AntBulletEnv-v0/policy.optimizer.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:ae34a234a13e78e7764685d17296460fb4b030a1e520b8dddb06e1151b45c170
|
3 |
-
size 56190
|
|
|
|
|
|
|
|
a2c-AntBulletEnv-v0/policy.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:e8d436595959204150e7bb98d9534582e5b1f9994536c261ca00a695c7ce203c
|
3 |
-
size 56958
|
|
|
|
|
|
|
|
a2c-AntBulletEnv-v0/pytorch_variables.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d030ad8db708280fcae77d87e973102039acd23a11bdecc3db8eb6c0ac940ee1
|
3 |
-
size 431
|
|
|
|
|
|
|
|
a2c-AntBulletEnv-v0/system_info.txt
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
- OS: Linux-5.10.147+-x86_64-with-glibc2.29 # 1 SMP Sat Dec 10 16:00:40 UTC 2022
|
2 |
-
- Python: 3.8.10
|
3 |
-
- Stable-Baselines3: 1.7.0
|
4 |
-
- PyTorch: 1.13.1+cu116
|
5 |
-
- GPU Enabled: True
|
6 |
-
- Numpy: 1.21.6
|
7 |
-
- Gym: 0.21.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
a2c/a2c.py
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from dataclasses import asdict, dataclass, field
|
7 |
-
from time import perf_counter
|
8 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
9 |
-
from typing import List, Optional, Sequence, NamedTuple, TypeVar
|
10 |
-
|
11 |
-
from shared.algorithm import Algorithm
|
12 |
-
from shared.callbacks.callback import Callback
|
13 |
-
from shared.gae import compute_advantage, compute_rtg_and_advantage
|
14 |
-
from shared.policy.on_policy import ActorCritic
|
15 |
-
from shared.schedule import schedule, update_learning_rate
|
16 |
-
from shared.stats import log_scalars
|
17 |
-
from shared.trajectory import Trajectory, TrajectoryAccumulator
|
18 |
-
from wrappers.vectorable_wrapper import (
|
19 |
-
VecEnv,
|
20 |
-
VecEnvObs,
|
21 |
-
single_observation_space,
|
22 |
-
single_action_space,
|
23 |
-
)
|
24 |
-
|
25 |
-
A2CSelf = TypeVar("A2CSelf", bound="A2C")
|
26 |
-
|
27 |
-
|
28 |
-
class A2C(Algorithm):
|
29 |
-
def __init__(
|
30 |
-
self,
|
31 |
-
policy: ActorCritic,
|
32 |
-
env: VecEnv,
|
33 |
-
device: torch.device,
|
34 |
-
tb_writer: SummaryWriter,
|
35 |
-
learning_rate: float = 7e-4,
|
36 |
-
learning_rate_decay: str = "none",
|
37 |
-
n_steps: int = 5,
|
38 |
-
gamma: float = 0.99,
|
39 |
-
gae_lambda: float = 1.0,
|
40 |
-
ent_coef: float = 0.0,
|
41 |
-
ent_coef_decay: str = "none",
|
42 |
-
vf_coef: float = 0.5,
|
43 |
-
max_grad_norm: float = 0.5,
|
44 |
-
rms_prop_eps: float = 1e-5,
|
45 |
-
use_rms_prop: bool = True,
|
46 |
-
sde_sample_freq: int = -1,
|
47 |
-
normalize_advantage: bool = False,
|
48 |
-
) -> None:
|
49 |
-
super().__init__(policy, env, device, tb_writer)
|
50 |
-
self.policy = policy
|
51 |
-
|
52 |
-
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
|
53 |
-
if use_rms_prop:
|
54 |
-
self.optimizer = torch.optim.RMSprop(
|
55 |
-
policy.parameters(), lr=learning_rate, eps=rms_prop_eps
|
56 |
-
)
|
57 |
-
else:
|
58 |
-
self.optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate)
|
59 |
-
|
60 |
-
self.n_steps = n_steps
|
61 |
-
|
62 |
-
self.gamma = gamma
|
63 |
-
self.gae_lambda = gae_lambda
|
64 |
-
|
65 |
-
self.vf_coef = vf_coef
|
66 |
-
self.ent_coef_schedule = schedule(ent_coef_decay, ent_coef)
|
67 |
-
self.max_grad_norm = max_grad_norm
|
68 |
-
|
69 |
-
self.sde_sample_freq = sde_sample_freq
|
70 |
-
self.normalize_advantage = normalize_advantage
|
71 |
-
|
72 |
-
def learn(
|
73 |
-
self: A2CSelf, total_timesteps: int, callback: Optional[Callback] = None
|
74 |
-
) -> A2CSelf:
|
75 |
-
epoch_dim = (self.n_steps, self.env.num_envs)
|
76 |
-
step_dim = (self.env.num_envs,)
|
77 |
-
obs_space = single_observation_space(self.env)
|
78 |
-
act_space = single_action_space(self.env)
|
79 |
-
|
80 |
-
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype)
|
81 |
-
actions = np.zeros(epoch_dim + act_space.shape, dtype=act_space.dtype)
|
82 |
-
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
83 |
-
episode_starts = np.zeros(epoch_dim, dtype=np.byte)
|
84 |
-
values = np.zeros(epoch_dim, dtype=np.float32)
|
85 |
-
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
86 |
-
|
87 |
-
next_obs = self.env.reset()
|
88 |
-
next_episode_starts = np.ones(step_dim, dtype=np.byte)
|
89 |
-
|
90 |
-
timesteps_elapsed = 0
|
91 |
-
while timesteps_elapsed < total_timesteps:
|
92 |
-
start_time = perf_counter()
|
93 |
-
|
94 |
-
progress = timesteps_elapsed / total_timesteps
|
95 |
-
ent_coef = self.ent_coef_schedule(progress)
|
96 |
-
learning_rate = self.lr_schedule(progress)
|
97 |
-
update_learning_rate(self.optimizer, learning_rate)
|
98 |
-
log_scalars(
|
99 |
-
self.tb_writer,
|
100 |
-
"charts",
|
101 |
-
{
|
102 |
-
"ent_coef": ent_coef,
|
103 |
-
"learning_rate": learning_rate,
|
104 |
-
},
|
105 |
-
timesteps_elapsed,
|
106 |
-
)
|
107 |
-
|
108 |
-
self.policy.eval()
|
109 |
-
self.policy.reset_noise()
|
110 |
-
for s in range(self.n_steps):
|
111 |
-
timesteps_elapsed += self.env.num_envs
|
112 |
-
if self.sde_sample_freq > 0 and s > 0 and s % self.sde_sample_freq == 0:
|
113 |
-
self.policy.reset_noise()
|
114 |
-
|
115 |
-
obs[s] = next_obs
|
116 |
-
episode_starts[s] = next_episode_starts
|
117 |
-
|
118 |
-
actions[s], values[s], logprobs[s], clamped_action = self.policy.step(
|
119 |
-
next_obs
|
120 |
-
)
|
121 |
-
next_obs, rewards[s], next_episode_starts, _ = self.env.step(
|
122 |
-
clamped_action
|
123 |
-
)
|
124 |
-
|
125 |
-
advantages = np.zeros(epoch_dim, dtype=np.float32)
|
126 |
-
last_gae_lam = 0
|
127 |
-
for t in reversed(range(self.n_steps)):
|
128 |
-
if t == self.n_steps - 1:
|
129 |
-
next_nonterminal = 1.0 - next_episode_starts
|
130 |
-
next_value = self.policy.value(next_obs)
|
131 |
-
else:
|
132 |
-
next_nonterminal = 1.0 - episode_starts[t + 1]
|
133 |
-
next_value = values[t + 1]
|
134 |
-
delta = (
|
135 |
-
rewards[t] + self.gamma * next_value * next_nonterminal - values[t]
|
136 |
-
)
|
137 |
-
last_gae_lam = (
|
138 |
-
delta
|
139 |
-
+ self.gamma * self.gae_lambda * next_nonterminal * last_gae_lam
|
140 |
-
)
|
141 |
-
advantages[t] = last_gae_lam
|
142 |
-
returns = advantages + values
|
143 |
-
|
144 |
-
b_obs = torch.tensor(obs.reshape((-1,) + obs_space.shape)).to(self.device)
|
145 |
-
b_actions = torch.tensor(actions.reshape((-1,) + act_space.shape)).to(
|
146 |
-
self.device
|
147 |
-
)
|
148 |
-
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
|
149 |
-
b_returns = torch.tensor(returns.reshape(-1)).to(self.device)
|
150 |
-
|
151 |
-
if self.normalize_advantage:
|
152 |
-
b_advantages = (b_advantages - b_advantages.mean()) / (
|
153 |
-
b_advantages.std() + 1e-8
|
154 |
-
)
|
155 |
-
|
156 |
-
self.policy.train()
|
157 |
-
logp_a, entropy, v = self.policy(b_obs, b_actions)
|
158 |
-
|
159 |
-
pi_loss = -(b_advantages * logp_a).mean()
|
160 |
-
value_loss = F.mse_loss(b_returns, v)
|
161 |
-
entropy_loss = -entropy.mean()
|
162 |
-
|
163 |
-
loss = pi_loss + self.vf_coef * value_loss + ent_coef * entropy_loss
|
164 |
-
|
165 |
-
self.optimizer.zero_grad()
|
166 |
-
loss.backward()
|
167 |
-
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
168 |
-
self.optimizer.step()
|
169 |
-
|
170 |
-
y_pred = values.reshape(-1)
|
171 |
-
y_true = returns.reshape(-1)
|
172 |
-
var_y = np.var(y_true).item()
|
173 |
-
explained_var = (
|
174 |
-
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
|
175 |
-
)
|
176 |
-
|
177 |
-
end_time = perf_counter()
|
178 |
-
rollout_steps = self.n_steps * self.env.num_envs
|
179 |
-
self.tb_writer.add_scalar(
|
180 |
-
"train/steps_per_second",
|
181 |
-
(rollout_steps) / (end_time - start_time),
|
182 |
-
timesteps_elapsed,
|
183 |
-
)
|
184 |
-
|
185 |
-
log_scalars(
|
186 |
-
self.tb_writer,
|
187 |
-
"losses",
|
188 |
-
{
|
189 |
-
"loss": loss.item(),
|
190 |
-
"pi_loss": pi_loss.item(),
|
191 |
-
"v_loss": value_loss.item(),
|
192 |
-
"entropy_loss": entropy_loss.item(),
|
193 |
-
"explained_var": explained_var,
|
194 |
-
},
|
195 |
-
timesteps_elapsed,
|
196 |
-
)
|
197 |
-
|
198 |
-
if callback:
|
199 |
-
callback.on_step(timesteps_elapsed=rollout_steps)
|
200 |
-
|
201 |
-
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/benchmark_test.sh
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
|
3 |
-
export WANDB_PROJECT_NAME="rl-algo-impls"
|
4 |
-
|
5 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
6 |
-
|
7 |
-
ALGOS=(
|
8 |
-
# "vpg"
|
9 |
-
"dqn"
|
10 |
-
# "ppo"
|
11 |
-
)
|
12 |
-
ENVS=(
|
13 |
-
# Basic
|
14 |
-
"CartPole-v1"
|
15 |
-
"MountainCar-v0"
|
16 |
-
# "MountainCarContinuous-v0"
|
17 |
-
"Acrobot-v1"
|
18 |
-
"LunarLander-v2"
|
19 |
-
# # PyBullet
|
20 |
-
# "HalfCheetahBulletEnv-v0"
|
21 |
-
# "AntBulletEnv-v0"
|
22 |
-
# "HopperBulletEnv-v0"
|
23 |
-
# "Walker2DBulletEnv-v0"
|
24 |
-
# # CarRacing
|
25 |
-
# "CarRacing-v0"
|
26 |
-
# Atari
|
27 |
-
"PongNoFrameskip-v4"
|
28 |
-
"BreakoutNoFrameskip-v4"
|
29 |
-
"SpaceInvadersNoFrameskip-v4"
|
30 |
-
"QbertNoFrameskip-v4"
|
31 |
-
)
|
32 |
-
train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/colab_atari1.sh
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
ALGOS="ppo"
|
3 |
-
ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
|
4 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
5 |
-
train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/colab_atari2.sh
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
ALGOS="ppo"
|
3 |
-
ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
|
4 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
5 |
-
train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/colab_basic.sh
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
ALGOS="ppo"
|
3 |
-
ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
|
4 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
5 |
-
train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/colab_benchmark.ipynb
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"provenance": [],
|
7 |
-
"machine_shape": "hm",
|
8 |
-
"authorship_tag": "ABX9TyOGIH7rqgasim3Sz7b1rpoE",
|
9 |
-
"include_colab_link": true
|
10 |
-
},
|
11 |
-
"kernelspec": {
|
12 |
-
"name": "python3",
|
13 |
-
"display_name": "Python 3"
|
14 |
-
},
|
15 |
-
"language_info": {
|
16 |
-
"name": "python"
|
17 |
-
},
|
18 |
-
"gpuClass": "standard",
|
19 |
-
"accelerator": "GPU"
|
20 |
-
},
|
21 |
-
"cells": [
|
22 |
-
{
|
23 |
-
"cell_type": "markdown",
|
24 |
-
"metadata": {
|
25 |
-
"id": "view-in-github",
|
26 |
-
"colab_type": "text"
|
27 |
-
},
|
28 |
-
"source": [
|
29 |
-
"<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
-
]
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"cell_type": "markdown",
|
34 |
-
"source": [
|
35 |
-
"# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
|
36 |
-
"## Parameters\n",
|
37 |
-
"\n",
|
38 |
-
"\n",
|
39 |
-
"1. Wandb\n",
|
40 |
-
"\n"
|
41 |
-
],
|
42 |
-
"metadata": {
|
43 |
-
"id": "S-tXDWP8WTLc"
|
44 |
-
}
|
45 |
-
},
|
46 |
-
{
|
47 |
-
"cell_type": "code",
|
48 |
-
"source": [
|
49 |
-
"from getpass import getpass\n",
|
50 |
-
"import os\n",
|
51 |
-
"os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
|
52 |
-
],
|
53 |
-
"metadata": {
|
54 |
-
"id": "1ZtdYgxWNGwZ"
|
55 |
-
},
|
56 |
-
"execution_count": null,
|
57 |
-
"outputs": []
|
58 |
-
},
|
59 |
-
{
|
60 |
-
"cell_type": "markdown",
|
61 |
-
"source": [
|
62 |
-
"## Setup\n",
|
63 |
-
"Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
|
64 |
-
],
|
65 |
-
"metadata": {
|
66 |
-
"id": "bsG35Io0hmKG"
|
67 |
-
}
|
68 |
-
},
|
69 |
-
{
|
70 |
-
"cell_type": "code",
|
71 |
-
"source": [
|
72 |
-
"%%capture\n",
|
73 |
-
"!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
|
74 |
-
],
|
75 |
-
"metadata": {
|
76 |
-
"id": "k5ynTV25hdAf"
|
77 |
-
},
|
78 |
-
"execution_count": null,
|
79 |
-
"outputs": []
|
80 |
-
},
|
81 |
-
{
|
82 |
-
"cell_type": "markdown",
|
83 |
-
"source": [
|
84 |
-
"Installing the correct packages:\n",
|
85 |
-
"\n",
|
86 |
-
"While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
|
87 |
-
],
|
88 |
-
"metadata": {
|
89 |
-
"id": "jKxGok-ElYQ7"
|
90 |
-
}
|
91 |
-
},
|
92 |
-
{
|
93 |
-
"cell_type": "code",
|
94 |
-
"source": [
|
95 |
-
"%%capture\n",
|
96 |
-
"!apt install python-opengl\n",
|
97 |
-
"!apt install ffmpeg\n",
|
98 |
-
"!apt install xvfb\n",
|
99 |
-
"!apt install swig"
|
100 |
-
],
|
101 |
-
"metadata": {
|
102 |
-
"id": "nn6EETTc2Ewf"
|
103 |
-
},
|
104 |
-
"execution_count": null,
|
105 |
-
"outputs": []
|
106 |
-
},
|
107 |
-
{
|
108 |
-
"cell_type": "code",
|
109 |
-
"source": [
|
110 |
-
"%%capture\n",
|
111 |
-
"%cd /content/rl-algo-impls\n",
|
112 |
-
"!pip install -r colab_requirements.txt"
|
113 |
-
],
|
114 |
-
"metadata": {
|
115 |
-
"id": "AfZh9rH3yQii"
|
116 |
-
},
|
117 |
-
"execution_count": null,
|
118 |
-
"outputs": []
|
119 |
-
},
|
120 |
-
{
|
121 |
-
"cell_type": "markdown",
|
122 |
-
"source": [
|
123 |
-
"## Run Once Per Runtime"
|
124 |
-
],
|
125 |
-
"metadata": {
|
126 |
-
"id": "4o5HOLjc4wq7"
|
127 |
-
}
|
128 |
-
},
|
129 |
-
{
|
130 |
-
"cell_type": "code",
|
131 |
-
"source": [
|
132 |
-
"import wandb\n",
|
133 |
-
"wandb.login()"
|
134 |
-
],
|
135 |
-
"metadata": {
|
136 |
-
"id": "PCXa5tdS2qFX"
|
137 |
-
},
|
138 |
-
"execution_count": null,
|
139 |
-
"outputs": []
|
140 |
-
},
|
141 |
-
{
|
142 |
-
"cell_type": "markdown",
|
143 |
-
"source": [
|
144 |
-
"## Restart Session beteween runs"
|
145 |
-
],
|
146 |
-
"metadata": {
|
147 |
-
"id": "AZBZfSUV43JQ"
|
148 |
-
}
|
149 |
-
},
|
150 |
-
{
|
151 |
-
"cell_type": "code",
|
152 |
-
"source": [
|
153 |
-
"%%capture\n",
|
154 |
-
"from pyvirtualdisplay import Display\n",
|
155 |
-
"\n",
|
156 |
-
"virtual_display = Display(visible=0, size=(1400, 900))\n",
|
157 |
-
"virtual_display.start()"
|
158 |
-
],
|
159 |
-
"metadata": {
|
160 |
-
"id": "VzemeQJP2NO9"
|
161 |
-
},
|
162 |
-
"execution_count": null,
|
163 |
-
"outputs": []
|
164 |
-
},
|
165 |
-
{
|
166 |
-
"cell_type": "markdown",
|
167 |
-
"source": [
|
168 |
-
"The below 5 bash scripts train agents on environments with 3 seeds each:\n",
|
169 |
-
"- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
|
170 |
-
"- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
|
171 |
-
"- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
|
172 |
-
],
|
173 |
-
"metadata": {
|
174 |
-
"id": "nSHfna0hLlO1"
|
175 |
-
}
|
176 |
-
},
|
177 |
-
{
|
178 |
-
"cell_type": "code",
|
179 |
-
"source": [
|
180 |
-
"%cd /content/rl-algo-impls\n",
|
181 |
-
"os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
|
182 |
-
"!./benchmarks/colab_basic.sh\n",
|
183 |
-
"!./benchmarks/colab_pybullet.sh\n",
|
184 |
-
"# !./benchmarks/colab_carracing.sh\n",
|
185 |
-
"# !./benchmarks/colab_atari1.sh\n",
|
186 |
-
"# !./benchmarks/colab_atari2.sh"
|
187 |
-
],
|
188 |
-
"metadata": {
|
189 |
-
"id": "07aHYFH1zfXa"
|
190 |
-
},
|
191 |
-
"execution_count": null,
|
192 |
-
"outputs": []
|
193 |
-
}
|
194 |
-
]
|
195 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/colab_carracing.sh
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
ALGOS="ppo"
|
3 |
-
ENVS="CarRacing-v0"
|
4 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
5 |
-
train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/colab_pybullet.sh
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
ALGOS="ppo"
|
3 |
-
ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
|
4 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
5 |
-
train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
benchmarks/train_loop.sh
DELETED
@@ -1,15 +0,0 @@
|
|
1 |
-
train_loop () {
|
2 |
-
local WANDB_TAGS="benchmark_$(git rev-parse --short HEAD) host_$(hostname)"
|
3 |
-
local algo
|
4 |
-
local env
|
5 |
-
local seed
|
6 |
-
local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
|
7 |
-
local SEEDS="${SEEDS:-1 2 3}"
|
8 |
-
for algo in $(echo $1); do
|
9 |
-
for env in $(echo $2); do
|
10 |
-
for seed in $SEEDS; do
|
11 |
-
echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME
|
12 |
-
done
|
13 |
-
done
|
14 |
-
done
|
15 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colab_enjoy.ipynb
DELETED
@@ -1,198 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"provenance": [],
|
7 |
-
"machine_shape": "hm",
|
8 |
-
"authorship_tag": "ABX9TyN6S7kyJKrM5x0OOiN+CgTc",
|
9 |
-
"include_colab_link": true
|
10 |
-
},
|
11 |
-
"kernelspec": {
|
12 |
-
"name": "python3",
|
13 |
-
"display_name": "Python 3"
|
14 |
-
},
|
15 |
-
"language_info": {
|
16 |
-
"name": "python"
|
17 |
-
},
|
18 |
-
"gpuClass": "standard",
|
19 |
-
"accelerator": "GPU"
|
20 |
-
},
|
21 |
-
"cells": [
|
22 |
-
{
|
23 |
-
"cell_type": "markdown",
|
24 |
-
"metadata": {
|
25 |
-
"id": "view-in-github",
|
26 |
-
"colab_type": "text"
|
27 |
-
},
|
28 |
-
"source": [
|
29 |
-
"<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
-
]
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"cell_type": "markdown",
|
34 |
-
"source": [
|
35 |
-
"# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
|
36 |
-
"## Parameters\n",
|
37 |
-
"\n",
|
38 |
-
"\n",
|
39 |
-
"1. Wandb\n",
|
40 |
-
"\n"
|
41 |
-
],
|
42 |
-
"metadata": {
|
43 |
-
"id": "S-tXDWP8WTLc"
|
44 |
-
}
|
45 |
-
},
|
46 |
-
{
|
47 |
-
"cell_type": "code",
|
48 |
-
"source": [
|
49 |
-
"from getpass import getpass\n",
|
50 |
-
"import os\n",
|
51 |
-
"os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
|
52 |
-
],
|
53 |
-
"metadata": {
|
54 |
-
"id": "1ZtdYgxWNGwZ"
|
55 |
-
},
|
56 |
-
"execution_count": null,
|
57 |
-
"outputs": []
|
58 |
-
},
|
59 |
-
{
|
60 |
-
"cell_type": "markdown",
|
61 |
-
"source": [
|
62 |
-
"2. enjoy.py parameters"
|
63 |
-
],
|
64 |
-
"metadata": {
|
65 |
-
"id": "ao0nAh3MOdN7"
|
66 |
-
}
|
67 |
-
},
|
68 |
-
{
|
69 |
-
"cell_type": "code",
|
70 |
-
"source": [
|
71 |
-
"WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
|
72 |
-
],
|
73 |
-
"metadata": {
|
74 |
-
"id": "jKL_NFhVOjSc"
|
75 |
-
},
|
76 |
-
"execution_count": 2,
|
77 |
-
"outputs": []
|
78 |
-
},
|
79 |
-
{
|
80 |
-
"cell_type": "markdown",
|
81 |
-
"source": [
|
82 |
-
"## Setup\n",
|
83 |
-
"Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
|
84 |
-
],
|
85 |
-
"metadata": {
|
86 |
-
"id": "bsG35Io0hmKG"
|
87 |
-
}
|
88 |
-
},
|
89 |
-
{
|
90 |
-
"cell_type": "code",
|
91 |
-
"source": [
|
92 |
-
"%%capture\n",
|
93 |
-
"!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
|
94 |
-
],
|
95 |
-
"metadata": {
|
96 |
-
"id": "k5ynTV25hdAf"
|
97 |
-
},
|
98 |
-
"execution_count": 3,
|
99 |
-
"outputs": []
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"cell_type": "markdown",
|
103 |
-
"source": [
|
104 |
-
"Installing the correct packages:\n",
|
105 |
-
"\n",
|
106 |
-
"While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
|
107 |
-
],
|
108 |
-
"metadata": {
|
109 |
-
"id": "jKxGok-ElYQ7"
|
110 |
-
}
|
111 |
-
},
|
112 |
-
{
|
113 |
-
"cell_type": "code",
|
114 |
-
"source": [
|
115 |
-
"%%capture\n",
|
116 |
-
"!apt install python-opengl\n",
|
117 |
-
"!apt install ffmpeg\n",
|
118 |
-
"!apt install xvfb\n",
|
119 |
-
"!apt install swig"
|
120 |
-
],
|
121 |
-
"metadata": {
|
122 |
-
"id": "nn6EETTc2Ewf"
|
123 |
-
},
|
124 |
-
"execution_count": 4,
|
125 |
-
"outputs": []
|
126 |
-
},
|
127 |
-
{
|
128 |
-
"cell_type": "code",
|
129 |
-
"source": [
|
130 |
-
"%%capture\n",
|
131 |
-
"%cd /content/rl-algo-impls\n",
|
132 |
-
"!pip install -r colab_requirements.txt"
|
133 |
-
],
|
134 |
-
"metadata": {
|
135 |
-
"id": "AfZh9rH3yQii"
|
136 |
-
},
|
137 |
-
"execution_count": 5,
|
138 |
-
"outputs": []
|
139 |
-
},
|
140 |
-
{
|
141 |
-
"cell_type": "markdown",
|
142 |
-
"source": [
|
143 |
-
"## Run Once Per Runtime"
|
144 |
-
],
|
145 |
-
"metadata": {
|
146 |
-
"id": "4o5HOLjc4wq7"
|
147 |
-
}
|
148 |
-
},
|
149 |
-
{
|
150 |
-
"cell_type": "code",
|
151 |
-
"source": [
|
152 |
-
"import wandb\n",
|
153 |
-
"wandb.login()"
|
154 |
-
],
|
155 |
-
"metadata": {
|
156 |
-
"id": "PCXa5tdS2qFX"
|
157 |
-
},
|
158 |
-
"execution_count": null,
|
159 |
-
"outputs": []
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"cell_type": "markdown",
|
163 |
-
"source": [
|
164 |
-
"## Restart Session beteween runs"
|
165 |
-
],
|
166 |
-
"metadata": {
|
167 |
-
"id": "AZBZfSUV43JQ"
|
168 |
-
}
|
169 |
-
},
|
170 |
-
{
|
171 |
-
"cell_type": "code",
|
172 |
-
"source": [
|
173 |
-
"%%capture\n",
|
174 |
-
"from pyvirtualdisplay import Display\n",
|
175 |
-
"\n",
|
176 |
-
"virtual_display = Display(visible=0, size=(1400, 900))\n",
|
177 |
-
"virtual_display.start()"
|
178 |
-
],
|
179 |
-
"metadata": {
|
180 |
-
"id": "VzemeQJP2NO9"
|
181 |
-
},
|
182 |
-
"execution_count": 7,
|
183 |
-
"outputs": []
|
184 |
-
},
|
185 |
-
{
|
186 |
-
"cell_type": "code",
|
187 |
-
"source": [
|
188 |
-
"%cd /content/rl-algo-impls\n",
|
189 |
-
"!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
|
190 |
-
],
|
191 |
-
"metadata": {
|
192 |
-
"id": "07aHYFH1zfXa"
|
193 |
-
},
|
194 |
-
"execution_count": null,
|
195 |
-
"outputs": []
|
196 |
-
}
|
197 |
-
]
|
198 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colab_requirements.txt
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
AutoROM.accept-rom-license >= 0.4.2, < 0.5
|
2 |
-
stable-baselines3[extra] >= 1.7.0, < 1.8
|
3 |
-
gym[box2d] >= 0.21.0, < 0.22
|
4 |
-
pyglet == 1.5.27
|
5 |
-
wandb >= 0.13.10, < 0.14
|
6 |
-
pyvirtualdisplay == 3.0
|
7 |
-
pybullet >= 3.2.5, < 3.3
|
8 |
-
tabulate >= 0.9.0, < 0.10
|
9 |
-
huggingface-hub >= 0.12.0, < 0.13
|
10 |
-
numexpr >= 2.8.4, < 2.9
|
11 |
-
gym3 >= 0.3.3, < 0.4
|
12 |
-
glfw >= 1.12.0, < 1.13
|
13 |
-
procgen >= 0.10.7, < 0.11
|
14 |
-
ipython >= 8.10.0, < 8.11
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
colab_train.ipynb
DELETED
@@ -1,200 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nbformat": 4,
|
3 |
-
"nbformat_minor": 0,
|
4 |
-
"metadata": {
|
5 |
-
"colab": {
|
6 |
-
"provenance": [],
|
7 |
-
"machine_shape": "hm",
|
8 |
-
"authorship_tag": "ABX9TyMmemQnx6G7GOnn6XBdjgxY",
|
9 |
-
"include_colab_link": true
|
10 |
-
},
|
11 |
-
"kernelspec": {
|
12 |
-
"name": "python3",
|
13 |
-
"display_name": "Python 3"
|
14 |
-
},
|
15 |
-
"language_info": {
|
16 |
-
"name": "python"
|
17 |
-
},
|
18 |
-
"gpuClass": "standard",
|
19 |
-
"accelerator": "GPU"
|
20 |
-
},
|
21 |
-
"cells": [
|
22 |
-
{
|
23 |
-
"cell_type": "markdown",
|
24 |
-
"metadata": {
|
25 |
-
"id": "view-in-github",
|
26 |
-
"colab_type": "text"
|
27 |
-
},
|
28 |
-
"source": [
|
29 |
-
"<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
30 |
-
]
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"cell_type": "markdown",
|
34 |
-
"source": [
|
35 |
-
"# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
|
36 |
-
"## Parameters\n",
|
37 |
-
"\n",
|
38 |
-
"\n",
|
39 |
-
"1. Wandb\n",
|
40 |
-
"\n"
|
41 |
-
],
|
42 |
-
"metadata": {
|
43 |
-
"id": "S-tXDWP8WTLc"
|
44 |
-
}
|
45 |
-
},
|
46 |
-
{
|
47 |
-
"cell_type": "code",
|
48 |
-
"source": [
|
49 |
-
"from getpass import getpass\n",
|
50 |
-
"import os\n",
|
51 |
-
"os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
|
52 |
-
],
|
53 |
-
"metadata": {
|
54 |
-
"id": "1ZtdYgxWNGwZ"
|
55 |
-
},
|
56 |
-
"execution_count": null,
|
57 |
-
"outputs": []
|
58 |
-
},
|
59 |
-
{
|
60 |
-
"cell_type": "markdown",
|
61 |
-
"source": [
|
62 |
-
"2. train run parameters"
|
63 |
-
],
|
64 |
-
"metadata": {
|
65 |
-
"id": "ao0nAh3MOdN7"
|
66 |
-
}
|
67 |
-
},
|
68 |
-
{
|
69 |
-
"cell_type": "code",
|
70 |
-
"source": [
|
71 |
-
"ALGO = \"ppo\"\n",
|
72 |
-
"ENV = \"CartPole-v1\"\n",
|
73 |
-
"SEED = 1"
|
74 |
-
],
|
75 |
-
"metadata": {
|
76 |
-
"id": "jKL_NFhVOjSc"
|
77 |
-
},
|
78 |
-
"execution_count": null,
|
79 |
-
"outputs": []
|
80 |
-
},
|
81 |
-
{
|
82 |
-
"cell_type": "markdown",
|
83 |
-
"source": [
|
84 |
-
"## Setup\n",
|
85 |
-
"Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
|
86 |
-
],
|
87 |
-
"metadata": {
|
88 |
-
"id": "bsG35Io0hmKG"
|
89 |
-
}
|
90 |
-
},
|
91 |
-
{
|
92 |
-
"cell_type": "code",
|
93 |
-
"source": [
|
94 |
-
"%%capture\n",
|
95 |
-
"!git clone https://github.com/sgoodfriend/rl-algo-impls.git"
|
96 |
-
],
|
97 |
-
"metadata": {
|
98 |
-
"id": "k5ynTV25hdAf"
|
99 |
-
},
|
100 |
-
"execution_count": null,
|
101 |
-
"outputs": []
|
102 |
-
},
|
103 |
-
{
|
104 |
-
"cell_type": "markdown",
|
105 |
-
"source": [
|
106 |
-
"Installing the correct packages:\n",
|
107 |
-
"\n",
|
108 |
-
"While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
|
109 |
-
],
|
110 |
-
"metadata": {
|
111 |
-
"id": "jKxGok-ElYQ7"
|
112 |
-
}
|
113 |
-
},
|
114 |
-
{
|
115 |
-
"cell_type": "code",
|
116 |
-
"source": [
|
117 |
-
"%%capture\n",
|
118 |
-
"!apt install python-opengl\n",
|
119 |
-
"!apt install ffmpeg\n",
|
120 |
-
"!apt install xvfb\n",
|
121 |
-
"!apt install swig"
|
122 |
-
],
|
123 |
-
"metadata": {
|
124 |
-
"id": "nn6EETTc2Ewf"
|
125 |
-
},
|
126 |
-
"execution_count": null,
|
127 |
-
"outputs": []
|
128 |
-
},
|
129 |
-
{
|
130 |
-
"cell_type": "code",
|
131 |
-
"source": [
|
132 |
-
"%%capture\n",
|
133 |
-
"%cd /content/rl-algo-impls\n",
|
134 |
-
"!pip install -r colab_requirements.txt"
|
135 |
-
],
|
136 |
-
"metadata": {
|
137 |
-
"id": "AfZh9rH3yQii"
|
138 |
-
},
|
139 |
-
"execution_count": null,
|
140 |
-
"outputs": []
|
141 |
-
},
|
142 |
-
{
|
143 |
-
"cell_type": "markdown",
|
144 |
-
"source": [
|
145 |
-
"## Run Once Per Runtime"
|
146 |
-
],
|
147 |
-
"metadata": {
|
148 |
-
"id": "4o5HOLjc4wq7"
|
149 |
-
}
|
150 |
-
},
|
151 |
-
{
|
152 |
-
"cell_type": "code",
|
153 |
-
"source": [
|
154 |
-
"import wandb\n",
|
155 |
-
"wandb.login()"
|
156 |
-
],
|
157 |
-
"metadata": {
|
158 |
-
"id": "PCXa5tdS2qFX"
|
159 |
-
},
|
160 |
-
"execution_count": null,
|
161 |
-
"outputs": []
|
162 |
-
},
|
163 |
-
{
|
164 |
-
"cell_type": "markdown",
|
165 |
-
"source": [
|
166 |
-
"## Restart Session beteween runs"
|
167 |
-
],
|
168 |
-
"metadata": {
|
169 |
-
"id": "AZBZfSUV43JQ"
|
170 |
-
}
|
171 |
-
},
|
172 |
-
{
|
173 |
-
"cell_type": "code",
|
174 |
-
"source": [
|
175 |
-
"%%capture\n",
|
176 |
-
"from pyvirtualdisplay import Display\n",
|
177 |
-
"\n",
|
178 |
-
"virtual_display = Display(visible=0, size=(1400, 900))\n",
|
179 |
-
"virtual_display.start()"
|
180 |
-
],
|
181 |
-
"metadata": {
|
182 |
-
"id": "VzemeQJP2NO9"
|
183 |
-
},
|
184 |
-
"execution_count": null,
|
185 |
-
"outputs": []
|
186 |
-
},
|
187 |
-
{
|
188 |
-
"cell_type": "code",
|
189 |
-
"source": [
|
190 |
-
"%cd /content/rl-algo-impls\n",
|
191 |
-
"!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
|
192 |
-
],
|
193 |
-
"metadata": {
|
194 |
-
"id": "07aHYFH1zfXa"
|
195 |
-
},
|
196 |
-
"execution_count": null,
|
197 |
-
"outputs": []
|
198 |
-
}
|
199 |
-
]
|
200 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"policy_class": {":type:": "<class 'abc.ABCMeta'>", ":serialized:": "gAWVOwAAAAAAAACMIXN0YWJsZV9iYXNlbGluZXMzLmNvbW1vbi5wb2xpY2llc5SMEUFjdG9yQ3JpdGljUG9saWN5lJOULg==", "__module__": "stable_baselines3.common.policies", "__doc__": "\n Policy class for actor-critic algorithms (has both policy and value prediction).\n Used by A2C, PPO and the likes.\n\n :param observation_space: Observation space\n :param action_space: Action space\n :param lr_schedule: Learning rate schedule (could be constant)\n :param net_arch: The specification of the policy and value networks.\n :param activation_fn: Activation function\n :param ortho_init: Whether to use or not orthogonal initialization\n :param use_sde: Whether to use State Dependent Exploration or not\n :param log_std_init: Initial value for the log standard deviation\n :param full_std: Whether to use (n_features x n_actions) parameters\n for the std instead of only (n_features,) when using gSDE\n :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure\n a positive standard deviation (cf paper). It allows to keep variance\n above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.\n :param squash_output: Whether to squash the output using a tanh function,\n this allows to ensure boundaries when using gSDE.\n :param features_extractor_class: Features extractor to use.\n :param features_extractor_kwargs: Keyword arguments\n to pass to the features extractor.\n :param share_features_extractor: If True, the features extractor is shared between the policy and value networks.\n :param normalize_images: Whether to normalize images or not,\n dividing by 255.0 (True by default)\n :param optimizer_class: The optimizer to use,\n ``th.optim.Adam`` by default\n :param optimizer_kwargs: Additional keyword arguments,\n excluding the learning rate, to pass to the optimizer\n ", "__init__": "<function ActorCriticPolicy.__init__ at 0x7ff06ed7c3a0>", "_get_constructor_parameters": "<function ActorCriticPolicy._get_constructor_parameters at 0x7ff06ed7c430>", "reset_noise": "<function ActorCriticPolicy.reset_noise at 0x7ff06ed7c4c0>", "_build_mlp_extractor": "<function ActorCriticPolicy._build_mlp_extractor at 0x7ff06ed7c550>", "_build": "<function ActorCriticPolicy._build at 0x7ff06ed7c5e0>", "forward": "<function ActorCriticPolicy.forward at 0x7ff06ed7c670>", "extract_features": "<function ActorCriticPolicy.extract_features at 0x7ff06ed7c700>", "_get_action_dist_from_latent": "<function ActorCriticPolicy._get_action_dist_from_latent at 0x7ff06ed7c790>", "_predict": "<function ActorCriticPolicy._predict at 0x7ff06ed7c820>", "evaluate_actions": "<function ActorCriticPolicy.evaluate_actions at 0x7ff06ed7c8b0>", "get_distribution": "<function ActorCriticPolicy.get_distribution at 0x7ff06ed7c940>", "predict_values": "<function ActorCriticPolicy.predict_values at 0x7ff06ed7c9d0>", "__abstractmethods__": "frozenset()", "_abc_impl": "<_abc_data object at 0x7ff06ed70f60>"}, "verbose": 1, "policy_kwargs": {":type:": "<class 'dict'>", ":serialized:": "gAWVowAAAAAAAAB9lCiMDGxvZ19zdGRfaW5pdJRK/v///4wKb3J0aG9faW5pdJSJjA9vcHRpbWl6ZXJfY2xhc3OUjBN0b3JjaC5vcHRpbS5ybXNwcm9wlIwHUk1TcHJvcJSTlIwQb3B0aW1pemVyX2t3YXJnc5R9lCiMBWFscGhhlEc/764UeuFHrowDZXBzlEc+5Pi1iONo8YwMd2VpZ2h0X2RlY2F5lEsAdXUu", "log_std_init": -2, "ortho_init": false, "optimizer_class": "<class 'torch.optim.rmsprop.RMSprop'>", "optimizer_kwargs": {"alpha": 0.99, "eps": 1e-05, "weight_decay": 0}}, "observation_space": {":type:": "<class 'gym.spaces.box.Box'>", ":serialized:": "gAWVZwIAAAAAAACMDmd5bS5zcGFjZXMuYm94lIwDQm94lJOUKYGUfZQojAVkdHlwZZSMBW51bXB5lGgFk5SMAmY0lImIh5RSlChLA4wBPJROTk5K/////0r/////SwB0lGKMBl9zaGFwZZRLHIWUjANsb3eUjBJudW1weS5jb3JlLm51bWVyaWOUjAtfZnJvbWJ1ZmZlcpSTlCiWcAAAAAAAAAAAAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/AACA/wAAgP8AAID/lGgKSxyFlIwBQ5R0lFKUjARoaWdolGgSKJZwAAAAAAAAAAAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH8AAIB/AACAfwAAgH+UaApLHIWUaBV0lFKUjA1ib3VuZGVkX2JlbG93lGgSKJYcAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACUaAeMAmIxlImIh5RSlChLA4wBfJROTk5K/////0r/////SwB0lGJLHIWUaBV0lFKUjA1ib3VuZGVkX2Fib3ZllGgSKJYcAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAACUaCFLHIWUaBV0lFKUjApfbnBfcmFuZG9tlE51Yi4=", "dtype": "float32", "_shape": [28], "low": "[-inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf\n -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf -inf]", "high": "[inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf inf\n inf inf inf inf inf inf inf inf inf inf]", "bounded_below": "[False False False False False False False False False False False False\n False False False False False False False False False False False False\n False False False False]", "bounded_above": "[False False False False False False False False False False False False\n False False False False False False False False False False False False\n False False False False]", "_np_random": null}, "action_space": {":type:": "<class 'gym.spaces.box.Box'>", ":serialized:": "gAWVnwEAAAAAAACMDmd5bS5zcGFjZXMuYm94lIwDQm94lJOUKYGUfZQojAVkdHlwZZSMBW51bXB5lGgFk5SMAmY0lImIh5RSlChLA4wBPJROTk5K/////0r/////SwB0lGKMBl9zaGFwZZRLCIWUjANsb3eUjBJudW1weS5jb3JlLm51bWVyaWOUjAtfZnJvbWJ1ZmZlcpSTlCiWIAAAAAAAAAAAAIC/AACAvwAAgL8AAIC/AACAvwAAgL8AAIC/AACAv5RoCksIhZSMAUOUdJRSlIwEaGlnaJRoEiiWIAAAAAAAAAAAAIA/AACAPwAAgD8AAIA/AACAPwAAgD8AAIA/AACAP5RoCksIhZRoFXSUUpSMDWJvdW5kZWRfYmVsb3eUaBIolggAAAAAAAAAAQEBAQEBAQGUaAeMAmIxlImIh5RSlChLA4wBfJROTk5K/////0r/////SwB0lGJLCIWUaBV0lFKUjA1ib3VuZGVkX2Fib3ZllGgSKJYIAAAAAAAAAAEBAQEBAQEBlGghSwiFlGgVdJRSlIwKX25wX3JhbmRvbZROdWIu", "dtype": "float32", "_shape": [8], "low": "[-1. -1. -1. -1. -1. -1. -1. -1.]", "high": "[1. 1. 1. 1. 1. 1. 1. 1.]", "bounded_below": "[ True True True True True True True True]", "bounded_above": "[ True True True True True True True True]", "_np_random": null}, "n_envs": 4, "num_timesteps": 2000000, "_total_timesteps": 2000000, "_num_timesteps_at_start": 0, "seed": null, "action_noise": null, "start_time": 1674588743994659442, "learning_rate": {":type:": "<class 'function'>", ":serialized:": "gAWVdgIAAAAAAACMF2Nsb3VkcGlja2xlLmNsb3VkcGlja2xllIwOX21ha2VfZnVuY3Rpb26Uk5QoaACMDV9idWlsdGluX3R5cGWUk5SMCENvZGVUeXBllIWUUpQoSwFLAEsASwFLAksTQwiIAHwAFABTAJROhZQpjBJwcm9ncmVzc19yZW1haW5pbmeUhZSMHzxpcHl0aG9uLWlucHV0LTEzLWVhYTdkOGY5N2ZkNj6UjAhzY2hlZHVsZZRLBEMCAAGUjA1pbml0aWFsX3ZhbHVllIWUKXSUUpR9lCiMC19fcGFja2FnZV9flE6MCF9fbmFtZV9flIwIX19tYWluX1+UdU5OaACMEF9tYWtlX2VtcHR5X2NlbGyUk5QpUpSFlHSUUpSMHGNsb3VkcGlja2xlLmNsb3VkcGlja2xlX2Zhc3SUjBJfZnVuY3Rpb25fc2V0c3RhdGWUk5RoHH2UfZQoaBVoDYwMX19xdWFsbmFtZV9flIwhbGluZWFyX3NjaGVkdWxlLjxsb2NhbHM+LnNjaGVkdWxllIwPX19hbm5vdGF0aW9uc19flH2UKIwScHJvZ3Jlc3NfcmVtYWluaW5nlIwIYnVpbHRpbnOUjAVmbG9hdJSTlIwGcmV0dXJulGgpdYwOX19rd2RlZmF1bHRzX1+UTowMX19kZWZhdWx0c19flE6MCl9fbW9kdWxlX1+UaBaMB19fZG9jX1+UTowLX19jbG9zdXJlX1+UaACMCl9tYWtlX2NlbGyUk5RHP091EE1VHWmFlFKUhZSMF19jbG91ZHBpY2tsZV9zdWJtb2R1bGVzlF2UjAtfX2dsb2JhbHNfX5R9lHWGlIZSMC4="}, "tensorboard_log": null, "lr_schedule": {":type:": "<class 'function'>", ":serialized:": "gAWVdgIAAAAAAACMF2Nsb3VkcGlja2xlLmNsb3VkcGlja2xllIwOX21ha2VfZnVuY3Rpb26Uk5QoaACMDV9idWlsdGluX3R5cGWUk5SMCENvZGVUeXBllIWUUpQoSwFLAEsASwFLAksTQwiIAHwAFABTAJROhZQpjBJwcm9ncmVzc19yZW1haW5pbmeUhZSMHzxpcHl0aG9uLWlucHV0LTEzLWVhYTdkOGY5N2ZkNj6UjAhzY2hlZHVsZZRLBEMCAAGUjA1pbml0aWFsX3ZhbHVllIWUKXSUUpR9lCiMC19fcGFja2FnZV9flE6MCF9fbmFtZV9flIwIX19tYWluX1+UdU5OaACMEF9tYWtlX2VtcHR5X2NlbGyUk5QpUpSFlHSUUpSMHGNsb3VkcGlja2xlLmNsb3VkcGlja2xlX2Zhc3SUjBJfZnVuY3Rpb25fc2V0c3RhdGWUk5RoHH2UfZQoaBVoDYwMX19xdWFsbmFtZV9flIwhbGluZWFyX3NjaGVkdWxlLjxsb2NhbHM+LnNjaGVkdWxllIwPX19hbm5vdGF0aW9uc19flH2UKIwScHJvZ3Jlc3NfcmVtYWluaW5nlIwIYnVpbHRpbnOUjAVmbG9hdJSTlIwGcmV0dXJulGgpdYwOX19rd2RlZmF1bHRzX1+UTowMX19kZWZhdWx0c19flE6MCl9fbW9kdWxlX1+UaBaMB19fZG9jX1+UTowLX19jbG9zdXJlX1+UaACMCl9tYWtlX2NlbGyUk5RHP091EE1VHWmFlFKUhZSMF19jbG91ZHBpY2tsZV9zdWJtb2R1bGVzlF2UjAtfX2dsb2JhbHNfX5R9lHWGlIZSMC4="}, "_last_obs": {":type:": "<class 'numpy.ndarray'>", ":serialized:": "gAWVNQIAAAAAAACMEm51bXB5LmNvcmUubnVtZXJpY5SMC19mcm9tYnVmZmVylJOUKJbAAQAAAAAAANkXHT/P+Bu/1A+1PjIvpT89Bh+/zjqjPhDbqT4Dqha/AjV/P+zYsr+uKGE/dj6Xv6Cxv7+R4wc8NxoKv6piD8CimWe/+TfIPi55aD9Zeju9gErvPhI+3L5IS6C/aYMmP8k7cr8P4ve/8OcFwEF3ab/IB8Q+iiKSPjT47T5cWJE/BK7Rvr5ZgD+8Kak900civ+X6jr4MOYU9Y5Shv5Jbtz6m2U++w9apP5QSCT7zfbk/P/O2PzWg8Lp/kpq+P/AOvwa3jb+DbFA/cyMdP3pKXj/JO3K//jAEP/DnBcDGWow/sCkSP3Xwnj5GF+s+0PqgP1LPhL70tAs/D5JdPdH8PDwVeaM+7Rabv2R/mz5gexrADzf+vtnLpj8/PMu+cPeXP8wagb4/bOU/gJVZPfvU8r+XP0C/uvbjvuIAMT9kUDM/yTtyv/4wBD9+tfQ+xlqMP1d1jT2Zgbe+1bHkPva9wj/bEYW/mrFZP9iSDD9GL/G+l7yMP/DjTz9xM7U/HN+OPqTRRr94K0XAr08LO38geb/41IG/tzC3v9ZobD/na868JuUBvnf+tr+WGXi/NfUov0BGhz8P4ve/frX0PkF3ab+UjAVudW1weZSMBWR0eXBllJOUjAJmNJSJiIeUUpQoSwOMATyUTk5OSv////9K/////0sAdJRiSwRLHIaUjAFDlHSUUpQu"}, "_last_episode_starts": {":type:": "<class 'numpy.ndarray'>", ":serialized:": "gAWVdwAAAAAAAACMEm51bXB5LmNvcmUubnVtZXJpY5SMC19mcm9tYnVmZmVylJOUKJYEAAAAAAAAAAAAAACUjAVudW1weZSMBWR0eXBllJOUjAJiMZSJiIeUUpQoSwOMAXyUTk5OSv////9K/////0sAdJRiSwSFlIwBQ5R0lFKULg=="}, "_last_original_obs": {":type:": "<class 'numpy.ndarray'>", ":serialized:": "gAWVNQIAAAAAAACMEm51bXB5LmNvcmUubnVtZXJpY5SMC19mcm9tYnVmZmVylJOUKJbAAQAAAAAAAAAAAAAtoGm1AACAPwAAAAAAAAAAAAAAAAAAAAAAAACA9o2EvQAAAACgRPm/AAAAAMwsiT0AAAAAyVXjPwAAAABeHa69AAAAAFdF5z8AAAAAAcmpPQAAAAApeea/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAHEHLNAAAgD8AAAAAAAAAAAAAAAAAAAAAAAAAgEwf2b0AAAAA2Gz1vwAAAADLltm6AAAAAM9Z7D8AAAAA3yVmvQAAAADn69s/AAAAADiBBz0AAAAAw1rrvwAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAOBNMrYAAIA/AAAAAAAAAAAAAAAAAAAAAAAAAIAyWb08AAAAAEuk/L8AAAAAL1SzvQAAAAByavk/AAAAAP8V470AAAAA/m/iPwAAAAAtweW9AAAAAExT9L8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAADP6ZA0AACAPwAAAAAAAAAAAAAAAAAAAAAAAACAfKcNPgAAAAAF3eK/AAAAACbLMT0AAAAARnIAQAAAAAArb0K7AAAAAMWz/z8AAAAANVvXvQAAAAAaNe2/AAAAAAAAAAAAAAAAAAAAAAAAAACUjAVudW1weZSMBWR0eXBllJOUjAJmNJSJiIeUUpQoSwOMATyUTk5OSv////9K/////0sAdJRiSwRLHIaUjAFDlHSUUpQu"}, "_episode_num": 0, "use_sde": true, "sde_sample_freq": -1, "_current_progress_remaining": 0.0, "ep_info_buffer": {":type:": "<class 'collections.deque'>", ":serialized:": "gAWVRAwAAAAAAACMC2NvbGxlY3Rpb25zlIwFZGVxdWWUk5QpS2SGlFKUKH2UKIwBcpRHQKaK8jRlYlqMAWyUTegDjAF0lEdArJafl8w6AHV9lChoBkdAproSsySFG2gHTegDaAhHQKyeau01IiF1fZQoaAZHQKW/SZpi7TVoB03oA2gIR0Csnr6jnFHbdX2UKGgGR0ClbyI/7iyZaAdN6ANoCEdArKJTlmvnsHV9lChoBkdApXSPAh0QsmgHTegDaAhHQKyiYGeMAFR1fZQoaAZHQKYeObMHKOloB03oA2gIR0Csqia/qPfbdX2UKGgGR0Cl25fAbhm5aAdN6ANoCEdArKp0rAgxJ3V9lChoBkdApGo8rmQr+mgHTegDaAhHQKyt3NGmUGF1fZQoaAZHQKW26Hk92X9oB03oA2gIR0CsreV76YVqdX2UKGgGR0Clpgs1KoQ4aAdN6ANoCEdArLXvPw/gSHV9lChoBkdApZRJfOUt7WgHTegDaAhHQKy2Q7jDKo11fZQoaAZHQKWQJ2alUIdoB03oA2gIR0Csubl7Uoa2dX2UKGgGR0CmLI2nsLOSaAdN6ANoCEdArLnCPU8V6HV9lChoBkdApVqDB9Cu2mgHTegDaAhHQKzBwo1k1/F1fZQoaAZHQKZuXCqp97ZoB03oA2gIR0CswhCEQGwBdX2UKGgGR0Cmlu0j9n9OaAdN6ANoCEdArMWdjbzshXV9lChoBkdApnp/RgJC0GgHTegDaAhHQKzFpoL5RCR1fZQoaAZHQKR8O5J9RaZoB03oA2gIR0CszVechC+ldX2UKGgGR0CltMF41P30aAdN6ANoCEdArM2rY/Vy3nV9lChoBkdApc4PphWo32gHTegDaAhHQKzRWe/5+H91fZQoaAZHQKZrGVUMoc9oB03oA2gIR0Cs0WMKTjebdX2UKGgGR0CmgCTvy9VWaAdN6ANoCEdArNj4dGRV63V9lChoBkdApfwql3yI6GgHTegDaAhHQKzZS31BdD91fZQoaAZHQKZoYRq46OpoB03oA2gIR0Cs3Nb8m8dxdX2UKGgGR0ClFq4sVclgaAdN6ANoCEdArNzf2PDHfnV9lChoBkdApdattVJcxGgHTegDaAhHQKzkmxptaZB1fZQoaAZHQKWUFjVhCt1oB03oA2gIR0Cs5O8LjPv8dX2UKGgGR0CmL6/N7jT8aAdN6ANoCEdArOhYIKMNt3V9lChoBkdApdR+nQ6ZIGgHTegDaAhHQKzoYBRQ7911fZQoaAZHQKYjF24d6s1oB03oA2gIR0Cs8AV6/qPfdX2UKGgGR0ClmIZ9Vmz0aAdN6ANoCEdArPBaKpDNQnV9lChoBkdApk9kpiI+GGgHTegDaAhHQKz0EkN4JNV1fZQoaAZHQKVyiL2HtWxoB03oA2gIR0Cs9Bv5pJwsdX2UKGgGR0CmUOSncclxaAdN6ANoCEdArPvNdVvMr3V9lChoBkdAprML/VAiV2gHTegDaAhHQKz8HtdiUgV1fZQoaAZHQKWHJzcRDkVoB03oA2gIR0Cs/4O6NEPUdX2UKGgGR0Clvp+QlruZaAdN6ANoCEdArP+Lvd/KAHV9lChoBkdAptM4bfgrH2gHTegDaAhHQK0HKjbBXS11fZQoaAZHQKTNTr/sE7poB03oA2gIR0CtB3/xtpEhdX2UKGgGR0ClpbtFjNILaAdN6ANoCEdArQsPj+717XV9lChoBkdApN2DeZXuE2gHTegDaAhHQK0LGJl8PWh1fZQoaAZHQKOCkuKXOW1oB03oA2gIR0CtEulT3qRmdX2UKGgGR0Cl+HYb83uNaAdN6ANoCEdArRM6aPS2IHV9lChoBkdApPy5dKNADGgHTegDaAhHQK0Wr5WzWwx1fZQoaAZHQKZWr4dIXj5oB03oA2gIR0CtFrfCAMDwdX2UKGgGR0CmKUQW3z+WaAdN6ANoCEdArR6OpQ1rI3V9lChoBkdAphbtDtw71mgHTegDaAhHQK0e3Ktga3t1fZQoaAZHQKW7YPTXrdFoB03oA2gIR0CtIkclolD4dX2UKGgGR0Cl5MDps41haAdN6ANoCEdArSJQ/RmbsnV9lChoBkdApiBvW4EwFmgHTegDaAhHQK0p+HEdeY51fZQoaAZHQKQNrjc2zfJoB03oA2gIR0CtKkj28IzFdX2UKGgGR0Cckdajvd/KaAdNggNoCEdArSyY1pCa7XV9lChoBkdApomsPvrnkmgHTegDaAhHQK0tzR4yGi51fZQoaAZHQKZk+Fyq+8JoB03oA2gIR0CtNXu1OTJRdX2UKGgGR0CllYyxqwhXaAdN6ANoCEdArTXIMx46fnV9lChoBkdAphfyzRhMJ2gHTegDaAhHQK04BurIYFd1fZQoaAZHQKSM9KV6eGxoB03oA2gIR0CtOUYlpoK2dX2UKGgGR0CmiWAJTl1baAdN6ANoCEdArUEHCj1wpHV9lChoBkdApgHe0gKWs2gHTegDaAhHQK1BYPxQSBd1fZQoaAZHQKcAthnanJloB03oA2gIR0CtQ791U2k0dX2UKGgGR0CiDy96Tnq3aAdN6ANoCEdArUTnRPXTVnV9lChoBkdApFqvcSGrS2gHTegDaAhHQK1MZh/Aj6h1fZQoaAZHQKXAlP/rB0poB03oA2gIR0CtTLSQHRkVdX2UKGgGR0CjmJprDZUUaAdN6ANoCEdArU7xddE9dXV9lChoBkdAptUnsZ5zHWgHTegDaAhHQK1QI8VYZEV1fZQoaAZHQKLkzDlYEGJoB03oA2gIR0CtWCQob4rSdX2UKGgGR0CeM9MS9M9KaAdN6ANoCEdArVh42GZeA3V9lChoBkdApeqzwMH8j2gHTegDaAhHQK1a0TSsr/d1fZQoaAZHQKYbd2wmmchoB03oA2gIR0CtXAZeJHiFdX2UKGgGR0Cl3GUtRNypaAdN6ANoCEdArWO18qnWKHV9lChoBkdApWpvz+WGAWgHTegDaAhHQK1kBFn7Hhl1fZQoaAZHQKZSKHBUJfJoB03oA2gIR0CtZlCuMdcTdX2UKGgGR0CmpdCiAUcoaAdN6ANoCEdArWeCZnctXnV9lChoBkdAppM5Q53kgmgHTegDaAhHQK1vG8h9srN1fZQoaAZHQKYExGDtgKFoB03oA2gIR0Ctb22CVbA2dX2UKGgGR0CmtFdFWn0kaAdN6ANoCEdArXG/VurIYHV9lChoBkdApfNvitJWemgHTegDaAhHQK1zFY5DJEJ1fZQoaAZHQKXdBBSk0rNoB03oA2gIR0CteqxdY4hmdX2UKGgGR0CmhG8kdFOPaAdN6ANoCEdArXr6QvHtGHV9lChoBkdApU/SYPXkHWgHTegDaAhHQK19RnVXmvJ1fZQoaAZHQKWUeXk5p8FoB03oA2gIR0Ctfnhun/DMdX2UKGgGR0CkjzIt16mgaAdN6ANoCEdArYYPdCVrynV9lChoBkdApPtMZ75VO2gHTegDaAhHQK2GYNAkcCJ1fZQoaAZHQKZMGafjCHhoB03oA2gIR0CtiLaOYIBzdX2UKGgGR0Cl0eULMLWqaAdN6ANoCEdArYnyg7HQyHV9lChoBkdApd5MyN4qw2gHTegDaAhHQK2Rs0iQkop1fZQoaAZHQKVcAo5xR2toB03oA2gIR0CtkgYjB2wFdX2UKGgGR0ClzucPFvQ4aAdN6ANoCEdArZRV0NjLCHV9lChoBkdApVq6v7m+02gHTegDaAhHQK2VmQjD8+B1fZQoaAZHQKSTCldC3PRoB03oA2gIR0CtnXYwqRU4dX2UKGgGR0Cl5iNfw7T2aAdN6ANoCEdArZ3F+iJwbXV9lChoBkdApdJXDYRNAWgHTegDaAhHQK2gGQlKK511fZQoaAZHQKZjJvMKTjhoB03oA2gIR0CtoWfvWpZPdX2UKGgGR0Cllc9fsu3+aAdN6ANoCEdAralP+bVjJHV9lChoBkdApTLXMSsbN2gHTegDaAhHQK2pqDU3GXJ1fZQoaAZHQKYQnPVNHpdoB03oA2gIR0Ctq/a9kBjndX2UKGgGR0Cmo4jRlYlqaAdN6ANoCEdAra09Tgl4T3V9lChoBkdApNBz6JqIrWgHTegDaAhHQK2051jAi3Z1fZQoaAZHQKXI+y5Zr59oB03oA2gIR0CttTxWkrPMdX2UKGgGR0CharHoPkJbaAdNVgNoCEdArbXjyUcGT3VlLg=="}, "ep_success_buffer": {":type:": "<class 'collections.deque'>", ":serialized:": "gAWVIAAAAAAAAACMC2NvbGxlY3Rpb25zlIwFZGVxdWWUk5QpS2SGlFKULg=="}, "_n_updates": 62500, "n_steps": 8, "gamma": 0.99, "gae_lambda": 0.9, "ent_coef": 0.0, "vf_coef": 0.4, "max_grad_norm": 0.5, "normalize_advantage": false, "system_info": {"OS": "Linux-5.10.147+-x86_64-with-glibc2.29 # 1 SMP Sat Dec 10 16:00:40 UTC 2022", "Python": "3.8.10", "Stable-Baselines3": "1.7.0", "PyTorch": "1.13.1+cu116", "GPU Enabled": "True", "Numpy": "1.21.6", "Gym": "0.21.0"}}
|
|
|
|
dqn/dqn.py
DELETED
@@ -1,182 +0,0 @@
|
|
1 |
-
import copy
|
2 |
-
import numpy as np
|
3 |
-
import random
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
|
8 |
-
from collections import deque
|
9 |
-
from torch.optim import Adam
|
10 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
11 |
-
from typing import List, NamedTuple, Optional, TypeVar
|
12 |
-
|
13 |
-
from dqn.policy import DQNPolicy
|
14 |
-
from shared.algorithm import Algorithm
|
15 |
-
from shared.callbacks.callback import Callback
|
16 |
-
from shared.schedule import linear_schedule
|
17 |
-
from wrappers.vectorable_wrapper import VecEnv, VecEnvObs
|
18 |
-
|
19 |
-
|
20 |
-
class Transition(NamedTuple):
|
21 |
-
obs: np.ndarray
|
22 |
-
action: np.ndarray
|
23 |
-
reward: float
|
24 |
-
done: bool
|
25 |
-
next_obs: np.ndarray
|
26 |
-
|
27 |
-
|
28 |
-
class Batch(NamedTuple):
|
29 |
-
obs: np.ndarray
|
30 |
-
actions: np.ndarray
|
31 |
-
rewards: np.ndarray
|
32 |
-
dones: np.ndarray
|
33 |
-
next_obs: np.ndarray
|
34 |
-
|
35 |
-
|
36 |
-
class ReplayBuffer:
|
37 |
-
def __init__(self, num_envs: int, maxlen: int) -> None:
|
38 |
-
self.num_envs = num_envs
|
39 |
-
self.buffer = deque(maxlen=maxlen)
|
40 |
-
|
41 |
-
def add(
|
42 |
-
self,
|
43 |
-
obs: VecEnvObs,
|
44 |
-
action: np.ndarray,
|
45 |
-
reward: np.ndarray,
|
46 |
-
done: np.ndarray,
|
47 |
-
next_obs: VecEnvObs,
|
48 |
-
) -> None:
|
49 |
-
assert isinstance(obs, np.ndarray)
|
50 |
-
assert isinstance(next_obs, np.ndarray)
|
51 |
-
for i in range(self.num_envs):
|
52 |
-
self.buffer.append(
|
53 |
-
Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
|
54 |
-
)
|
55 |
-
|
56 |
-
def sample(self, batch_size: int) -> Batch:
|
57 |
-
ts = random.sample(self.buffer, batch_size)
|
58 |
-
return Batch(
|
59 |
-
obs=np.array([t.obs for t in ts]),
|
60 |
-
actions=np.array([t.action for t in ts]),
|
61 |
-
rewards=np.array([t.reward for t in ts]),
|
62 |
-
dones=np.array([t.done for t in ts]),
|
63 |
-
next_obs=np.array([t.next_obs for t in ts]),
|
64 |
-
)
|
65 |
-
|
66 |
-
def __len__(self) -> int:
|
67 |
-
return len(self.buffer)
|
68 |
-
|
69 |
-
|
70 |
-
DQNSelf = TypeVar("DQNSelf", bound="DQN")
|
71 |
-
|
72 |
-
|
73 |
-
class DQN(Algorithm):
|
74 |
-
def __init__(
|
75 |
-
self,
|
76 |
-
policy: DQNPolicy,
|
77 |
-
env: VecEnv,
|
78 |
-
device: torch.device,
|
79 |
-
tb_writer: SummaryWriter,
|
80 |
-
learning_rate: float = 1e-4,
|
81 |
-
buffer_size: int = 1_000_000,
|
82 |
-
learning_starts: int = 50_000,
|
83 |
-
batch_size: int = 32,
|
84 |
-
tau: float = 1.0,
|
85 |
-
gamma: float = 0.99,
|
86 |
-
train_freq: int = 4,
|
87 |
-
gradient_steps: int = 1,
|
88 |
-
target_update_interval: int = 10_000,
|
89 |
-
exploration_fraction: float = 0.1,
|
90 |
-
exploration_initial_eps: float = 1.0,
|
91 |
-
exploration_final_eps: float = 0.05,
|
92 |
-
max_grad_norm: float = 10.0,
|
93 |
-
) -> None:
|
94 |
-
super().__init__(policy, env, device, tb_writer)
|
95 |
-
self.policy = policy
|
96 |
-
|
97 |
-
self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
|
98 |
-
|
99 |
-
self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
|
100 |
-
self.target_q_net.train(False)
|
101 |
-
self.tau = tau
|
102 |
-
self.target_update_interval = target_update_interval
|
103 |
-
|
104 |
-
self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
|
105 |
-
self.batch_size = batch_size
|
106 |
-
|
107 |
-
self.learning_starts = learning_starts
|
108 |
-
self.train_freq = train_freq
|
109 |
-
self.gradient_steps = gradient_steps
|
110 |
-
|
111 |
-
self.gamma = gamma
|
112 |
-
self.exploration_eps_schedule = linear_schedule(
|
113 |
-
exploration_initial_eps,
|
114 |
-
exploration_final_eps,
|
115 |
-
end_fraction=exploration_fraction,
|
116 |
-
)
|
117 |
-
|
118 |
-
self.max_grad_norm = max_grad_norm
|
119 |
-
|
120 |
-
def learn(
|
121 |
-
self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
|
122 |
-
) -> DQNSelf:
|
123 |
-
self.policy.train(True)
|
124 |
-
obs = self.env.reset()
|
125 |
-
obs = self._collect_rollout(self.learning_starts, obs, 1)
|
126 |
-
learning_steps = total_timesteps - self.learning_starts
|
127 |
-
timesteps_elapsed = 0
|
128 |
-
steps_since_target_update = 0
|
129 |
-
while timesteps_elapsed < learning_steps:
|
130 |
-
progress = timesteps_elapsed / learning_steps
|
131 |
-
eps = self.exploration_eps_schedule(progress)
|
132 |
-
obs = self._collect_rollout(self.train_freq, obs, eps)
|
133 |
-
rollout_steps = self.train_freq
|
134 |
-
timesteps_elapsed += rollout_steps
|
135 |
-
for _ in range(
|
136 |
-
self.gradient_steps if self.gradient_steps > 0 else self.train_freq
|
137 |
-
):
|
138 |
-
self.train()
|
139 |
-
steps_since_target_update += rollout_steps
|
140 |
-
if steps_since_target_update >= self.target_update_interval:
|
141 |
-
self._update_target()
|
142 |
-
steps_since_target_update = 0
|
143 |
-
if callback:
|
144 |
-
callback.on_step(timesteps_elapsed=rollout_steps)
|
145 |
-
return self
|
146 |
-
|
147 |
-
def train(self) -> None:
|
148 |
-
if len(self.replay_buffer) < self.batch_size:
|
149 |
-
return
|
150 |
-
o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
|
151 |
-
o = torch.as_tensor(o, device=self.device)
|
152 |
-
a = torch.as_tensor(a, device=self.device).unsqueeze(1)
|
153 |
-
r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
|
154 |
-
d = torch.as_tensor(d, dtype=torch.long, device=self.device)
|
155 |
-
next_o = torch.as_tensor(next_o, device=self.device)
|
156 |
-
|
157 |
-
with torch.no_grad():
|
158 |
-
target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
|
159 |
-
current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
|
160 |
-
loss = F.smooth_l1_loss(current, target)
|
161 |
-
|
162 |
-
self.optimizer.zero_grad()
|
163 |
-
loss.backward()
|
164 |
-
if self.max_grad_norm:
|
165 |
-
nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
|
166 |
-
self.optimizer.step()
|
167 |
-
|
168 |
-
def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
|
169 |
-
for _ in range(0, timesteps, self.env.num_envs):
|
170 |
-
action = self.policy.act(obs, eps, deterministic=False)
|
171 |
-
next_obs, reward, done, _ = self.env.step(action)
|
172 |
-
self.replay_buffer.add(obs, action, reward, done, next_obs)
|
173 |
-
obs = next_obs
|
174 |
-
return obs
|
175 |
-
|
176 |
-
def _update_target(self) -> None:
|
177 |
-
for target_param, param in zip(
|
178 |
-
self.target_q_net.parameters(), self.policy.q_net.parameters()
|
179 |
-
):
|
180 |
-
target_param.data.copy_(
|
181 |
-
self.tau * param.data + (1 - self.tau) * target_param.data
|
182 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dqn/policy.py
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import os
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from typing import Optional, Sequence, TypeVar
|
6 |
-
|
7 |
-
from dqn.q_net import QNetwork
|
8 |
-
from shared.policy.policy import Policy
|
9 |
-
from wrappers.vectorable_wrapper import (
|
10 |
-
VecEnv,
|
11 |
-
VecEnvObs,
|
12 |
-
single_observation_space,
|
13 |
-
single_action_space,
|
14 |
-
)
|
15 |
-
|
16 |
-
DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
|
17 |
-
|
18 |
-
|
19 |
-
class DQNPolicy(Policy):
|
20 |
-
def __init__(
|
21 |
-
self,
|
22 |
-
env: VecEnv,
|
23 |
-
hidden_sizes: Sequence[int] = [],
|
24 |
-
cnn_feature_dim: int = 512,
|
25 |
-
cnn_style: str = "nature",
|
26 |
-
cnn_layers_init_orthogonal: Optional[bool] = None,
|
27 |
-
impala_channels: Sequence[int] = (16, 32, 32),
|
28 |
-
**kwargs,
|
29 |
-
) -> None:
|
30 |
-
super().__init__(env, **kwargs)
|
31 |
-
self.q_net = QNetwork(
|
32 |
-
single_observation_space(env),
|
33 |
-
single_action_space(env),
|
34 |
-
hidden_sizes,
|
35 |
-
cnn_feature_dim=cnn_feature_dim,
|
36 |
-
cnn_style=cnn_style,
|
37 |
-
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
38 |
-
impala_channels=impala_channels,
|
39 |
-
)
|
40 |
-
|
41 |
-
def act(
|
42 |
-
self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
|
43 |
-
) -> np.ndarray:
|
44 |
-
assert eps == 0 if deterministic else eps >= 0
|
45 |
-
if not deterministic and np.random.random() < eps:
|
46 |
-
return np.array(
|
47 |
-
[self.env.action_space.sample() for _ in range(self.env.num_envs)]
|
48 |
-
)
|
49 |
-
else:
|
50 |
-
o = self._as_tensor(obs)
|
51 |
-
with torch.no_grad():
|
52 |
-
return self.q_net(o).argmax(axis=1).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dqn/q_net.py
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
import gym
|
2 |
-
import torch as th
|
3 |
-
import torch.nn as nn
|
4 |
-
|
5 |
-
from gym.spaces import Discrete
|
6 |
-
from typing import Optional, Sequence, Type
|
7 |
-
|
8 |
-
from shared.module.feature_extractor import FeatureExtractor
|
9 |
-
from shared.module.module import mlp
|
10 |
-
|
11 |
-
|
12 |
-
class QNetwork(nn.Module):
|
13 |
-
def __init__(
|
14 |
-
self,
|
15 |
-
observation_space: gym.Space,
|
16 |
-
action_space: gym.Space,
|
17 |
-
hidden_sizes: Sequence[int] = [],
|
18 |
-
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
|
19 |
-
cnn_feature_dim: int = 512,
|
20 |
-
cnn_style: str = "nature",
|
21 |
-
cnn_layers_init_orthogonal: Optional[bool] = None,
|
22 |
-
impala_channels: Sequence[int] = (16, 32, 32),
|
23 |
-
) -> None:
|
24 |
-
super().__init__()
|
25 |
-
assert isinstance(action_space, Discrete)
|
26 |
-
self._feature_extractor = FeatureExtractor(
|
27 |
-
observation_space,
|
28 |
-
activation,
|
29 |
-
cnn_feature_dim=cnn_feature_dim,
|
30 |
-
cnn_style=cnn_style,
|
31 |
-
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
32 |
-
impala_channels=impala_channels,
|
33 |
-
)
|
34 |
-
layer_sizes = (
|
35 |
-
(self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
|
36 |
-
)
|
37 |
-
self._fc = mlp(layer_sizes, activation)
|
38 |
-
|
39 |
-
def forward(self, obs: th.Tensor) -> th.Tensor:
|
40 |
-
x = self._feature_extractor(obs)
|
41 |
-
return self._fc(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyperparams/a2c.yml
DELETED
@@ -1,127 +0,0 @@
|
|
1 |
-
CartPole-v1: &cartpole-defaults
|
2 |
-
n_timesteps: !!float 5e5
|
3 |
-
env_hyperparams:
|
4 |
-
n_envs: 8
|
5 |
-
|
6 |
-
CartPole-v0:
|
7 |
-
<<: *cartpole-defaults
|
8 |
-
|
9 |
-
MountainCar-v0:
|
10 |
-
n_timesteps: !!float 1e6
|
11 |
-
env_hyperparams:
|
12 |
-
n_envs: 16
|
13 |
-
normalize: true
|
14 |
-
|
15 |
-
MountainCarContinuous-v0:
|
16 |
-
n_timesteps: !!float 1e5
|
17 |
-
env_hyperparams:
|
18 |
-
n_envs: 4
|
19 |
-
normalize: true
|
20 |
-
# policy_hyperparams:
|
21 |
-
# use_sde: true
|
22 |
-
# log_std_init: 0.0
|
23 |
-
# init_layers_orthogonal: false
|
24 |
-
algo_hyperparams:
|
25 |
-
n_steps: 100
|
26 |
-
sde_sample_freq: 16
|
27 |
-
|
28 |
-
Acrobot-v1:
|
29 |
-
n_timesteps: !!float 5e5
|
30 |
-
env_hyperparams:
|
31 |
-
normalize: true
|
32 |
-
n_envs: 16
|
33 |
-
|
34 |
-
LunarLander-v2:
|
35 |
-
n_timesteps: !!float 1e6
|
36 |
-
env_hyperparams:
|
37 |
-
n_envs: 8
|
38 |
-
normalize: true
|
39 |
-
algo_hyperparams:
|
40 |
-
n_steps: 5
|
41 |
-
gamma: 0.995
|
42 |
-
learning_rate: !!float 8.3e-4
|
43 |
-
learning_rate_decay: linear
|
44 |
-
ent_coef: !!float 1e-5
|
45 |
-
|
46 |
-
BipedalWalker-v3:
|
47 |
-
n_timesteps: !!float 5e6
|
48 |
-
env_hyperparams:
|
49 |
-
n_envs: 16
|
50 |
-
normalize: true
|
51 |
-
policy_hyperparams:
|
52 |
-
use_sde: true
|
53 |
-
log_std_init: -2
|
54 |
-
init_layers_orthogonal: false
|
55 |
-
algo_hyperparams:
|
56 |
-
ent_coef: 0
|
57 |
-
max_grad_norm: 0.5
|
58 |
-
n_steps: 8
|
59 |
-
gae_lambda: 0.9
|
60 |
-
vf_coef: 0.4
|
61 |
-
gamma: 0.99
|
62 |
-
learning_rate: !!float 9.6e-4
|
63 |
-
learning_rate_decay: linear
|
64 |
-
|
65 |
-
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
66 |
-
n_timesteps: !!float 2e6
|
67 |
-
env_hyperparams:
|
68 |
-
n_envs: 4
|
69 |
-
normalize: true
|
70 |
-
policy_hyperparams:
|
71 |
-
use_sde: true
|
72 |
-
log_std_init: -2
|
73 |
-
init_layers_orthogonal: false
|
74 |
-
algo_hyperaparms: &pybullet-algo-defaults
|
75 |
-
n_steps: 8
|
76 |
-
ent_coef: 0
|
77 |
-
max_grad_norm: 0.5
|
78 |
-
gae_lambda: 0.9
|
79 |
-
gamma: 0.99
|
80 |
-
vf_coef: 0.4
|
81 |
-
learning_rate: !!float 9.6e-4
|
82 |
-
learning_rate_decay: linear
|
83 |
-
|
84 |
-
AntBulletEnv-v0:
|
85 |
-
<<: *pybullet-defaults
|
86 |
-
|
87 |
-
Walker2DBulletEnv-v0:
|
88 |
-
<<: *pybullet-defaults
|
89 |
-
|
90 |
-
HopperBulletEnv-v0:
|
91 |
-
<<: *pybullet-defaults
|
92 |
-
|
93 |
-
CarRacing-v0:
|
94 |
-
n_timesteps: !!float 4e6
|
95 |
-
env_hyperparams:
|
96 |
-
n_envs: 8
|
97 |
-
frame_stack: 4
|
98 |
-
policy_hyperparams:
|
99 |
-
use_sde: true
|
100 |
-
log_std_init: -2
|
101 |
-
init_layers_orthogonal: false
|
102 |
-
activation_fn: relu
|
103 |
-
share_features_extractor: false
|
104 |
-
cnn_feature_dim: 256
|
105 |
-
hidden_sizes: [256]
|
106 |
-
algo_hyperparams:
|
107 |
-
n_steps: 8
|
108 |
-
learning_rate: !!float 5e-5
|
109 |
-
learning_rate_decay: linear
|
110 |
-
gamma: 0.99
|
111 |
-
gae_lambda: 0.95
|
112 |
-
ent_coef: 0
|
113 |
-
sde_sample_freq: 4
|
114 |
-
|
115 |
-
_atari: &atari-defaults
|
116 |
-
n_timesteps: !!float 1e7
|
117 |
-
env_hyperparams: &atari-env-defaults
|
118 |
-
n_envs: 16
|
119 |
-
frame_stack: 4
|
120 |
-
no_reward_timeout_steps: 1000
|
121 |
-
no_reward_fire_steps: 500
|
122 |
-
vec_env_class: async
|
123 |
-
policy_hyperparams: &atari-policy-defaults
|
124 |
-
activation_fn: relu
|
125 |
-
algo_hyperparams:
|
126 |
-
ent_coef: 0.01
|
127 |
-
vf_coef: 0.25
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyperparams/dqn.yml
DELETED
@@ -1,130 +0,0 @@
|
|
1 |
-
CartPole-v1: &cartpole-defaults
|
2 |
-
n_timesteps: !!float 5e4
|
3 |
-
env_hyperparams:
|
4 |
-
rolling_length: 50
|
5 |
-
policy_hyperparams:
|
6 |
-
hidden_sizes: [256, 256]
|
7 |
-
algo_hyperparams:
|
8 |
-
learning_rate: !!float 2.3e-3
|
9 |
-
batch_size: 64
|
10 |
-
buffer_size: 100000
|
11 |
-
learning_starts: 1000
|
12 |
-
gamma: 0.99
|
13 |
-
target_update_interval: 10
|
14 |
-
train_freq: 256
|
15 |
-
gradient_steps: 128
|
16 |
-
exploration_fraction: 0.16
|
17 |
-
exploration_final_eps: 0.04
|
18 |
-
eval_params:
|
19 |
-
step_freq: !!float 1e4
|
20 |
-
|
21 |
-
CartPole-v0:
|
22 |
-
<<: *cartpole-defaults
|
23 |
-
n_timesteps: !!float 4e4
|
24 |
-
|
25 |
-
MountainCar-v0:
|
26 |
-
n_timesteps: !!float 1.2e5
|
27 |
-
env_hyperparams:
|
28 |
-
rolling_length: 50
|
29 |
-
policy_hyperparams:
|
30 |
-
hidden_sizes: [256, 256]
|
31 |
-
algo_hyperparams:
|
32 |
-
learning_rate: !!float 4e-3
|
33 |
-
batch_size: 128
|
34 |
-
buffer_size: 10000
|
35 |
-
learning_starts: 1000
|
36 |
-
gamma: 0.98
|
37 |
-
target_update_interval: 600
|
38 |
-
train_freq: 16
|
39 |
-
gradient_steps: 8
|
40 |
-
exploration_fraction: 0.2
|
41 |
-
exploration_final_eps: 0.07
|
42 |
-
|
43 |
-
Acrobot-v1:
|
44 |
-
n_timesteps: !!float 1e5
|
45 |
-
env_hyperparams:
|
46 |
-
rolling_length: 50
|
47 |
-
policy_hyperparams:
|
48 |
-
hidden_sizes: [256, 256]
|
49 |
-
algo_hyperparams:
|
50 |
-
learning_rate: !!float 6.3e-4
|
51 |
-
batch_size: 128
|
52 |
-
buffer_size: 50000
|
53 |
-
learning_starts: 0
|
54 |
-
gamma: 0.99
|
55 |
-
target_update_interval: 250
|
56 |
-
train_freq: 4
|
57 |
-
gradient_steps: -1
|
58 |
-
exploration_fraction: 0.12
|
59 |
-
exploration_final_eps: 0.1
|
60 |
-
|
61 |
-
LunarLander-v2:
|
62 |
-
n_timesteps: !!float 5e5
|
63 |
-
env_hyperparams:
|
64 |
-
rolling_length: 50
|
65 |
-
policy_hyperparams:
|
66 |
-
hidden_sizes: [256, 256]
|
67 |
-
algo_hyperparams:
|
68 |
-
learning_rate: !!float 1e-4
|
69 |
-
batch_size: 256
|
70 |
-
buffer_size: 100000
|
71 |
-
learning_starts: 10000
|
72 |
-
gamma: 0.99
|
73 |
-
target_update_interval: 250
|
74 |
-
train_freq: 8
|
75 |
-
gradient_steps: -1
|
76 |
-
exploration_fraction: 0.12
|
77 |
-
exploration_final_eps: 0.1
|
78 |
-
max_grad_norm: 0.5
|
79 |
-
eval_params:
|
80 |
-
step_freq: 25_000
|
81 |
-
|
82 |
-
_atari: &atari-defaults
|
83 |
-
n_timesteps: !!float 1e7
|
84 |
-
env_hyperparams:
|
85 |
-
frame_stack: 4
|
86 |
-
no_reward_timeout_steps: 1_000
|
87 |
-
no_reward_fire_steps: 500
|
88 |
-
n_envs: 8
|
89 |
-
vec_env_class: async
|
90 |
-
algo_hyperparams:
|
91 |
-
buffer_size: 100000
|
92 |
-
learning_rate: !!float 1e-4
|
93 |
-
batch_size: 32
|
94 |
-
learning_starts: 100000
|
95 |
-
target_update_interval: 1000
|
96 |
-
train_freq: 8
|
97 |
-
gradient_steps: 2
|
98 |
-
exploration_fraction: 0.1
|
99 |
-
exploration_final_eps: 0.01
|
100 |
-
eval_params:
|
101 |
-
deterministic: false
|
102 |
-
|
103 |
-
PongNoFrameskip-v4:
|
104 |
-
<<: *atari-defaults
|
105 |
-
n_timesteps: !!float 2.5e6
|
106 |
-
|
107 |
-
_impala-atari: &impala-atari-defaults
|
108 |
-
<<: *atari-defaults
|
109 |
-
policy_hyperparams:
|
110 |
-
cnn_style: impala
|
111 |
-
cnn_feature_dim: 256
|
112 |
-
init_layers_orthogonal: true
|
113 |
-
cnn_layers_init_orthogonal: false
|
114 |
-
|
115 |
-
impala-PongNoFrameskip-v4:
|
116 |
-
<<: *impala-atari-defaults
|
117 |
-
env_id: PongNoFrameskip-v4
|
118 |
-
n_timesteps: !!float 2.5e6
|
119 |
-
|
120 |
-
impala-BreakoutNoFrameskip-v4:
|
121 |
-
<<: *impala-atari-defaults
|
122 |
-
env_id: BreakoutNoFrameskip-v4
|
123 |
-
|
124 |
-
impala-SpaceInvadersNoFrameskip-v4:
|
125 |
-
<<: *impala-atari-defaults
|
126 |
-
env_id: SpaceInvadersNoFrameskip-v4
|
127 |
-
|
128 |
-
impala-QbertNoFrameskip-v4:
|
129 |
-
<<: *impala-atari-defaults
|
130 |
-
env_id: QbertNoFrameskip-v4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyperparams/ppo.yml
DELETED
@@ -1,383 +0,0 @@
|
|
1 |
-
CartPole-v1: &cartpole-defaults
|
2 |
-
n_timesteps: !!float 1e5
|
3 |
-
env_hyperparams:
|
4 |
-
n_envs: 8
|
5 |
-
algo_hyperparams:
|
6 |
-
n_steps: 32
|
7 |
-
batch_size: 256
|
8 |
-
n_epochs: 20
|
9 |
-
gae_lambda: 0.8
|
10 |
-
gamma: 0.98
|
11 |
-
ent_coef: 0.0
|
12 |
-
learning_rate: 0.001
|
13 |
-
learning_rate_decay: linear
|
14 |
-
clip_range: 0.2
|
15 |
-
clip_range_decay: linear
|
16 |
-
eval_params:
|
17 |
-
step_freq: !!float 2.5e4
|
18 |
-
|
19 |
-
CartPole-v0:
|
20 |
-
<<: *cartpole-defaults
|
21 |
-
n_timesteps: !!float 5e4
|
22 |
-
|
23 |
-
MountainCar-v0:
|
24 |
-
n_timesteps: !!float 1e6
|
25 |
-
env_hyperparams:
|
26 |
-
normalize: true
|
27 |
-
n_envs: 16
|
28 |
-
algo_hyperparams:
|
29 |
-
n_steps: 16
|
30 |
-
n_epochs: 4
|
31 |
-
gae_lambda: 0.98
|
32 |
-
gamma: 0.99
|
33 |
-
ent_coef: 0.0
|
34 |
-
|
35 |
-
MountainCarContinuous-v0:
|
36 |
-
n_timesteps: !!float 1e5
|
37 |
-
env_hyperparams:
|
38 |
-
normalize: true
|
39 |
-
n_envs: 4
|
40 |
-
# policy_hyperparams:
|
41 |
-
# init_layers_orthogonal: false
|
42 |
-
# log_std_init: -3.29
|
43 |
-
# use_sde: true
|
44 |
-
algo_hyperparams:
|
45 |
-
n_steps: 512
|
46 |
-
batch_size: 256
|
47 |
-
n_epochs: 10
|
48 |
-
learning_rate: !!float 7.77e-5
|
49 |
-
ent_coef: 0.01 # 0.00429
|
50 |
-
ent_coef_decay: linear
|
51 |
-
clip_range: 0.1
|
52 |
-
gae_lambda: 0.9
|
53 |
-
max_grad_norm: 5
|
54 |
-
vf_coef: 0.19
|
55 |
-
eval_params:
|
56 |
-
step_freq: 5000
|
57 |
-
|
58 |
-
Acrobot-v1:
|
59 |
-
n_timesteps: !!float 1e6
|
60 |
-
env_hyperparams:
|
61 |
-
n_envs: 16
|
62 |
-
normalize: true
|
63 |
-
algo_hyperparams:
|
64 |
-
n_steps: 256
|
65 |
-
n_epochs: 4
|
66 |
-
gae_lambda: 0.94
|
67 |
-
gamma: 0.99
|
68 |
-
ent_coef: 0.0
|
69 |
-
|
70 |
-
LunarLander-v2:
|
71 |
-
n_timesteps: !!float 4e6
|
72 |
-
env_hyperparams:
|
73 |
-
n_envs: 16
|
74 |
-
algo_hyperparams:
|
75 |
-
n_steps: 1024
|
76 |
-
batch_size: 64
|
77 |
-
n_epochs: 4
|
78 |
-
gae_lambda: 0.98
|
79 |
-
gamma: 0.999
|
80 |
-
learning_rate: !!float 5e-4
|
81 |
-
learning_rate_decay: linear
|
82 |
-
clip_range: 0.2
|
83 |
-
clip_range_decay: linear
|
84 |
-
ent_coef: 0.01
|
85 |
-
normalize_advantage: false
|
86 |
-
|
87 |
-
BipedalWalker-v3:
|
88 |
-
n_timesteps: !!float 10e6
|
89 |
-
env_hyperparams:
|
90 |
-
n_envs: 16
|
91 |
-
normalize: true
|
92 |
-
algo_hyperparams:
|
93 |
-
n_steps: 2048
|
94 |
-
batch_size: 64
|
95 |
-
gae_lambda: 0.95
|
96 |
-
gamma: 0.99
|
97 |
-
n_epochs: 10
|
98 |
-
ent_coef: 0.001
|
99 |
-
learning_rate: !!float 2.5e-4
|
100 |
-
learning_rate_decay: linear
|
101 |
-
clip_range: 0.2
|
102 |
-
clip_range_decay: linear
|
103 |
-
|
104 |
-
CarRacing-v0: &carracing-defaults
|
105 |
-
n_timesteps: !!float 4e6
|
106 |
-
env_hyperparams:
|
107 |
-
n_envs: 8
|
108 |
-
frame_stack: 4
|
109 |
-
policy_hyperparams: &carracing-policy-defaults
|
110 |
-
use_sde: true
|
111 |
-
log_std_init: -2
|
112 |
-
init_layers_orthogonal: false
|
113 |
-
activation_fn: relu
|
114 |
-
share_features_extractor: false
|
115 |
-
cnn_feature_dim: 256
|
116 |
-
hidden_sizes: [256]
|
117 |
-
algo_hyperparams:
|
118 |
-
n_steps: 512
|
119 |
-
batch_size: 128
|
120 |
-
n_epochs: 10
|
121 |
-
learning_rate: !!float 1e-4
|
122 |
-
learning_rate_decay: linear
|
123 |
-
gamma: 0.99
|
124 |
-
gae_lambda: 0.95
|
125 |
-
ent_coef: 0.0
|
126 |
-
sde_sample_freq: 4
|
127 |
-
max_grad_norm: 0.5
|
128 |
-
vf_coef: 0.5
|
129 |
-
clip_range: 0.2
|
130 |
-
|
131 |
-
impala-CarRacing-v0:
|
132 |
-
<<: *carracing-defaults
|
133 |
-
env_id: CarRacing-v0
|
134 |
-
policy_hyperparams:
|
135 |
-
<<: *carracing-policy-defaults
|
136 |
-
cnn_style: impala
|
137 |
-
init_layers_orthogonal: true
|
138 |
-
cnn_layers_init_orthogonal: false
|
139 |
-
hidden_sizes: []
|
140 |
-
|
141 |
-
# BreakoutNoFrameskip-v4
|
142 |
-
# PongNoFrameskip-v4
|
143 |
-
# SpaceInvadersNoFrameskip-v4
|
144 |
-
# QbertNoFrameskip-v4
|
145 |
-
_atari: &atari-defaults
|
146 |
-
n_timesteps: !!float 1e7
|
147 |
-
env_hyperparams: &atari-env-defaults
|
148 |
-
n_envs: 8
|
149 |
-
frame_stack: 4
|
150 |
-
no_reward_timeout_steps: 1000
|
151 |
-
no_reward_fire_steps: 500
|
152 |
-
vec_env_class: async
|
153 |
-
policy_hyperparams: &atari-policy-defaults
|
154 |
-
activation_fn: relu
|
155 |
-
algo_hyperparams:
|
156 |
-
n_steps: 128
|
157 |
-
batch_size: 256
|
158 |
-
n_epochs: 4
|
159 |
-
learning_rate: !!float 2.5e-4
|
160 |
-
learning_rate_decay: linear
|
161 |
-
clip_range: 0.1
|
162 |
-
clip_range_decay: linear
|
163 |
-
vf_coef: 0.5
|
164 |
-
ent_coef: 0.01
|
165 |
-
eval_params:
|
166 |
-
deterministic: false
|
167 |
-
|
168 |
-
_norm-rewards-atari: &norm-rewards-atari-default
|
169 |
-
<<: *atari-defaults
|
170 |
-
env_hyperparams:
|
171 |
-
<<: *atari-env-defaults
|
172 |
-
clip_atari_rewards: false
|
173 |
-
normalize: true
|
174 |
-
normalize_kwargs:
|
175 |
-
norm_obs: false
|
176 |
-
norm_reward: true
|
177 |
-
|
178 |
-
norm-rewards-BreakoutNoFrameskip-v4:
|
179 |
-
<<: *norm-rewards-atari-default
|
180 |
-
env_id: BreakoutNoFrameskip-v4
|
181 |
-
|
182 |
-
debug-PongNoFrameskip-v4:
|
183 |
-
<<: *atari-defaults
|
184 |
-
device: cpu
|
185 |
-
env_id: PongNoFrameskip-v4
|
186 |
-
env_hyperparams:
|
187 |
-
<<: *atari-env-defaults
|
188 |
-
vec_env_class: sync
|
189 |
-
|
190 |
-
_impala-atari: &impala-atari-defaults
|
191 |
-
<<: *atari-defaults
|
192 |
-
policy_hyperparams:
|
193 |
-
<<: *atari-policy-defaults
|
194 |
-
cnn_style: impala
|
195 |
-
cnn_feature_dim: 256
|
196 |
-
init_layers_orthogonal: true
|
197 |
-
cnn_layers_init_orthogonal: false
|
198 |
-
|
199 |
-
impala-PongNoFrameskip-v4:
|
200 |
-
<<: *impala-atari-defaults
|
201 |
-
env_id: PongNoFrameskip-v4
|
202 |
-
|
203 |
-
impala-BreakoutNoFrameskip-v4:
|
204 |
-
<<: *impala-atari-defaults
|
205 |
-
env_id: BreakoutNoFrameskip-v4
|
206 |
-
|
207 |
-
impala-SpaceInvadersNoFrameskip-v4:
|
208 |
-
<<: *impala-atari-defaults
|
209 |
-
env_id: SpaceInvadersNoFrameskip-v4
|
210 |
-
|
211 |
-
impala-QbertNoFrameskip-v4:
|
212 |
-
<<: *impala-atari-defaults
|
213 |
-
env_id: QbertNoFrameskip-v4
|
214 |
-
|
215 |
-
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
216 |
-
n_timesteps: !!float 2e6
|
217 |
-
env_hyperparams: &pybullet-env-defaults
|
218 |
-
n_envs: 16
|
219 |
-
normalize: true
|
220 |
-
policy_hyperparams: &pybullet-policy-defaults
|
221 |
-
pi_hidden_sizes: [256, 256]
|
222 |
-
v_hidden_sizes: [256, 256]
|
223 |
-
activation_fn: relu
|
224 |
-
algo_hyperparams: &pybullet-algo-defaults
|
225 |
-
n_steps: 512
|
226 |
-
batch_size: 128
|
227 |
-
n_epochs: 20
|
228 |
-
gamma: 0.99
|
229 |
-
gae_lambda: 0.9
|
230 |
-
ent_coef: 0.0
|
231 |
-
max_grad_norm: 0.5
|
232 |
-
vf_coef: 0.5
|
233 |
-
learning_rate: !!float 3e-5
|
234 |
-
clip_range: 0.4
|
235 |
-
|
236 |
-
AntBulletEnv-v0:
|
237 |
-
<<: *pybullet-defaults
|
238 |
-
policy_hyperparams:
|
239 |
-
<<: *pybullet-policy-defaults
|
240 |
-
algo_hyperparams:
|
241 |
-
<<: *pybullet-algo-defaults
|
242 |
-
|
243 |
-
Walker2DBulletEnv-v0:
|
244 |
-
<<: *pybullet-defaults
|
245 |
-
algo_hyperparams:
|
246 |
-
<<: *pybullet-algo-defaults
|
247 |
-
clip_range_decay: linear
|
248 |
-
|
249 |
-
HopperBulletEnv-v0:
|
250 |
-
<<: *pybullet-defaults
|
251 |
-
algo_hyperparams:
|
252 |
-
<<: *pybullet-algo-defaults
|
253 |
-
clip_range_decay: linear
|
254 |
-
|
255 |
-
HumanoidBulletEnv-v0:
|
256 |
-
<<: *pybullet-defaults
|
257 |
-
n_timesteps: !!float 1e7
|
258 |
-
env_hyperparams:
|
259 |
-
<<: *pybullet-env-defaults
|
260 |
-
n_envs: 8
|
261 |
-
policy_hyperparams:
|
262 |
-
<<: *pybullet-policy-defaults
|
263 |
-
# log_std_init: -1
|
264 |
-
algo_hyperparams:
|
265 |
-
<<: *pybullet-algo-defaults
|
266 |
-
n_steps: 2048
|
267 |
-
batch_size: 64
|
268 |
-
n_epochs: 10
|
269 |
-
gae_lambda: 0.95
|
270 |
-
learning_rate: !!float 2.5e-4
|
271 |
-
clip_range: 0.2
|
272 |
-
|
273 |
-
_procgen: &procgen-defaults
|
274 |
-
env_hyperparams: &procgen-env-defaults
|
275 |
-
env_type: procgen
|
276 |
-
n_envs: 64
|
277 |
-
# grayscale: false
|
278 |
-
# frame_stack: 4
|
279 |
-
normalize: true # procgen only normalizes reward
|
280 |
-
make_kwargs: &procgen-make-kwargs-defaults
|
281 |
-
num_threads: 8
|
282 |
-
policy_hyperparams: &procgen-policy-defaults
|
283 |
-
activation_fn: relu
|
284 |
-
cnn_style: impala
|
285 |
-
cnn_feature_dim: 256
|
286 |
-
init_layers_orthogonal: true
|
287 |
-
cnn_layers_init_orthogonal: false
|
288 |
-
algo_hyperparams: &procgen-algo-defaults
|
289 |
-
gamma: 0.999
|
290 |
-
gae_lambda: 0.95
|
291 |
-
n_steps: 256
|
292 |
-
batch_size: 2048
|
293 |
-
n_epochs: 3
|
294 |
-
ent_coef: 0.01
|
295 |
-
clip_range: 0.2
|
296 |
-
# clip_range_decay: linear
|
297 |
-
clip_range_vf: 0.2
|
298 |
-
learning_rate: !!float 5e-4
|
299 |
-
# learning_rate_decay: linear
|
300 |
-
vf_coef: 0.5
|
301 |
-
eval_params: &procgen-eval-defaults
|
302 |
-
ignore_first_episode: true
|
303 |
-
# deterministic: false
|
304 |
-
step_freq: !!float 1e5
|
305 |
-
|
306 |
-
_procgen-easy: &procgen-easy-defaults
|
307 |
-
<<: *procgen-defaults
|
308 |
-
n_timesteps: !!float 25e6
|
309 |
-
env_hyperparams: &procgen-easy-env-defaults
|
310 |
-
<<: *procgen-env-defaults
|
311 |
-
make_kwargs:
|
312 |
-
<<: *procgen-make-kwargs-defaults
|
313 |
-
distribution_mode: easy
|
314 |
-
|
315 |
-
procgen-coinrun-easy: &coinrun-easy-defaults
|
316 |
-
<<: *procgen-easy-defaults
|
317 |
-
env_id: coinrun
|
318 |
-
|
319 |
-
debug-procgen-coinrun:
|
320 |
-
<<: *coinrun-easy-defaults
|
321 |
-
device: cpu
|
322 |
-
|
323 |
-
procgen-starpilot-easy:
|
324 |
-
<<: *procgen-easy-defaults
|
325 |
-
env_id: starpilot
|
326 |
-
|
327 |
-
procgen-bossfight-easy:
|
328 |
-
<<: *procgen-easy-defaults
|
329 |
-
env_id: bossfight
|
330 |
-
|
331 |
-
procgen-bigfish-easy:
|
332 |
-
<<: *procgen-easy-defaults
|
333 |
-
env_id: bigfish
|
334 |
-
|
335 |
-
_procgen-hard: &procgen-hard-defaults
|
336 |
-
<<: *procgen-defaults
|
337 |
-
n_timesteps: !!float 200e6
|
338 |
-
env_hyperparams: &procgen-hard-env-defaults
|
339 |
-
<<: *procgen-env-defaults
|
340 |
-
n_envs: 256
|
341 |
-
make_kwargs:
|
342 |
-
<<: *procgen-make-kwargs-defaults
|
343 |
-
distribution_mode: hard
|
344 |
-
algo_hyperparams: &procgen-hard-algo-defaults
|
345 |
-
<<: *procgen-algo-defaults
|
346 |
-
batch_size: 8192
|
347 |
-
clip_range_decay: linear
|
348 |
-
learning_rate_decay: linear
|
349 |
-
eval_params:
|
350 |
-
<<: *procgen-eval-defaults
|
351 |
-
step_freq: !!float 5e5
|
352 |
-
|
353 |
-
procgen-starpilot-hard: &procgen-starpilot-hard-defaults
|
354 |
-
<<: *procgen-hard-defaults
|
355 |
-
env_id: starpilot
|
356 |
-
|
357 |
-
procgen-starpilot-hard-2xIMPALA:
|
358 |
-
<<: *procgen-starpilot-hard-defaults
|
359 |
-
policy_hyperparams:
|
360 |
-
<<: *procgen-policy-defaults
|
361 |
-
impala_channels: [32, 64, 64]
|
362 |
-
algo_hyperparams:
|
363 |
-
<<: *procgen-hard-algo-defaults
|
364 |
-
learning_rate: !!float 3.3e-4
|
365 |
-
|
366 |
-
procgen-starpilot-hard-2xIMPALA-fat:
|
367 |
-
<<: *procgen-starpilot-hard-defaults
|
368 |
-
policy_hyperparams:
|
369 |
-
<<: *procgen-policy-defaults
|
370 |
-
impala_channels: [32, 64, 64]
|
371 |
-
cnn_feature_dim: 512
|
372 |
-
algo_hyperparams:
|
373 |
-
<<: *procgen-hard-algo-defaults
|
374 |
-
learning_rate: !!float 2.5e-4
|
375 |
-
|
376 |
-
procgen-starpilot-hard-4xIMPALA:
|
377 |
-
<<: *procgen-starpilot-hard-defaults
|
378 |
-
policy_hyperparams:
|
379 |
-
<<: *procgen-policy-defaults
|
380 |
-
impala_channels: [64, 128, 128]
|
381 |
-
algo_hyperparams:
|
382 |
-
<<: *procgen-hard-algo-defaults
|
383 |
-
learning_rate: !!float 2.1e-4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hyperparams/vpg.yml
DELETED
@@ -1,197 +0,0 @@
|
|
1 |
-
CartPole-v1: &cartpole-defaults
|
2 |
-
n_timesteps: !!float 4e5
|
3 |
-
algo_hyperparams:
|
4 |
-
n_steps: 4096
|
5 |
-
pi_lr: 0.01
|
6 |
-
gamma: 0.99
|
7 |
-
gae_lambda: 1
|
8 |
-
val_lr: 0.01
|
9 |
-
train_v_iters: 80
|
10 |
-
eval_params:
|
11 |
-
step_freq: !!float 2.5e4
|
12 |
-
|
13 |
-
CartPole-v0:
|
14 |
-
<<: *cartpole-defaults
|
15 |
-
n_timesteps: !!float 1e5
|
16 |
-
algo_hyperparams:
|
17 |
-
n_steps: 1024
|
18 |
-
pi_lr: 0.01
|
19 |
-
gamma: 0.99
|
20 |
-
gae_lambda: 1
|
21 |
-
val_lr: 0.01
|
22 |
-
train_v_iters: 80
|
23 |
-
|
24 |
-
MountainCar-v0:
|
25 |
-
n_timesteps: !!float 1e6
|
26 |
-
env_hyperparams:
|
27 |
-
normalize: true
|
28 |
-
n_envs: 16
|
29 |
-
algo_hyperparams:
|
30 |
-
n_steps: 200
|
31 |
-
pi_lr: 0.005
|
32 |
-
gamma: 0.99
|
33 |
-
gae_lambda: 0.97
|
34 |
-
val_lr: 0.01
|
35 |
-
train_v_iters: 80
|
36 |
-
max_grad_norm: 0.5
|
37 |
-
|
38 |
-
MountainCarContinuous-v0:
|
39 |
-
n_timesteps: !!float 3e5
|
40 |
-
env_hyperparams:
|
41 |
-
normalize: true
|
42 |
-
n_envs: 4
|
43 |
-
# policy_hyperparams:
|
44 |
-
# init_layers_orthogonal: false
|
45 |
-
# log_std_init: -3.29
|
46 |
-
# use_sde: true
|
47 |
-
algo_hyperparams:
|
48 |
-
n_steps: 1000
|
49 |
-
pi_lr: !!float 5e-4
|
50 |
-
gamma: 0.99
|
51 |
-
gae_lambda: 0.9
|
52 |
-
val_lr: !!float 1e-3
|
53 |
-
train_v_iters: 80
|
54 |
-
max_grad_norm: 5
|
55 |
-
eval_params:
|
56 |
-
step_freq: 5000
|
57 |
-
|
58 |
-
Acrobot-v1:
|
59 |
-
n_timesteps: !!float 2e5
|
60 |
-
algo_hyperparams:
|
61 |
-
n_steps: 2048
|
62 |
-
pi_lr: 0.005
|
63 |
-
gamma: 0.99
|
64 |
-
gae_lambda: 0.97
|
65 |
-
val_lr: 0.01
|
66 |
-
train_v_iters: 80
|
67 |
-
max_grad_norm: 0.5
|
68 |
-
|
69 |
-
LunarLander-v2:
|
70 |
-
n_timesteps: !!float 4e6
|
71 |
-
policy_hyperparams:
|
72 |
-
hidden_sizes: [256, 256]
|
73 |
-
algo_hyperparams:
|
74 |
-
n_steps: 2048
|
75 |
-
pi_lr: 0.0001
|
76 |
-
gamma: 0.999
|
77 |
-
gae_lambda: 0.97
|
78 |
-
val_lr: 0.0001
|
79 |
-
train_v_iters: 80
|
80 |
-
max_grad_norm: 0.5
|
81 |
-
eval_params:
|
82 |
-
deterministic: false
|
83 |
-
|
84 |
-
BipedalWalker-v3:
|
85 |
-
n_timesteps: !!float 10e6
|
86 |
-
env_hyperparams:
|
87 |
-
n_envs: 16
|
88 |
-
normalize: true
|
89 |
-
policy_hyperparams:
|
90 |
-
hidden_sizes: [256, 256]
|
91 |
-
algo_hyperparams:
|
92 |
-
n_steps: 1600
|
93 |
-
gae_lambda: 0.95
|
94 |
-
gamma: 0.99
|
95 |
-
pi_lr: !!float 1e-4
|
96 |
-
val_lr: !!float 1e-4
|
97 |
-
train_v_iters: 80
|
98 |
-
max_grad_norm: 0.5
|
99 |
-
eval_params:
|
100 |
-
deterministic: false
|
101 |
-
|
102 |
-
CarRacing-v0:
|
103 |
-
n_timesteps: !!float 4e6
|
104 |
-
env_hyperparams:
|
105 |
-
frame_stack: 4
|
106 |
-
n_envs: 4
|
107 |
-
vec_env_class: sync
|
108 |
-
policy_hyperparams:
|
109 |
-
use_sde: true
|
110 |
-
log_std_init: -2
|
111 |
-
init_layers_orthogonal: false
|
112 |
-
activation_fn: relu
|
113 |
-
cnn_feature_dim: 256
|
114 |
-
hidden_sizes: [256]
|
115 |
-
algo_hyperparams:
|
116 |
-
n_steps: 1000
|
117 |
-
pi_lr: !!float 5e-5
|
118 |
-
gamma: 0.99
|
119 |
-
gae_lambda: 0.95
|
120 |
-
val_lr: !!float 1e-4
|
121 |
-
train_v_iters: 40
|
122 |
-
max_grad_norm: 0.5
|
123 |
-
sde_sample_freq: 4
|
124 |
-
|
125 |
-
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
126 |
-
n_timesteps: !!float 2e6
|
127 |
-
env_hyperparams: &pybullet-env-defaults
|
128 |
-
normalize: true
|
129 |
-
policy_hyperparams: &pybullet-policy-defaults
|
130 |
-
hidden_sizes: [256, 256]
|
131 |
-
algo_hyperparams: &pybullet-algo-defaults
|
132 |
-
n_steps: 4000
|
133 |
-
pi_lr: !!float 3e-4
|
134 |
-
gamma: 0.99
|
135 |
-
gae_lambda: 0.97
|
136 |
-
val_lr: !!float 1e-3
|
137 |
-
train_v_iters: 80
|
138 |
-
max_grad_norm: 0.5
|
139 |
-
|
140 |
-
AntBulletEnv-v0:
|
141 |
-
<<: *pybullet-defaults
|
142 |
-
policy_hyperparams:
|
143 |
-
<<: *pybullet-policy-defaults
|
144 |
-
hidden_sizes: [400, 300]
|
145 |
-
algo_hyperparams:
|
146 |
-
<<: *pybullet-algo-defaults
|
147 |
-
pi_lr: !!float 7e-4
|
148 |
-
val_lr: !!float 7e-3
|
149 |
-
|
150 |
-
HopperBulletEnv-v0:
|
151 |
-
<<: *pybullet-defaults
|
152 |
-
|
153 |
-
Walker2DBulletEnv-v0:
|
154 |
-
<<: *pybullet-defaults
|
155 |
-
|
156 |
-
FrozenLake-v1:
|
157 |
-
n_timesteps: !!float 8e5
|
158 |
-
env_params:
|
159 |
-
make_kwargs:
|
160 |
-
map_name: 8x8
|
161 |
-
is_slippery: true
|
162 |
-
policy_hyperparams:
|
163 |
-
hidden_sizes: [64]
|
164 |
-
algo_hyperparams:
|
165 |
-
n_steps: 2048
|
166 |
-
pi_lr: 0.01
|
167 |
-
gamma: 0.99
|
168 |
-
gae_lambda: 0.98
|
169 |
-
val_lr: 0.01
|
170 |
-
train_v_iters: 80
|
171 |
-
max_grad_norm: 0.5
|
172 |
-
eval_params:
|
173 |
-
step_freq: !!float 5e4
|
174 |
-
n_episodes: 10
|
175 |
-
save_best: true
|
176 |
-
|
177 |
-
_atari: &atari-defaults
|
178 |
-
n_timesteps: !!float 25e6
|
179 |
-
env_hyperparams:
|
180 |
-
n_envs: 4
|
181 |
-
frame_stack: 4
|
182 |
-
no_reward_timeout_steps: 1000
|
183 |
-
no_reward_fire_steps: 500
|
184 |
-
vec_env_class: async
|
185 |
-
policy_hyperparams:
|
186 |
-
activation_fn: relu
|
187 |
-
algo_hyperparams:
|
188 |
-
n_steps: 2048
|
189 |
-
pi_lr: !!float 5e-5
|
190 |
-
gamma: 0.99
|
191 |
-
gae_lambda: 0.95
|
192 |
-
val_lr: !!float 1e-4
|
193 |
-
train_v_iters: 80
|
194 |
-
max_grad_norm: 0.5
|
195 |
-
ent_coef: 0.01
|
196 |
-
eval_params:
|
197 |
-
deterministic: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_labs/benchmark.sh
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
|
3 |
-
# export WANDB_PROJECT_NAME="rl-algo-impls"
|
4 |
-
|
5 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
|
6 |
-
|
7 |
-
ALGOS=(
|
8 |
-
# "vpg"
|
9 |
-
# "dqn"
|
10 |
-
# "ppo"
|
11 |
-
"a2c"
|
12 |
-
)
|
13 |
-
ENVS=(
|
14 |
-
# Basic
|
15 |
-
"CartPole-v1"
|
16 |
-
"MountainCar-v0"
|
17 |
-
"MountainCarContinuous-v0"
|
18 |
-
"Acrobot-v1"
|
19 |
-
"LunarLander-v2"
|
20 |
-
"BipedalWalker-v3"
|
21 |
-
# PyBullet
|
22 |
-
"HalfCheetahBulletEnv-v0"
|
23 |
-
"AntBulletEnv-v0"
|
24 |
-
"HopperBulletEnv-v0"
|
25 |
-
"Walker2DBulletEnv-v0"
|
26 |
-
# CarRacing
|
27 |
-
"CarRacing-v0"
|
28 |
-
# Atari
|
29 |
-
"PongNoFrameskip-v4"
|
30 |
-
"BreakoutNoFrameskip-v4"
|
31 |
-
"SpaceInvadersNoFrameskip-v4"
|
32 |
-
"QbertNoFrameskip-v4"
|
33 |
-
)
|
34 |
-
train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_labs/impala_atari_benchmark.sh
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
|
3 |
-
# export WANDB_PROJECT_NAME="rl-algo-impls"
|
4 |
-
|
5 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-5}"
|
6 |
-
|
7 |
-
ALGOS=(
|
8 |
-
# "vpg"
|
9 |
-
# "dqn"
|
10 |
-
"ppo"
|
11 |
-
)
|
12 |
-
ENVS=(
|
13 |
-
"impala-PongNoFrameskip-v4"
|
14 |
-
"impala-BreakoutNoFrameskip-v4"
|
15 |
-
"impala-SpaceInvadersNoFrameskip-v4"
|
16 |
-
"impala-QbertNoFrameskip-v4"
|
17 |
-
"impala-CarRacing-v0"
|
18 |
-
)
|
19 |
-
train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_labs/lambda_requirements.txt
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
scipy >= 1.10.0, < 1.11
|
2 |
-
tensorboard >= ^2.11.0, < 2.12
|
3 |
-
AutoROM.accept-rom-license >= 0.4.2, < 0.5
|
4 |
-
stable-baselines3[extra] >= 1.7.0, < 1.8
|
5 |
-
gym[box2d] >= 0.21.0, < 0.22
|
6 |
-
pyglet == 1.5.27
|
7 |
-
wandb >= 0.13.10, < 0.14
|
8 |
-
pyvirtualdisplay == 3.0
|
9 |
-
pybullet >= 3.2.5, < 3.3
|
10 |
-
tabulate >= 0.9.0, < 0.10
|
11 |
-
huggingface-hub >= 0.12.0, < 0.13
|
12 |
-
numexpr >= 2.8.4, < 2.9
|
13 |
-
gym3 >= 0.3.3, < 0.4
|
14 |
-
glfw >= 1.12.0, < 1.13
|
15 |
-
procgen >= 0.10.7, < 0.11
|
16 |
-
ipython >= 8.10.0, < 8.11
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_labs/procgen_benchmark.sh
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
|
3 |
-
# export WANDB_PROJECT_NAME="rl-algo-impls"
|
4 |
-
|
5 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
6 |
-
|
7 |
-
ALGOS=(
|
8 |
-
# "vpg"
|
9 |
-
# "dqn"
|
10 |
-
"ppo"
|
11 |
-
)
|
12 |
-
ENVS=(
|
13 |
-
"procgen-coinrun-easy"
|
14 |
-
"procgen-starpilot-easy"
|
15 |
-
"procgen-bossfight-easy"
|
16 |
-
"procgen-bigfish-easy"
|
17 |
-
)
|
18 |
-
train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_labs/setup.sh
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
sudo apt update
|
2 |
-
sudo apt install -y python-opengl
|
3 |
-
sudo apt install -y ffmpeg
|
4 |
-
sudo apt install -y xvfb
|
5 |
-
sudo apt install -y swig
|
6 |
-
|
7 |
-
python3 -m pip install --upgrade pip
|
8 |
-
pip install --upgrade torch torchvision torchaudio
|
9 |
-
|
10 |
-
pip install --upgrade -r ~/rl-algo-impls/lambda_labs/lambda_requirements.txt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_labs/starpilot_hard_benchmark.sh
DELETED
@@ -1,16 +0,0 @@
|
|
1 |
-
source benchmarks/train_loop.sh
|
2 |
-
|
3 |
-
# export WANDB_PROJECT_NAME="rl-algo-impls"
|
4 |
-
|
5 |
-
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-1}"
|
6 |
-
|
7 |
-
ALGOS=(
|
8 |
-
"ppo"
|
9 |
-
)
|
10 |
-
ENVS=(
|
11 |
-
"procgen-starpilot-hard"
|
12 |
-
"procgen-starpilot-hard-2xIMPALA"
|
13 |
-
"procgen-starpilot-hard-2xIMPALA-fat"
|
14 |
-
"procgen-starpilot-hard-4xIMPALA"
|
15 |
-
)
|
16 |
-
train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
poetry.lock
DELETED
The diff for this file is too large to render.
See raw diff
|
|
ppo/ppo.py
DELETED
@@ -1,349 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from dataclasses import asdict, dataclass, field
|
7 |
-
from time import perf_counter
|
8 |
-
from torch.optim import Adam
|
9 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
10 |
-
from typing import List, Optional, NamedTuple, TypeVar
|
11 |
-
|
12 |
-
from shared.algorithm import Algorithm
|
13 |
-
from shared.callbacks.callback import Callback
|
14 |
-
from shared.gae import compute_advantage, compute_rtg_and_advantage
|
15 |
-
from shared.policy.on_policy import ActorCritic
|
16 |
-
from shared.schedule import constant_schedule, linear_schedule, update_learning_rate
|
17 |
-
from shared.trajectory import Trajectory, TrajectoryAccumulator
|
18 |
-
from wrappers.vectorable_wrapper import VecEnv, VecEnvObs
|
19 |
-
|
20 |
-
|
21 |
-
@dataclass
|
22 |
-
class PPOTrajectory(Trajectory):
|
23 |
-
logp_a: List[float] = field(default_factory=list)
|
24 |
-
|
25 |
-
def add(
|
26 |
-
self,
|
27 |
-
obs: np.ndarray,
|
28 |
-
act: np.ndarray,
|
29 |
-
next_obs: np.ndarray,
|
30 |
-
rew: float,
|
31 |
-
terminated: bool,
|
32 |
-
v: float,
|
33 |
-
logp_a: float,
|
34 |
-
):
|
35 |
-
super().add(obs, act, next_obs, rew, terminated, v)
|
36 |
-
self.logp_a.append(logp_a)
|
37 |
-
|
38 |
-
|
39 |
-
class PPOTrajectoryAccumulator(TrajectoryAccumulator):
|
40 |
-
def __init__(self, num_envs: int) -> None:
|
41 |
-
super().__init__(num_envs, PPOTrajectory)
|
42 |
-
|
43 |
-
def step(
|
44 |
-
self,
|
45 |
-
obs: VecEnvObs,
|
46 |
-
action: np.ndarray,
|
47 |
-
next_obs: VecEnvObs,
|
48 |
-
reward: np.ndarray,
|
49 |
-
done: np.ndarray,
|
50 |
-
val: np.ndarray,
|
51 |
-
logp_a: np.ndarray,
|
52 |
-
) -> None:
|
53 |
-
super().step(obs, action, next_obs, reward, done, val, logp_a)
|
54 |
-
|
55 |
-
|
56 |
-
class TrainStepStats(NamedTuple):
|
57 |
-
loss: float
|
58 |
-
pi_loss: float
|
59 |
-
v_loss: float
|
60 |
-
entropy_loss: float
|
61 |
-
approx_kl: float
|
62 |
-
clipped_frac: float
|
63 |
-
val_clipped_frac: float
|
64 |
-
|
65 |
-
|
66 |
-
@dataclass
|
67 |
-
class TrainStats:
|
68 |
-
loss: float
|
69 |
-
pi_loss: float
|
70 |
-
v_loss: float
|
71 |
-
entropy_loss: float
|
72 |
-
approx_kl: float
|
73 |
-
clipped_frac: float
|
74 |
-
val_clipped_frac: float
|
75 |
-
explained_var: float
|
76 |
-
|
77 |
-
def __init__(self, step_stats: List[TrainStepStats], explained_var: float) -> None:
|
78 |
-
self.loss = np.mean([s.loss for s in step_stats]).item()
|
79 |
-
self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
|
80 |
-
self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
|
81 |
-
self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
|
82 |
-
self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
|
83 |
-
self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
|
84 |
-
self.val_clipped_frac = np.mean([s.val_clipped_frac for s in step_stats]).item()
|
85 |
-
self.explained_var = explained_var
|
86 |
-
|
87 |
-
def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
|
88 |
-
for name, value in asdict(self).items():
|
89 |
-
tb_writer.add_scalar(f"losses/{name}", value, global_step=global_step)
|
90 |
-
|
91 |
-
def __repr__(self) -> str:
|
92 |
-
return " | ".join(
|
93 |
-
[
|
94 |
-
f"Loss: {round(self.loss, 2)}",
|
95 |
-
f"Pi L: {round(self.pi_loss, 2)}",
|
96 |
-
f"V L: {round(self.v_loss, 2)}",
|
97 |
-
f"E L: {round(self.entropy_loss, 2)}",
|
98 |
-
f"Apx KL Div: {round(self.approx_kl, 2)}",
|
99 |
-
f"Clip Frac: {round(self.clipped_frac, 2)}",
|
100 |
-
f"Val Clip Frac: {round(self.val_clipped_frac, 2)}",
|
101 |
-
]
|
102 |
-
)
|
103 |
-
|
104 |
-
|
105 |
-
PPOSelf = TypeVar("PPOSelf", bound="PPO")
|
106 |
-
|
107 |
-
|
108 |
-
class PPO(Algorithm):
|
109 |
-
def __init__(
|
110 |
-
self,
|
111 |
-
policy: ActorCritic,
|
112 |
-
env: VecEnv,
|
113 |
-
device: torch.device,
|
114 |
-
tb_writer: SummaryWriter,
|
115 |
-
learning_rate: float = 3e-4,
|
116 |
-
learning_rate_decay: str = "none",
|
117 |
-
n_steps: int = 2048,
|
118 |
-
batch_size: int = 64,
|
119 |
-
n_epochs: int = 10,
|
120 |
-
gamma: float = 0.99,
|
121 |
-
gae_lambda: float = 0.95,
|
122 |
-
clip_range: float = 0.2,
|
123 |
-
clip_range_decay: str = "none",
|
124 |
-
clip_range_vf: Optional[float] = None,
|
125 |
-
clip_range_vf_decay: str = "none",
|
126 |
-
normalize_advantage: bool = True,
|
127 |
-
ent_coef: float = 0.0,
|
128 |
-
ent_coef_decay: str = "none",
|
129 |
-
vf_coef: float = 0.5,
|
130 |
-
ppo2_vf_coef_halving: bool = False,
|
131 |
-
max_grad_norm: float = 0.5,
|
132 |
-
update_rtg_between_epochs: bool = False,
|
133 |
-
sde_sample_freq: int = -1,
|
134 |
-
) -> None:
|
135 |
-
super().__init__(policy, env, device, tb_writer)
|
136 |
-
self.policy = policy
|
137 |
-
|
138 |
-
self.gamma = gamma
|
139 |
-
self.gae_lambda = gae_lambda
|
140 |
-
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
|
141 |
-
self.lr_schedule = (
|
142 |
-
linear_schedule(learning_rate, 0)
|
143 |
-
if learning_rate_decay == "linear"
|
144 |
-
else constant_schedule(learning_rate)
|
145 |
-
)
|
146 |
-
self.max_grad_norm = max_grad_norm
|
147 |
-
self.clip_range_schedule = (
|
148 |
-
linear_schedule(clip_range, 0)
|
149 |
-
if clip_range_decay == "linear"
|
150 |
-
else constant_schedule(clip_range)
|
151 |
-
)
|
152 |
-
self.clip_range_vf_schedule = None
|
153 |
-
if clip_range_vf:
|
154 |
-
self.clip_range_vf_schedule = (
|
155 |
-
linear_schedule(clip_range_vf, 0)
|
156 |
-
if clip_range_vf_decay == "linear"
|
157 |
-
else constant_schedule(clip_range_vf)
|
158 |
-
)
|
159 |
-
self.normalize_advantage = normalize_advantage
|
160 |
-
self.ent_coef_schedule = (
|
161 |
-
linear_schedule(ent_coef, 0)
|
162 |
-
if ent_coef_decay == "linear"
|
163 |
-
else constant_schedule(ent_coef)
|
164 |
-
)
|
165 |
-
self.vf_coef = vf_coef
|
166 |
-
self.ppo2_vf_coef_halving = ppo2_vf_coef_halving
|
167 |
-
|
168 |
-
self.n_steps = n_steps
|
169 |
-
self.batch_size = batch_size
|
170 |
-
self.n_epochs = n_epochs
|
171 |
-
self.sde_sample_freq = sde_sample_freq
|
172 |
-
|
173 |
-
self.update_rtg_between_epochs = update_rtg_between_epochs
|
174 |
-
|
175 |
-
def learn(
|
176 |
-
self: PPOSelf,
|
177 |
-
total_timesteps: int,
|
178 |
-
callback: Optional[Callback] = None,
|
179 |
-
) -> PPOSelf:
|
180 |
-
obs = self.env.reset()
|
181 |
-
ts_elapsed = 0
|
182 |
-
while ts_elapsed < total_timesteps:
|
183 |
-
start_time = perf_counter()
|
184 |
-
accumulator = self._collect_trajectories(obs)
|
185 |
-
rollout_steps = self.n_steps * self.env.num_envs
|
186 |
-
ts_elapsed += rollout_steps
|
187 |
-
progress = ts_elapsed / total_timesteps
|
188 |
-
train_stats = self.train(accumulator.all_trajectories, progress, ts_elapsed)
|
189 |
-
train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
|
190 |
-
end_time = perf_counter()
|
191 |
-
self.tb_writer.add_scalar(
|
192 |
-
"train/steps_per_second",
|
193 |
-
rollout_steps / (end_time - start_time),
|
194 |
-
ts_elapsed,
|
195 |
-
)
|
196 |
-
if callback:
|
197 |
-
callback.on_step(timesteps_elapsed=rollout_steps)
|
198 |
-
|
199 |
-
return self
|
200 |
-
|
201 |
-
def _collect_trajectories(self, obs: VecEnvObs) -> PPOTrajectoryAccumulator:
|
202 |
-
self.policy.eval()
|
203 |
-
accumulator = PPOTrajectoryAccumulator(self.env.num_envs)
|
204 |
-
self.policy.reset_noise()
|
205 |
-
for i in range(self.n_steps):
|
206 |
-
if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
|
207 |
-
self.policy.reset_noise()
|
208 |
-
action, value, logp_a, clamped_action = self.policy.step(obs)
|
209 |
-
next_obs, reward, done, _ = self.env.step(clamped_action)
|
210 |
-
accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
|
211 |
-
obs = next_obs
|
212 |
-
return accumulator
|
213 |
-
|
214 |
-
def train(
|
215 |
-
self, trajectories: List[PPOTrajectory], progress: float, timesteps_elapsed: int
|
216 |
-
) -> TrainStats:
|
217 |
-
self.policy.train()
|
218 |
-
learning_rate = self.lr_schedule(progress)
|
219 |
-
update_learning_rate(self.optimizer, learning_rate)
|
220 |
-
self.tb_writer.add_scalar(
|
221 |
-
"charts/learning_rate",
|
222 |
-
self.optimizer.param_groups[0]["lr"],
|
223 |
-
timesteps_elapsed,
|
224 |
-
)
|
225 |
-
|
226 |
-
pi_clip = self.clip_range_schedule(progress)
|
227 |
-
self.tb_writer.add_scalar("charts/pi_clip", pi_clip, timesteps_elapsed)
|
228 |
-
if self.clip_range_vf_schedule:
|
229 |
-
v_clip = self.clip_range_vf_schedule(progress)
|
230 |
-
self.tb_writer.add_scalar("charts/v_clip", v_clip, timesteps_elapsed)
|
231 |
-
else:
|
232 |
-
v_clip = None
|
233 |
-
ent_coef = self.ent_coef_schedule(progress)
|
234 |
-
self.tb_writer.add_scalar("charts/ent_coef", ent_coef, timesteps_elapsed)
|
235 |
-
|
236 |
-
obs = torch.as_tensor(
|
237 |
-
np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
|
238 |
-
)
|
239 |
-
act = torch.as_tensor(
|
240 |
-
np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
|
241 |
-
)
|
242 |
-
rtg, adv = compute_rtg_and_advantage(
|
243 |
-
trajectories, self.policy, self.gamma, self.gae_lambda, self.device
|
244 |
-
)
|
245 |
-
orig_v = torch.as_tensor(
|
246 |
-
np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
|
247 |
-
)
|
248 |
-
orig_logp_a = torch.as_tensor(
|
249 |
-
np.concatenate([np.array(t.logp_a) for t in trajectories]),
|
250 |
-
device=self.device,
|
251 |
-
)
|
252 |
-
|
253 |
-
step_stats = []
|
254 |
-
for _ in range(self.n_epochs):
|
255 |
-
step_stats.clear()
|
256 |
-
if self.update_rtg_between_epochs:
|
257 |
-
rtg, adv = compute_rtg_and_advantage(
|
258 |
-
trajectories, self.policy, self.gamma, self.gae_lambda, self.device
|
259 |
-
)
|
260 |
-
else:
|
261 |
-
adv = compute_advantage(
|
262 |
-
trajectories, self.policy, self.gamma, self.gae_lambda, self.device
|
263 |
-
)
|
264 |
-
idxs = torch.randperm(len(obs))
|
265 |
-
for i in range(0, len(obs), self.batch_size):
|
266 |
-
mb_idxs = idxs[i : i + self.batch_size]
|
267 |
-
mb_adv = adv[mb_idxs]
|
268 |
-
if self.normalize_advantage:
|
269 |
-
mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
|
270 |
-
step_stats.append(
|
271 |
-
self._train_step(
|
272 |
-
pi_clip,
|
273 |
-
v_clip,
|
274 |
-
ent_coef,
|
275 |
-
obs[mb_idxs],
|
276 |
-
act[mb_idxs],
|
277 |
-
rtg[mb_idxs],
|
278 |
-
mb_adv,
|
279 |
-
orig_v[mb_idxs],
|
280 |
-
orig_logp_a[mb_idxs],
|
281 |
-
)
|
282 |
-
)
|
283 |
-
|
284 |
-
y_pred, y_true = orig_v.cpu().numpy(), rtg.cpu().numpy()
|
285 |
-
var_y = np.var(y_true).item()
|
286 |
-
explained_var = (
|
287 |
-
np.nan if var_y == 0 else 1 - np.var(y_true - y_pred).item() / var_y
|
288 |
-
)
|
289 |
-
|
290 |
-
return TrainStats(step_stats, explained_var)
|
291 |
-
|
292 |
-
def _train_step(
|
293 |
-
self,
|
294 |
-
pi_clip: float,
|
295 |
-
v_clip: Optional[float],
|
296 |
-
ent_coef: float,
|
297 |
-
obs: torch.Tensor,
|
298 |
-
act: torch.Tensor,
|
299 |
-
rtg: torch.Tensor,
|
300 |
-
adv: torch.Tensor,
|
301 |
-
orig_v: torch.Tensor,
|
302 |
-
orig_logp_a: torch.Tensor,
|
303 |
-
) -> TrainStepStats:
|
304 |
-
logp_a, entropy, v = self.policy(obs, act)
|
305 |
-
logratio = logp_a - orig_logp_a
|
306 |
-
ratio = torch.exp(logratio)
|
307 |
-
clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
|
308 |
-
pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
|
309 |
-
|
310 |
-
v_loss_unclipped = (v - rtg) ** 2
|
311 |
-
if v_clip:
|
312 |
-
v_loss_clipped = (
|
313 |
-
orig_v + torch.clamp(v - orig_v, -v_clip, v_clip) - rtg
|
314 |
-
) ** 2
|
315 |
-
v_loss = torch.max(v_loss_unclipped, v_loss_clipped).mean()
|
316 |
-
else:
|
317 |
-
v_loss = v_loss_unclipped.mean()
|
318 |
-
if self.ppo2_vf_coef_halving:
|
319 |
-
v_loss *= 0.5
|
320 |
-
|
321 |
-
entropy_loss = -entropy.mean()
|
322 |
-
|
323 |
-
loss = pi_loss + ent_coef * entropy_loss + self.vf_coef * v_loss
|
324 |
-
|
325 |
-
self.optimizer.zero_grad()
|
326 |
-
loss.backward()
|
327 |
-
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
|
328 |
-
self.optimizer.step()
|
329 |
-
|
330 |
-
with torch.no_grad():
|
331 |
-
approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
|
332 |
-
clipped_frac = (
|
333 |
-
((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
|
334 |
-
)
|
335 |
-
val_clipped_frac = (
|
336 |
-
(((v - orig_v).abs() > v_clip).float().mean().cpu().numpy().item())
|
337 |
-
if v_clip
|
338 |
-
else 0
|
339 |
-
)
|
340 |
-
|
341 |
-
return TrainStepStats(
|
342 |
-
loss.item(),
|
343 |
-
pi_loss.item(),
|
344 |
-
v_loss.item(),
|
345 |
-
entropy_loss.item(),
|
346 |
-
approx_kl,
|
347 |
-
clipped_frac,
|
348 |
-
val_clipped_frac,
|
349 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
publish/markdown_format.py
DELETED
@@ -1,210 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import pandas as pd
|
3 |
-
import wandb.apis.public
|
4 |
-
import yaml
|
5 |
-
|
6 |
-
from collections import defaultdict
|
7 |
-
from dataclasses import dataclass, asdict
|
8 |
-
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
|
9 |
-
from urllib.parse import urlparse
|
10 |
-
|
11 |
-
from runner.evaluate import Evaluation
|
12 |
-
|
13 |
-
EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
|
14 |
-
|
15 |
-
|
16 |
-
@dataclass
|
17 |
-
class EvaluationRow:
|
18 |
-
algo: str
|
19 |
-
env: str
|
20 |
-
seed: Optional[int]
|
21 |
-
reward_mean: float
|
22 |
-
reward_std: float
|
23 |
-
eval_episodes: int
|
24 |
-
best: str
|
25 |
-
wandb_url: str
|
26 |
-
|
27 |
-
@staticmethod
|
28 |
-
def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
|
29 |
-
results = defaultdict(list)
|
30 |
-
for r in rows:
|
31 |
-
for k, v in asdict(r).items():
|
32 |
-
results[k].append(v)
|
33 |
-
return pd.DataFrame(results)
|
34 |
-
|
35 |
-
|
36 |
-
class EvalTableData(NamedTuple):
|
37 |
-
run: wandb.apis.public.Run
|
38 |
-
evaluation: Evaluation
|
39 |
-
|
40 |
-
|
41 |
-
def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
|
42 |
-
best_stats = sorted(
|
43 |
-
[d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
|
44 |
-
)[0]
|
45 |
-
table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
|
46 |
-
rows = [
|
47 |
-
EvaluationRow(
|
48 |
-
config.algo,
|
49 |
-
config.env_id,
|
50 |
-
config.seed(),
|
51 |
-
stats.score.mean,
|
52 |
-
stats.score.std,
|
53 |
-
len(stats),
|
54 |
-
"*" if stats == best_stats else "",
|
55 |
-
f"[wandb]({r.url})",
|
56 |
-
)
|
57 |
-
for (r, (_, stats, config)) in table_data
|
58 |
-
]
|
59 |
-
df = EvaluationRow.data_frame(rows)
|
60 |
-
return df.to_markdown(index=False)
|
61 |
-
|
62 |
-
|
63 |
-
def github_project_link(github_url: str) -> str:
|
64 |
-
return f"[{urlparse(github_url).path}]({github_url})"
|
65 |
-
|
66 |
-
|
67 |
-
def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
|
68 |
-
algo_caps = algo.upper()
|
69 |
-
lines = [
|
70 |
-
f"# **{algo_caps}** Agent playing **{env}**",
|
71 |
-
f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
|
72 |
-
f"the {github_project_link(github_url)} repo.",
|
73 |
-
f"All models trained at this commit can be found at {wandb_report_url}.",
|
74 |
-
]
|
75 |
-
return "\n\n".join(lines)
|
76 |
-
|
77 |
-
|
78 |
-
def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
|
79 |
-
if not commit_hash:
|
80 |
-
return github_project_link(github_url)
|
81 |
-
return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
|
82 |
-
|
83 |
-
|
84 |
-
def results_section(
|
85 |
-
table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
|
86 |
-
) -> str:
|
87 |
-
# type: ignore
|
88 |
-
lines = [
|
89 |
-
"## Training Results",
|
90 |
-
f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
|
91 |
-
+ "agents using different initial seeds. "
|
92 |
-
+ f"These agents were trained by checking out "
|
93 |
-
+ f"{github_tree_link(github_url, commit_hash)}. "
|
94 |
-
+ "The best and last models were kept from each training. "
|
95 |
-
+ "This submission has loaded the best models from each training, reevaluates "
|
96 |
-
+ "them, and selects the best model from these latest evaluations (mean - std).",
|
97 |
-
]
|
98 |
-
lines.append(evaluation_table(table_data))
|
99 |
-
return "\n\n".join(lines)
|
100 |
-
|
101 |
-
|
102 |
-
def prerequisites_section() -> str:
|
103 |
-
return """
|
104 |
-
### Prerequisites: Weights & Biases (WandB)
|
105 |
-
Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
|
106 |
-
By default training goes to a rl-algo-impls project while benchmarks go to
|
107 |
-
rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
|
108 |
-
models and the model weights are uploaded to WandB.
|
109 |
-
|
110 |
-
Before doing anything below, you'll need to create a wandb account and run `wandb
|
111 |
-
login`.
|
112 |
-
"""
|
113 |
-
|
114 |
-
|
115 |
-
def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
|
116 |
-
return f"""
|
117 |
-
## Usage
|
118 |
-
{urlparse(github_url).path}: {github_url}
|
119 |
-
|
120 |
-
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
121 |
-
implementation could be sufficiently different to not be able to reproduce similar
|
122 |
-
results. You might need to checkout the commit the agent was trained on:
|
123 |
-
{github_tree_link(github_url, commit_hash)}.
|
124 |
-
```
|
125 |
-
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
126 |
-
python enjoy.py --wandb-run-path={run_path}
|
127 |
-
```
|
128 |
-
|
129 |
-
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
130 |
-
Colab starting from the
|
131 |
-
[colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
|
132 |
-
notebook.
|
133 |
-
"""
|
134 |
-
|
135 |
-
|
136 |
-
def training_setion(
|
137 |
-
github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
|
138 |
-
) -> str:
|
139 |
-
return f"""
|
140 |
-
## Training
|
141 |
-
If you want the highest chance to reproduce these results, you'll want to checkout the
|
142 |
-
commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
|
143 |
-
training is deterministic, different hardware will give different results.
|
144 |
-
|
145 |
-
```
|
146 |
-
python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
|
147 |
-
```
|
148 |
-
|
149 |
-
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
150 |
-
Colab starting from the
|
151 |
-
[colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
|
152 |
-
notebook.
|
153 |
-
"""
|
154 |
-
|
155 |
-
|
156 |
-
def benchmarking_section(report_url: str) -> str:
|
157 |
-
return f"""
|
158 |
-
## Benchmarking (with Lambda Labs instance)
|
159 |
-
This and other models from {report_url} were generated by running a script on a Lambda
|
160 |
-
Labs instance. In a Lambda Labs instance terminal:
|
161 |
-
```
|
162 |
-
git clone git@github.com:sgoodfriend/rl-algo-impls.git
|
163 |
-
cd rl-algo-impls
|
164 |
-
bash ./lambda_labs/setup.sh
|
165 |
-
wandb login
|
166 |
-
bash ./lambda_labs/benchmark.sh
|
167 |
-
```
|
168 |
-
|
169 |
-
### Alternative: Google Colab Pro+
|
170 |
-
As an alternative,
|
171 |
-
[colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
|
172 |
-
can be used. However, this requires a Google Colab Pro+ subscription and running across
|
173 |
-
4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
|
174 |
-
"""
|
175 |
-
|
176 |
-
|
177 |
-
def hyperparams_section(run_config: Dict[str, Any]) -> str:
|
178 |
-
return f"""
|
179 |
-
## Hyperparameters
|
180 |
-
This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
|
181 |
-
run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
|
182 |
-
close and has some additional data:
|
183 |
-
```
|
184 |
-
{yaml.dump(run_config)}
|
185 |
-
```
|
186 |
-
"""
|
187 |
-
|
188 |
-
|
189 |
-
def model_card_text(
|
190 |
-
algo: str,
|
191 |
-
env: str,
|
192 |
-
github_url: str,
|
193 |
-
commit_hash: str,
|
194 |
-
wandb_report_url: str,
|
195 |
-
table_data: List[EvalTableData],
|
196 |
-
best_eval: EvalTableData,
|
197 |
-
) -> str:
|
198 |
-
run, (_, _, config) = best_eval
|
199 |
-
run_path = "/".join(run.path)
|
200 |
-
return "\n\n".join(
|
201 |
-
[
|
202 |
-
header_section(algo, env, github_url, wandb_report_url),
|
203 |
-
results_section(table_data, algo, github_url, commit_hash),
|
204 |
-
prerequisites_section(),
|
205 |
-
usage_section(github_url, run_path, commit_hash),
|
206 |
-
training_setion(github_url, commit_hash, algo, env, config.seed()),
|
207 |
-
benchmarking_section(wandb_report_url),
|
208 |
-
hyperparams_section(run.config),
|
209 |
-
]
|
210 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
replay.meta.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/
|
|
|
1 |
+
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/tmppc8ru6js/a2c-AntBulletEnv-v0/replay.mp4"]}, "episode": {"r": 3004.433349609375, "l": 1000, "t": 28.542343}}
|
results.json
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
{"mean_reward": 2992.133482391329, "std_reward": 55.53945472831181, "is_deterministic": true, "n_eval_episodes": 10, "eval_datetime": "2023-01-24T20:22:24.991332"}
|
|
|
|
rl_algo_impls/benchmark_publish.py
CHANGED
@@ -54,8 +54,8 @@ def benchmark_publish() -> None:
|
|
54 |
"--virtual-display", action="store_true", help="Use headless virtual display"
|
55 |
)
|
56 |
# parser.set_defaults(
|
57 |
-
# wandb_tags=["
|
58 |
-
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/
|
59 |
# envs=[],
|
60 |
# exclude_envs=[],
|
61 |
# )
|
|
|
54 |
"--virtual-display", action="store_true", help="Use headless virtual display"
|
55 |
)
|
56 |
# parser.set_defaults(
|
57 |
+
# wandb_tags=["benchmark_2067e21", "host_155-248-199-228"],
|
58 |
+
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/09frjfcs",
|
59 |
# envs=[],
|
60 |
# exclude_envs=[],
|
61 |
# )
|
rl_algo_impls/huggingface_publish.py
CHANGED
@@ -162,6 +162,7 @@ def publish(
|
|
162 |
path_in_repo="",
|
163 |
commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
|
164 |
token=huggingface_token,
|
|
|
165 |
)
|
166 |
print(f"Pushed model to the hub: {repo_url}")
|
167 |
|
|
|
162 |
path_in_repo="",
|
163 |
commit_message=f"{algo.upper()} playing {env_id} from {github_url}/tree/{commit_hash}",
|
164 |
token=huggingface_token,
|
165 |
+
delete_patterns="*",
|
166 |
)
|
167 |
print(f"Pushed model to the hub: {repo_url}")
|
168 |
|
runner/config.py
DELETED
@@ -1,155 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
from datetime import datetime
|
4 |
-
from dataclasses import dataclass
|
5 |
-
from typing import Any, Dict, NamedTuple, Optional, TypedDict, Union
|
6 |
-
|
7 |
-
|
8 |
-
@dataclass
|
9 |
-
class RunArgs:
|
10 |
-
algo: str
|
11 |
-
env: str
|
12 |
-
seed: Optional[int] = None
|
13 |
-
use_deterministic_algorithms: bool = True
|
14 |
-
|
15 |
-
|
16 |
-
class EnvHyperparams(NamedTuple):
|
17 |
-
env_type: str = "gymvec"
|
18 |
-
n_envs: int = 1
|
19 |
-
frame_stack: int = 1
|
20 |
-
make_kwargs: Optional[Dict[str, Any]] = None
|
21 |
-
no_reward_timeout_steps: Optional[int] = None
|
22 |
-
no_reward_fire_steps: Optional[int] = None
|
23 |
-
vec_env_class: str = "sync"
|
24 |
-
normalize: bool = False
|
25 |
-
normalize_kwargs: Optional[Dict[str, Any]] = None
|
26 |
-
rolling_length: int = 100
|
27 |
-
train_record_video: bool = False
|
28 |
-
video_step_interval: Union[int, float] = 1_000_000
|
29 |
-
initial_steps_to_truncate: Optional[int] = None
|
30 |
-
clip_atari_rewards: bool = True
|
31 |
-
|
32 |
-
|
33 |
-
class Hyperparams(TypedDict, total=False):
|
34 |
-
device: str
|
35 |
-
n_timesteps: Union[int, float]
|
36 |
-
env_hyperparams: Dict[str, Any]
|
37 |
-
policy_hyperparams: Dict[str, Any]
|
38 |
-
algo_hyperparams: Dict[str, Any]
|
39 |
-
eval_params: Dict[str, Any]
|
40 |
-
|
41 |
-
|
42 |
-
@dataclass
|
43 |
-
class Config:
|
44 |
-
args: RunArgs
|
45 |
-
hyperparams: Hyperparams
|
46 |
-
root_dir: str
|
47 |
-
run_id: str = datetime.now().isoformat()
|
48 |
-
|
49 |
-
def seed(self, training: bool = True) -> Optional[int]:
|
50 |
-
seed = self.args.seed
|
51 |
-
if training or seed is None:
|
52 |
-
return seed
|
53 |
-
return seed + self.env_hyperparams.get("n_envs", 1)
|
54 |
-
|
55 |
-
@property
|
56 |
-
def device(self) -> str:
|
57 |
-
return self.hyperparams.get("device", "auto")
|
58 |
-
|
59 |
-
@property
|
60 |
-
def n_timesteps(self) -> int:
|
61 |
-
return int(self.hyperparams.get("n_timesteps", 100_000))
|
62 |
-
|
63 |
-
@property
|
64 |
-
def env_hyperparams(self) -> Dict[str, Any]:
|
65 |
-
return self.hyperparams.get("env_hyperparams", {})
|
66 |
-
|
67 |
-
@property
|
68 |
-
def policy_hyperparams(self) -> Dict[str, Any]:
|
69 |
-
return self.hyperparams.get("policy_hyperparams", {})
|
70 |
-
|
71 |
-
@property
|
72 |
-
def algo_hyperparams(self) -> Dict[str, Any]:
|
73 |
-
return self.hyperparams.get("algo_hyperparams", {})
|
74 |
-
|
75 |
-
@property
|
76 |
-
def eval_params(self) -> Dict[str, Any]:
|
77 |
-
return self.hyperparams.get("eval_params", {})
|
78 |
-
|
79 |
-
@property
|
80 |
-
def algo(self) -> str:
|
81 |
-
return self.args.algo
|
82 |
-
|
83 |
-
@property
|
84 |
-
def env_id(self) -> str:
|
85 |
-
return self.hyperparams.get("env_id") or self.args.env
|
86 |
-
|
87 |
-
def model_name(self, include_seed: bool = True) -> str:
|
88 |
-
# Use arg env name instead of environment name
|
89 |
-
parts = [self.algo, self.args.env]
|
90 |
-
if include_seed and self.args.seed is not None:
|
91 |
-
parts.append(f"S{self.args.seed}")
|
92 |
-
|
93 |
-
# Assume that the custom arg name already has the necessary information
|
94 |
-
if not self.hyperparams.get("env_id"):
|
95 |
-
make_kwargs = self.env_hyperparams.get("make_kwargs", {})
|
96 |
-
if make_kwargs:
|
97 |
-
for k, v in make_kwargs.items():
|
98 |
-
if type(v) == bool and v:
|
99 |
-
parts.append(k)
|
100 |
-
elif type(v) == int and v:
|
101 |
-
parts.append(f"{k}{v}")
|
102 |
-
else:
|
103 |
-
parts.append(str(v))
|
104 |
-
|
105 |
-
return "-".join(parts)
|
106 |
-
|
107 |
-
@property
|
108 |
-
def run_name(self) -> str:
|
109 |
-
parts = [self.model_name(), self.run_id]
|
110 |
-
return "-".join(parts)
|
111 |
-
|
112 |
-
@property
|
113 |
-
def saved_models_dir(self) -> str:
|
114 |
-
return os.path.join(self.root_dir, "saved_models")
|
115 |
-
|
116 |
-
@property
|
117 |
-
def downloaded_models_dir(self) -> str:
|
118 |
-
return os.path.join(self.root_dir, "downloaded_models")
|
119 |
-
|
120 |
-
def model_dir_name(
|
121 |
-
self,
|
122 |
-
best: bool = False,
|
123 |
-
extension: str = "",
|
124 |
-
) -> str:
|
125 |
-
return self.model_name() + ("-best" if best else "") + extension
|
126 |
-
|
127 |
-
def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
|
128 |
-
return os.path.join(
|
129 |
-
self.saved_models_dir if not downloaded else self.downloaded_models_dir,
|
130 |
-
self.model_dir_name(best=best),
|
131 |
-
)
|
132 |
-
|
133 |
-
@property
|
134 |
-
def runs_dir(self) -> str:
|
135 |
-
return os.path.join(self.root_dir, "runs")
|
136 |
-
|
137 |
-
@property
|
138 |
-
def tensorboard_summary_path(self) -> str:
|
139 |
-
return os.path.join(self.runs_dir, self.run_name)
|
140 |
-
|
141 |
-
@property
|
142 |
-
def logs_path(self) -> str:
|
143 |
-
return os.path.join(self.runs_dir, f"log.yml")
|
144 |
-
|
145 |
-
@property
|
146 |
-
def videos_dir(self) -> str:
|
147 |
-
return os.path.join(self.root_dir, "videos")
|
148 |
-
|
149 |
-
@property
|
150 |
-
def video_prefix(self) -> str:
|
151 |
-
return os.path.join(self.videos_dir, self.model_name())
|
152 |
-
|
153 |
-
@property
|
154 |
-
def best_videos_dir(self) -> str:
|
155 |
-
return os.path.join(self.videos_dir, f"{self.model_name()}-best")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
runner/env.py
DELETED
@@ -1,284 +0,0 @@
|
|
1 |
-
import gym
|
2 |
-
import numpy as np
|
3 |
-
import os
|
4 |
-
|
5 |
-
from gym.vector.async_vector_env import AsyncVectorEnv
|
6 |
-
from gym.vector.sync_vector_env import SyncVectorEnv
|
7 |
-
from gym.wrappers.resize_observation import ResizeObservation
|
8 |
-
from gym.wrappers.gray_scale_observation import GrayScaleObservation
|
9 |
-
from gym.wrappers.frame_stack import FrameStack
|
10 |
-
from stable_baselines3.common.atari_wrappers import (
|
11 |
-
MaxAndSkipEnv,
|
12 |
-
NoopResetEnv,
|
13 |
-
)
|
14 |
-
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
|
15 |
-
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
|
16 |
-
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
|
17 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
18 |
-
from typing import Callable, Optional
|
19 |
-
|
20 |
-
from runner.config import Config, EnvHyperparams
|
21 |
-
from shared.policy.policy import VEC_NORMALIZE_FILENAME
|
22 |
-
from wrappers.atari_wrappers import EpisodicLifeEnv, FireOnLifeStarttEnv, ClipRewardEnv
|
23 |
-
from wrappers.episode_record_video import EpisodeRecordVideo
|
24 |
-
from wrappers.episode_stats_writer import EpisodeStatsWriter
|
25 |
-
from wrappers.initial_step_truncate_wrapper import InitialStepTruncateWrapper
|
26 |
-
from wrappers.is_vector_env import IsVectorEnv
|
27 |
-
from wrappers.noop_env_seed import NoopEnvSeed
|
28 |
-
from wrappers.normalize import NormalizeObservation, NormalizeReward
|
29 |
-
from wrappers.sync_vector_env_render_compat import SyncVectorEnvRenderCompat
|
30 |
-
from wrappers.transpose_image_observation import TransposeImageObservation
|
31 |
-
from wrappers.vectorable_wrapper import VecEnv
|
32 |
-
from wrappers.video_compat_wrapper import VideoCompatWrapper
|
33 |
-
|
34 |
-
|
35 |
-
def make_env(
|
36 |
-
config: Config,
|
37 |
-
hparams: EnvHyperparams,
|
38 |
-
training: bool = True,
|
39 |
-
render: bool = False,
|
40 |
-
normalize_load_path: Optional[str] = None,
|
41 |
-
tb_writer: Optional[SummaryWriter] = None,
|
42 |
-
) -> VecEnv:
|
43 |
-
if hparams.env_type == "procgen":
|
44 |
-
return _make_procgen_env(
|
45 |
-
config,
|
46 |
-
hparams,
|
47 |
-
training=training,
|
48 |
-
render=render,
|
49 |
-
normalize_load_path=normalize_load_path,
|
50 |
-
tb_writer=tb_writer,
|
51 |
-
)
|
52 |
-
elif hparams.env_type in {"sb3vec", "gymvec"}:
|
53 |
-
return _make_vec_env(
|
54 |
-
config,
|
55 |
-
hparams,
|
56 |
-
training=training,
|
57 |
-
render=render,
|
58 |
-
normalize_load_path=normalize_load_path,
|
59 |
-
tb_writer=tb_writer,
|
60 |
-
)
|
61 |
-
else:
|
62 |
-
raise ValueError(f"env_type {hparams.env_type} not supported")
|
63 |
-
|
64 |
-
|
65 |
-
def make_eval_env(
|
66 |
-
config: Config,
|
67 |
-
hparams: EnvHyperparams,
|
68 |
-
override_n_envs: Optional[int] = None,
|
69 |
-
**kwargs,
|
70 |
-
) -> VecEnv:
|
71 |
-
kwargs = kwargs.copy()
|
72 |
-
kwargs["training"] = False
|
73 |
-
if override_n_envs is not None:
|
74 |
-
hparams_kwargs = hparams._asdict()
|
75 |
-
hparams_kwargs["n_envs"] = override_n_envs
|
76 |
-
if override_n_envs == 1:
|
77 |
-
hparams_kwargs["vec_env_class"] = "sync"
|
78 |
-
hparams = EnvHyperparams(**hparams_kwargs)
|
79 |
-
return make_env(config, hparams, **kwargs)
|
80 |
-
|
81 |
-
|
82 |
-
def _make_vec_env(
|
83 |
-
config: Config,
|
84 |
-
hparams: EnvHyperparams,
|
85 |
-
training: bool = True,
|
86 |
-
render: bool = False,
|
87 |
-
normalize_load_path: Optional[str] = None,
|
88 |
-
tb_writer: Optional[SummaryWriter] = None,
|
89 |
-
) -> VecEnv:
|
90 |
-
(
|
91 |
-
env_type,
|
92 |
-
n_envs,
|
93 |
-
frame_stack,
|
94 |
-
make_kwargs,
|
95 |
-
no_reward_timeout_steps,
|
96 |
-
no_reward_fire_steps,
|
97 |
-
vec_env_class,
|
98 |
-
normalize,
|
99 |
-
normalize_kwargs,
|
100 |
-
rolling_length,
|
101 |
-
train_record_video,
|
102 |
-
video_step_interval,
|
103 |
-
initial_steps_to_truncate,
|
104 |
-
clip_atari_rewards,
|
105 |
-
) = hparams
|
106 |
-
|
107 |
-
if "BulletEnv" in config.env_id:
|
108 |
-
import pybullet_envs
|
109 |
-
|
110 |
-
spec = gym.spec(config.env_id)
|
111 |
-
seed = config.seed(training=training)
|
112 |
-
|
113 |
-
make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
|
114 |
-
if "BulletEnv" in config.env_id and render:
|
115 |
-
make_kwargs["render"] = True
|
116 |
-
if "CarRacing" in config.env_id:
|
117 |
-
make_kwargs["verbose"] = 0
|
118 |
-
if "procgen" in config.env_id:
|
119 |
-
if not render:
|
120 |
-
make_kwargs["render_mode"] = "rgb_array"
|
121 |
-
|
122 |
-
def make(idx: int) -> Callable[[], gym.Env]:
|
123 |
-
def _make() -> gym.Env:
|
124 |
-
env = gym.make(config.env_id, **make_kwargs)
|
125 |
-
env = gym.wrappers.RecordEpisodeStatistics(env)
|
126 |
-
env = VideoCompatWrapper(env)
|
127 |
-
if training and train_record_video and idx == 0:
|
128 |
-
env = EpisodeRecordVideo(
|
129 |
-
env,
|
130 |
-
config.video_prefix,
|
131 |
-
step_increment=n_envs,
|
132 |
-
video_step_interval=int(video_step_interval),
|
133 |
-
)
|
134 |
-
if training and initial_steps_to_truncate:
|
135 |
-
env = InitialStepTruncateWrapper(
|
136 |
-
env, idx * initial_steps_to_truncate // n_envs
|
137 |
-
)
|
138 |
-
if "AtariEnv" in spec.entry_point: # type: ignore
|
139 |
-
env = NoopResetEnv(env, noop_max=30)
|
140 |
-
env = MaxAndSkipEnv(env, skip=4)
|
141 |
-
env = EpisodicLifeEnv(env, training=training)
|
142 |
-
action_meanings = env.unwrapped.get_action_meanings()
|
143 |
-
if "FIRE" in action_meanings: # type: ignore
|
144 |
-
env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
|
145 |
-
if clip_atari_rewards:
|
146 |
-
env = ClipRewardEnv(env, training=training)
|
147 |
-
env = ResizeObservation(env, (84, 84))
|
148 |
-
env = GrayScaleObservation(env, keep_dim=False)
|
149 |
-
env = FrameStack(env, frame_stack)
|
150 |
-
elif "CarRacing" in config.env_id:
|
151 |
-
env = ResizeObservation(env, (64, 64))
|
152 |
-
env = GrayScaleObservation(env, keep_dim=False)
|
153 |
-
env = FrameStack(env, frame_stack)
|
154 |
-
elif "procgen" in config.env_id:
|
155 |
-
# env = GrayScaleObservation(env, keep_dim=False)
|
156 |
-
env = NoopEnvSeed(env)
|
157 |
-
env = TransposeImageObservation(env)
|
158 |
-
if frame_stack > 1:
|
159 |
-
env = FrameStack(env, frame_stack)
|
160 |
-
|
161 |
-
if no_reward_timeout_steps:
|
162 |
-
from wrappers.no_reward_timeout import NoRewardTimeout
|
163 |
-
|
164 |
-
env = NoRewardTimeout(
|
165 |
-
env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
|
166 |
-
)
|
167 |
-
|
168 |
-
if seed is not None:
|
169 |
-
env.seed(seed + idx)
|
170 |
-
env.action_space.seed(seed + idx)
|
171 |
-
env.observation_space.seed(seed + idx)
|
172 |
-
|
173 |
-
return env
|
174 |
-
|
175 |
-
return _make
|
176 |
-
|
177 |
-
if env_type == "sb3vec":
|
178 |
-
VecEnvClass = {"sync": DummyVecEnv, "async": SubprocVecEnv}[vec_env_class]
|
179 |
-
elif env_type == "gymvec":
|
180 |
-
VecEnvClass = {"sync": SyncVectorEnv, "async": AsyncVectorEnv}[vec_env_class]
|
181 |
-
else:
|
182 |
-
raise ValueError(f"env_type {env_type} unsupported")
|
183 |
-
envs = VecEnvClass([make(i) for i in range(n_envs)])
|
184 |
-
if env_type == "gymvec" and vec_env_class == "sync":
|
185 |
-
envs = SyncVectorEnvRenderCompat(envs)
|
186 |
-
if training:
|
187 |
-
assert tb_writer
|
188 |
-
envs = EpisodeStatsWriter(
|
189 |
-
envs, tb_writer, training=training, rolling_length=rolling_length
|
190 |
-
)
|
191 |
-
if normalize:
|
192 |
-
normalize_kwargs = normalize_kwargs or {}
|
193 |
-
if env_type == "sb3vec":
|
194 |
-
if normalize_load_path:
|
195 |
-
envs = VecNormalize.load(
|
196 |
-
os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
|
197 |
-
envs, # type: ignore
|
198 |
-
)
|
199 |
-
else:
|
200 |
-
envs = VecNormalize(
|
201 |
-
envs, # type: ignore
|
202 |
-
training=training,
|
203 |
-
**normalize_kwargs,
|
204 |
-
)
|
205 |
-
if not training:
|
206 |
-
envs.norm_reward = False
|
207 |
-
else:
|
208 |
-
if normalize_kwargs.get("norm_obs", True):
|
209 |
-
envs = NormalizeObservation(
|
210 |
-
envs, training=training, clip=normalize_kwargs.get("clip_obs", 10.0)
|
211 |
-
)
|
212 |
-
if training and normalize_kwargs.get("norm_reward", True):
|
213 |
-
envs = NormalizeReward(
|
214 |
-
envs,
|
215 |
-
training=training,
|
216 |
-
clip=normalize_kwargs.get("clip_reward", 10.0),
|
217 |
-
)
|
218 |
-
return envs
|
219 |
-
|
220 |
-
|
221 |
-
def _make_procgen_env(
|
222 |
-
config: Config,
|
223 |
-
hparams: EnvHyperparams,
|
224 |
-
training: bool = True,
|
225 |
-
render: bool = False,
|
226 |
-
normalize_load_path: Optional[str] = None,
|
227 |
-
tb_writer: Optional[SummaryWriter] = None,
|
228 |
-
) -> VecEnv:
|
229 |
-
from gym3 import ViewerWrapper, ExtractDictObWrapper
|
230 |
-
from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv
|
231 |
-
|
232 |
-
(
|
233 |
-
_,
|
234 |
-
n_envs,
|
235 |
-
frame_stack,
|
236 |
-
make_kwargs,
|
237 |
-
_, # no_reward_timeout_steps
|
238 |
-
_, # no_reward_fire_steps
|
239 |
-
_, # vec_env_class
|
240 |
-
normalize,
|
241 |
-
normalize_kwargs,
|
242 |
-
rolling_length,
|
243 |
-
_, # train_record_video
|
244 |
-
_, # video_step_interval
|
245 |
-
_, # initial_steps_to_truncate
|
246 |
-
_, # clip_atari_rewards
|
247 |
-
) = hparams
|
248 |
-
|
249 |
-
seed = config.seed(training=training)
|
250 |
-
|
251 |
-
make_kwargs = make_kwargs or {}
|
252 |
-
make_kwargs["render_mode"] = "rgb_array"
|
253 |
-
if seed is not None:
|
254 |
-
make_kwargs["rand_seed"] = seed
|
255 |
-
|
256 |
-
envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs)
|
257 |
-
envs = ExtractDictObWrapper(envs, key="rgb")
|
258 |
-
if render:
|
259 |
-
envs = ViewerWrapper(envs, info_key="rgb")
|
260 |
-
envs = ToBaselinesVecEnv(envs)
|
261 |
-
envs = IsVectorEnv(envs)
|
262 |
-
# TODO: Handle Grayscale and/or FrameStack
|
263 |
-
envs = TransposeImageObservation(envs)
|
264 |
-
|
265 |
-
envs = gym.wrappers.RecordEpisodeStatistics(envs)
|
266 |
-
|
267 |
-
if seed is not None:
|
268 |
-
envs.action_space.seed(seed)
|
269 |
-
envs.observation_space.seed(seed)
|
270 |
-
|
271 |
-
if training:
|
272 |
-
assert tb_writer
|
273 |
-
envs = EpisodeStatsWriter(
|
274 |
-
envs, tb_writer, training=training, rolling_length=rolling_length
|
275 |
-
)
|
276 |
-
if normalize and training:
|
277 |
-
normalize_kwargs = normalize_kwargs or {}
|
278 |
-
envs = gym.wrappers.NormalizeReward(envs)
|
279 |
-
clip_obs = normalize_kwargs.get("clip_reward", 10.0)
|
280 |
-
envs = gym.wrappers.TransformReward(
|
281 |
-
envs, lambda r: np.clip(r, -clip_obs, clip_obs)
|
282 |
-
)
|
283 |
-
|
284 |
-
return envs # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
runner/evaluate.py
DELETED
@@ -1,103 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import shutil
|
3 |
-
|
4 |
-
from dataclasses import dataclass
|
5 |
-
from typing import NamedTuple, Optional
|
6 |
-
|
7 |
-
from runner.env import make_eval_env
|
8 |
-
from runner.config import Config, EnvHyperparams, RunArgs
|
9 |
-
from runner.running_utils import (
|
10 |
-
load_hyperparams,
|
11 |
-
set_seeds,
|
12 |
-
get_device,
|
13 |
-
make_policy,
|
14 |
-
)
|
15 |
-
from shared.callbacks.eval_callback import evaluate
|
16 |
-
from shared.policy.policy import Policy
|
17 |
-
from shared.stats import EpisodesStats
|
18 |
-
|
19 |
-
|
20 |
-
@dataclass
|
21 |
-
class EvalArgs(RunArgs):
|
22 |
-
render: bool = True
|
23 |
-
best: bool = True
|
24 |
-
n_envs: Optional[int] = 1
|
25 |
-
n_episodes: int = 3
|
26 |
-
deterministic_eval: Optional[bool] = None
|
27 |
-
no_print_returns: bool = False
|
28 |
-
wandb_run_path: Optional[str] = None
|
29 |
-
|
30 |
-
|
31 |
-
class Evaluation(NamedTuple):
|
32 |
-
policy: Policy
|
33 |
-
stats: EpisodesStats
|
34 |
-
config: Config
|
35 |
-
|
36 |
-
|
37 |
-
def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
|
38 |
-
if args.wandb_run_path:
|
39 |
-
import wandb
|
40 |
-
|
41 |
-
api = wandb.Api()
|
42 |
-
run = api.run(args.wandb_run_path)
|
43 |
-
hyperparams = run.config
|
44 |
-
|
45 |
-
args.algo = hyperparams["algo"]
|
46 |
-
args.env = hyperparams["env"]
|
47 |
-
args.seed = hyperparams.get("seed", None)
|
48 |
-
args.use_deterministic_algorithms = hyperparams.get(
|
49 |
-
"use_deterministic_algorithms", True
|
50 |
-
)
|
51 |
-
|
52 |
-
config = Config(args, hyperparams, root_dir)
|
53 |
-
model_path = config.model_dir_path(best=args.best, downloaded=True)
|
54 |
-
|
55 |
-
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
|
56 |
-
run.file(model_archive_name).download()
|
57 |
-
if os.path.isdir(model_path):
|
58 |
-
shutil.rmtree(model_path)
|
59 |
-
shutil.unpack_archive(model_archive_name, model_path)
|
60 |
-
os.remove(model_archive_name)
|
61 |
-
else:
|
62 |
-
hyperparams = load_hyperparams(args.algo, args.env, root_dir)
|
63 |
-
|
64 |
-
config = Config(args, hyperparams, root_dir)
|
65 |
-
model_path = config.model_dir_path(best=args.best)
|
66 |
-
|
67 |
-
print(args)
|
68 |
-
|
69 |
-
set_seeds(args.seed, args.use_deterministic_algorithms)
|
70 |
-
|
71 |
-
env = make_eval_env(
|
72 |
-
config,
|
73 |
-
EnvHyperparams(**config.env_hyperparams),
|
74 |
-
override_n_envs=args.n_envs,
|
75 |
-
render=args.render,
|
76 |
-
normalize_load_path=model_path,
|
77 |
-
)
|
78 |
-
device = get_device(config.device, env)
|
79 |
-
policy = make_policy(
|
80 |
-
args.algo,
|
81 |
-
env,
|
82 |
-
device,
|
83 |
-
load_path=model_path,
|
84 |
-
**config.policy_hyperparams,
|
85 |
-
).eval()
|
86 |
-
|
87 |
-
deterministic = (
|
88 |
-
args.deterministic_eval
|
89 |
-
if args.deterministic_eval is not None
|
90 |
-
else config.eval_params.get("deterministic", True)
|
91 |
-
)
|
92 |
-
return Evaluation(
|
93 |
-
policy,
|
94 |
-
evaluate(
|
95 |
-
env,
|
96 |
-
policy,
|
97 |
-
args.n_episodes,
|
98 |
-
render=args.render,
|
99 |
-
deterministic=deterministic,
|
100 |
-
print_returns=not args.no_print_returns,
|
101 |
-
),
|
102 |
-
config,
|
103 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
runner/running_utils.py
DELETED
@@ -1,195 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import gym
|
3 |
-
import json
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import numpy as np
|
6 |
-
import os
|
7 |
-
import random
|
8 |
-
import torch
|
9 |
-
import torch.backends.cudnn
|
10 |
-
import yaml
|
11 |
-
|
12 |
-
from gym.spaces import Box, Discrete
|
13 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
14 |
-
from typing import Dict, Optional, Type, Union
|
15 |
-
|
16 |
-
from runner.config import Hyperparams
|
17 |
-
from shared.algorithm import Algorithm
|
18 |
-
from shared.callbacks.eval_callback import EvalCallback
|
19 |
-
from shared.policy.on_policy import ActorCritic
|
20 |
-
from shared.policy.policy import Policy
|
21 |
-
|
22 |
-
from a2c.a2c import A2C
|
23 |
-
from dqn.dqn import DQN
|
24 |
-
from dqn.policy import DQNPolicy
|
25 |
-
from ppo.ppo import PPO
|
26 |
-
from vpg.vpg import VanillaPolicyGradient
|
27 |
-
from vpg.policy import VPGActorCritic
|
28 |
-
from wrappers.vectorable_wrapper import VecEnv, single_observation_space
|
29 |
-
|
30 |
-
ALGOS: Dict[str, Type[Algorithm]] = {
|
31 |
-
"dqn": DQN,
|
32 |
-
"vpg": VanillaPolicyGradient,
|
33 |
-
"ppo": PPO,
|
34 |
-
"a2c": A2C,
|
35 |
-
}
|
36 |
-
POLICIES: Dict[str, Type[Policy]] = {
|
37 |
-
"dqn": DQNPolicy,
|
38 |
-
"vpg": VPGActorCritic,
|
39 |
-
"ppo": ActorCritic,
|
40 |
-
"a2c": ActorCritic,
|
41 |
-
}
|
42 |
-
|
43 |
-
HYPERPARAMS_PATH = "hyperparams"
|
44 |
-
|
45 |
-
|
46 |
-
def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
|
47 |
-
parser = argparse.ArgumentParser()
|
48 |
-
parser.add_argument(
|
49 |
-
"--algo",
|
50 |
-
default=["dqn"],
|
51 |
-
type=str,
|
52 |
-
choices=list(ALGOS.keys()),
|
53 |
-
nargs="+" if multiple else 1,
|
54 |
-
help="Abbreviation(s) of algorithm(s)",
|
55 |
-
)
|
56 |
-
parser.add_argument(
|
57 |
-
"--env",
|
58 |
-
default=["CartPole-v1"],
|
59 |
-
type=str,
|
60 |
-
nargs="+" if multiple else 1,
|
61 |
-
help="Name of environment(s) in gym",
|
62 |
-
)
|
63 |
-
parser.add_argument(
|
64 |
-
"--seed",
|
65 |
-
default=[1],
|
66 |
-
type=int,
|
67 |
-
nargs="*" if multiple else "?",
|
68 |
-
help="Seeds to run experiment. Unset will do one run with no set seed",
|
69 |
-
)
|
70 |
-
parser.add_argument(
|
71 |
-
"--use-deterministic-algorithms",
|
72 |
-
default=True,
|
73 |
-
type=bool,
|
74 |
-
help="If seed set, set torch.use_deterministic_algorithms",
|
75 |
-
)
|
76 |
-
return parser
|
77 |
-
|
78 |
-
|
79 |
-
def load_hyperparams(algo: str, env_id: str, root_path: str) -> Hyperparams:
|
80 |
-
hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
|
81 |
-
with open(hyperparams_path, "r") as f:
|
82 |
-
hyperparams_dict = yaml.safe_load(f)
|
83 |
-
|
84 |
-
if env_id in hyperparams_dict:
|
85 |
-
return hyperparams_dict[env_id]
|
86 |
-
|
87 |
-
if "BulletEnv" in env_id:
|
88 |
-
import pybullet_envs
|
89 |
-
spec = gym.spec(env_id)
|
90 |
-
if "AtariEnv" in str(spec.entry_point) and "_atari" in hyperparams_dict:
|
91 |
-
return hyperparams_dict["_atari"]
|
92 |
-
else:
|
93 |
-
raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
|
94 |
-
|
95 |
-
|
96 |
-
def get_device(device: str, env: VecEnv) -> torch.device:
|
97 |
-
# cuda by default
|
98 |
-
if device == "auto":
|
99 |
-
device = "cuda"
|
100 |
-
# Apple MPS is a second choice (sometimes)
|
101 |
-
if device == "cuda" and not torch.cuda.is_available():
|
102 |
-
device = "mps"
|
103 |
-
# If no MPS, fallback to cpu
|
104 |
-
if device == "mps" and not torch.backends.mps.is_available():
|
105 |
-
device = "cpu"
|
106 |
-
# Simple environments like Discreet and 1-D Boxes might also be better
|
107 |
-
# served with the CPU.
|
108 |
-
if device == "mps":
|
109 |
-
obs_space = single_observation_space(env)
|
110 |
-
if isinstance(obs_space, Discrete):
|
111 |
-
device = "cpu"
|
112 |
-
elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
|
113 |
-
device = "cpu"
|
114 |
-
print(f"Device: {device}")
|
115 |
-
return torch.device(device)
|
116 |
-
|
117 |
-
|
118 |
-
def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
|
119 |
-
if seed is None:
|
120 |
-
return
|
121 |
-
random.seed(seed)
|
122 |
-
np.random.seed(seed)
|
123 |
-
torch.manual_seed(seed)
|
124 |
-
torch.backends.cudnn.benchmark = False
|
125 |
-
torch.use_deterministic_algorithms(use_deterministic_algorithms)
|
126 |
-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
127 |
-
# Stop warning and it would introduce stochasticity if I was using TF
|
128 |
-
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
129 |
-
|
130 |
-
|
131 |
-
def make_policy(
|
132 |
-
algo: str,
|
133 |
-
env: VecEnv,
|
134 |
-
device: torch.device,
|
135 |
-
load_path: Optional[str] = None,
|
136 |
-
**kwargs,
|
137 |
-
) -> Policy:
|
138 |
-
policy = POLICIES[algo](env, **kwargs).to(device)
|
139 |
-
if load_path:
|
140 |
-
policy.load(load_path)
|
141 |
-
return policy
|
142 |
-
|
143 |
-
|
144 |
-
def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
|
145 |
-
figure = plt.figure()
|
146 |
-
cumulative_steps = [
|
147 |
-
(idx + 1) * callback.step_freq for idx in range(len(callback.stats))
|
148 |
-
]
|
149 |
-
plt.plot(
|
150 |
-
cumulative_steps,
|
151 |
-
[s.score.mean for s in callback.stats],
|
152 |
-
"b-",
|
153 |
-
label="mean",
|
154 |
-
)
|
155 |
-
plt.plot(
|
156 |
-
cumulative_steps,
|
157 |
-
[s.score.mean - s.score.std for s in callback.stats],
|
158 |
-
"g--",
|
159 |
-
label="mean-std",
|
160 |
-
)
|
161 |
-
plt.fill_between(
|
162 |
-
cumulative_steps,
|
163 |
-
[s.score.min for s in callback.stats], # type: ignore
|
164 |
-
[s.score.max for s in callback.stats], # type: ignore
|
165 |
-
facecolor="cyan",
|
166 |
-
label="range",
|
167 |
-
)
|
168 |
-
plt.xlabel("Steps")
|
169 |
-
plt.ylabel("Score")
|
170 |
-
plt.legend()
|
171 |
-
plt.title(f"Eval {run_name}")
|
172 |
-
tb_writer.add_figure("eval", figure)
|
173 |
-
|
174 |
-
|
175 |
-
Scalar = Union[bool, str, float, int, None]
|
176 |
-
|
177 |
-
|
178 |
-
def hparam_dict(
|
179 |
-
hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
|
180 |
-
) -> Dict[str, Scalar]:
|
181 |
-
flattened = args.copy()
|
182 |
-
for k, v in flattened.items():
|
183 |
-
if isinstance(v, list):
|
184 |
-
flattened[k] = json.dumps(v)
|
185 |
-
for k, v in hyperparams.items():
|
186 |
-
if isinstance(v, dict):
|
187 |
-
for sk, sv in v.items():
|
188 |
-
key = f"{k}/{sk}"
|
189 |
-
if isinstance(sv, dict) or isinstance(sv, list):
|
190 |
-
flattened[key] = str(sv)
|
191 |
-
else:
|
192 |
-
flattened[key] = sv
|
193 |
-
else:
|
194 |
-
flattened[k] = v # type: ignore
|
195 |
-
return flattened # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
runner/train.py
DELETED
@@ -1,141 +0,0 @@
|
|
1 |
-
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
|
2 |
-
import os
|
3 |
-
|
4 |
-
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
-
|
6 |
-
import dataclasses
|
7 |
-
import shutil
|
8 |
-
import wandb
|
9 |
-
import yaml
|
10 |
-
|
11 |
-
from dataclasses import dataclass
|
12 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
13 |
-
from typing import Any, Dict, Optional, Sequence
|
14 |
-
|
15 |
-
from shared.callbacks.eval_callback import EvalCallback
|
16 |
-
from runner.config import Config, EnvHyperparams, RunArgs
|
17 |
-
from runner.env import make_env, make_eval_env
|
18 |
-
from runner.running_utils import (
|
19 |
-
ALGOS,
|
20 |
-
load_hyperparams,
|
21 |
-
set_seeds,
|
22 |
-
get_device,
|
23 |
-
make_policy,
|
24 |
-
plot_eval_callback,
|
25 |
-
hparam_dict,
|
26 |
-
)
|
27 |
-
from shared.stats import EpisodesStats
|
28 |
-
|
29 |
-
|
30 |
-
@dataclass
|
31 |
-
class TrainArgs(RunArgs):
|
32 |
-
wandb_project_name: Optional[str] = None
|
33 |
-
wandb_entity: Optional[str] = None
|
34 |
-
wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
|
35 |
-
|
36 |
-
|
37 |
-
def train(args: TrainArgs):
|
38 |
-
print(args)
|
39 |
-
hyperparams = load_hyperparams(args.algo, args.env, os.getcwd())
|
40 |
-
print(hyperparams)
|
41 |
-
config = Config(args, hyperparams, os.getcwd())
|
42 |
-
|
43 |
-
wandb_enabled = args.wandb_project_name
|
44 |
-
if wandb_enabled:
|
45 |
-
wandb.tensorboard.patch(
|
46 |
-
root_logdir=config.tensorboard_summary_path, pytorch=True
|
47 |
-
)
|
48 |
-
wandb.init(
|
49 |
-
project=args.wandb_project_name,
|
50 |
-
entity=args.wandb_entity,
|
51 |
-
config=hyperparams, # type: ignore
|
52 |
-
name=config.run_name,
|
53 |
-
monitor_gym=True,
|
54 |
-
save_code=True,
|
55 |
-
tags=args.wandb_tags,
|
56 |
-
)
|
57 |
-
wandb.config.update(args)
|
58 |
-
|
59 |
-
tb_writer = SummaryWriter(config.tensorboard_summary_path)
|
60 |
-
|
61 |
-
set_seeds(args.seed, args.use_deterministic_algorithms)
|
62 |
-
|
63 |
-
env = make_env(
|
64 |
-
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
65 |
-
)
|
66 |
-
device = get_device(config.device, env)
|
67 |
-
policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
|
68 |
-
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
69 |
-
|
70 |
-
num_parameters = policy.num_parameters()
|
71 |
-
num_trainable_parameters = policy.num_trainable_parameters()
|
72 |
-
if wandb_enabled:
|
73 |
-
wandb.run.summary["num_parameters"] = num_parameters
|
74 |
-
wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters
|
75 |
-
else:
|
76 |
-
print(
|
77 |
-
f"num_parameters = {num_parameters} ; "
|
78 |
-
f"num_trainable_parameters = {num_trainable_parameters}"
|
79 |
-
)
|
80 |
-
|
81 |
-
eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
|
82 |
-
record_best_videos = config.eval_params.get("record_best_videos", True)
|
83 |
-
callback = EvalCallback(
|
84 |
-
policy,
|
85 |
-
eval_env,
|
86 |
-
tb_writer,
|
87 |
-
best_model_path=config.model_dir_path(best=True),
|
88 |
-
**config.eval_params,
|
89 |
-
video_env=make_eval_env(
|
90 |
-
config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
|
91 |
-
)
|
92 |
-
if record_best_videos
|
93 |
-
else None,
|
94 |
-
best_video_dir=config.best_videos_dir,
|
95 |
-
)
|
96 |
-
algo.learn(config.n_timesteps, callback=callback)
|
97 |
-
|
98 |
-
policy.save(config.model_dir_path(best=False))
|
99 |
-
|
100 |
-
eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
|
101 |
-
|
102 |
-
plot_eval_callback(callback, tb_writer, config.run_name)
|
103 |
-
|
104 |
-
log_dict: Dict[str, Any] = {
|
105 |
-
"eval": eval_stats._asdict(),
|
106 |
-
}
|
107 |
-
if callback.best:
|
108 |
-
log_dict["best_eval"] = callback.best._asdict()
|
109 |
-
log_dict.update(hyperparams)
|
110 |
-
log_dict.update(vars(args))
|
111 |
-
with open(config.logs_path, "a") as f:
|
112 |
-
yaml.dump({config.run_name: log_dict}, f)
|
113 |
-
|
114 |
-
best_eval_stats: EpisodesStats = callback.best # type: ignore
|
115 |
-
tb_writer.add_hparams(
|
116 |
-
hparam_dict(hyperparams, vars(args)),
|
117 |
-
{
|
118 |
-
"hparam/best_mean": best_eval_stats.score.mean,
|
119 |
-
"hparam/best_result": best_eval_stats.score.mean
|
120 |
-
- best_eval_stats.score.std,
|
121 |
-
"hparam/last_mean": eval_stats.score.mean,
|
122 |
-
"hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
|
123 |
-
},
|
124 |
-
None,
|
125 |
-
config.run_name,
|
126 |
-
)
|
127 |
-
|
128 |
-
tb_writer.close()
|
129 |
-
|
130 |
-
if wandb_enabled:
|
131 |
-
shutil.make_archive(
|
132 |
-
os.path.join(wandb.run.dir, config.model_dir_name()),
|
133 |
-
"zip",
|
134 |
-
config.model_dir_path(),
|
135 |
-
)
|
136 |
-
shutil.make_archive(
|
137 |
-
os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
|
138 |
-
"zip",
|
139 |
-
config.model_dir_path(best=True),
|
140 |
-
)
|
141 |
-
wandb.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shared/algorithm.py
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
import gym
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from abc import ABC, abstractmethod
|
5 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
6 |
-
from typing import List, Optional, TypeVar
|
7 |
-
|
8 |
-
from shared.callbacks.callback import Callback
|
9 |
-
from shared.policy.policy import Policy
|
10 |
-
from wrappers.vectorable_wrapper import VecEnv
|
11 |
-
|
12 |
-
AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
|
13 |
-
|
14 |
-
|
15 |
-
class Algorithm(ABC):
|
16 |
-
@abstractmethod
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
policy: Policy,
|
20 |
-
env: VecEnv,
|
21 |
-
device: torch.device,
|
22 |
-
tb_writer: SummaryWriter,
|
23 |
-
**kwargs,
|
24 |
-
) -> None:
|
25 |
-
super().__init__()
|
26 |
-
self.policy = policy
|
27 |
-
self.env = env
|
28 |
-
self.device = device
|
29 |
-
self.tb_writer = tb_writer
|
30 |
-
|
31 |
-
@abstractmethod
|
32 |
-
def learn(
|
33 |
-
self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None
|
34 |
-
) -> AlgorithmSelf:
|
35 |
-
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shared/callbacks/callback.py
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
-
|
3 |
-
|
4 |
-
class Callback(ABC):
|
5 |
-
|
6 |
-
def __init__(self) -> None:
|
7 |
-
super().__init__()
|
8 |
-
self.timesteps_elapsed = 0
|
9 |
-
|
10 |
-
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
11 |
-
self.timesteps_elapsed += timesteps_elapsed
|
12 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shared/callbacks/eval_callback.py
DELETED
@@ -1,199 +0,0 @@
|
|
1 |
-
import itertools
|
2 |
-
import numpy as np
|
3 |
-
import os
|
4 |
-
|
5 |
-
from time import perf_counter
|
6 |
-
from torch.utils.tensorboard.writer import SummaryWriter
|
7 |
-
from typing import List, Optional, Union
|
8 |
-
|
9 |
-
from shared.callbacks.callback import Callback
|
10 |
-
from shared.policy.policy import Policy
|
11 |
-
from shared.stats import Episode, EpisodeAccumulator, EpisodesStats
|
12 |
-
from wrappers.vec_episode_recorder import VecEpisodeRecorder
|
13 |
-
from wrappers.vectorable_wrapper import VecEnv
|
14 |
-
|
15 |
-
|
16 |
-
class EvaluateAccumulator(EpisodeAccumulator):
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
num_envs: int,
|
20 |
-
goal_episodes: int,
|
21 |
-
print_returns: bool = True,
|
22 |
-
ignore_first_episode: bool = False,
|
23 |
-
):
|
24 |
-
super().__init__(num_envs)
|
25 |
-
self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
|
26 |
-
self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
|
27 |
-
self.print_returns = print_returns
|
28 |
-
if ignore_first_episode:
|
29 |
-
first_done = set()
|
30 |
-
|
31 |
-
def should_record_done(idx: int) -> bool:
|
32 |
-
has_done_first_episode = idx in first_done
|
33 |
-
first_done.add(idx)
|
34 |
-
return has_done_first_episode
|
35 |
-
|
36 |
-
self.should_record_done = should_record_done
|
37 |
-
else:
|
38 |
-
self.should_record_done = lambda idx: True
|
39 |
-
|
40 |
-
def on_done(self, ep_idx: int, episode: Episode) -> None:
|
41 |
-
if (
|
42 |
-
self.should_record_done(ep_idx)
|
43 |
-
and len(self.completed_episodes_by_env_idx[ep_idx])
|
44 |
-
>= self.goal_episodes_per_env
|
45 |
-
):
|
46 |
-
return
|
47 |
-
self.completed_episodes_by_env_idx[ep_idx].append(episode)
|
48 |
-
if self.print_returns:
|
49 |
-
print(
|
50 |
-
f"Episode {len(self)} | "
|
51 |
-
f"Score {episode.score} | "
|
52 |
-
f"Length {episode.length}"
|
53 |
-
)
|
54 |
-
|
55 |
-
def __len__(self) -> int:
|
56 |
-
return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
|
57 |
-
|
58 |
-
@property
|
59 |
-
def episodes(self) -> List[Episode]:
|
60 |
-
return list(itertools.chain(*self.completed_episodes_by_env_idx))
|
61 |
-
|
62 |
-
def is_done(self) -> bool:
|
63 |
-
return all(
|
64 |
-
len(ce) == self.goal_episodes_per_env
|
65 |
-
for ce in self.completed_episodes_by_env_idx
|
66 |
-
)
|
67 |
-
|
68 |
-
|
69 |
-
def evaluate(
|
70 |
-
env: VecEnv,
|
71 |
-
policy: Policy,
|
72 |
-
n_episodes: int,
|
73 |
-
render: bool = False,
|
74 |
-
deterministic: bool = True,
|
75 |
-
print_returns: bool = True,
|
76 |
-
ignore_first_episode: bool = False,
|
77 |
-
) -> EpisodesStats:
|
78 |
-
policy.sync_normalization(env)
|
79 |
-
policy.eval()
|
80 |
-
|
81 |
-
episodes = EvaluateAccumulator(
|
82 |
-
env.num_envs, n_episodes, print_returns, ignore_first_episode
|
83 |
-
)
|
84 |
-
|
85 |
-
obs = env.reset()
|
86 |
-
while not episodes.is_done():
|
87 |
-
act = policy.act(obs, deterministic=deterministic)
|
88 |
-
obs, rew, done, _ = env.step(act)
|
89 |
-
episodes.step(rew, done)
|
90 |
-
if render:
|
91 |
-
env.render()
|
92 |
-
stats = EpisodesStats(episodes.episodes)
|
93 |
-
if print_returns:
|
94 |
-
print(stats)
|
95 |
-
return stats
|
96 |
-
|
97 |
-
|
98 |
-
class EvalCallback(Callback):
|
99 |
-
def __init__(
|
100 |
-
self,
|
101 |
-
policy: Policy,
|
102 |
-
env: VecEnv,
|
103 |
-
tb_writer: SummaryWriter,
|
104 |
-
best_model_path: Optional[str] = None,
|
105 |
-
step_freq: Union[int, float] = 50_000,
|
106 |
-
n_episodes: int = 10,
|
107 |
-
save_best: bool = True,
|
108 |
-
deterministic: bool = True,
|
109 |
-
record_best_videos: bool = True,
|
110 |
-
video_env: Optional[VecEnv] = None,
|
111 |
-
best_video_dir: Optional[str] = None,
|
112 |
-
max_video_length: int = 3600,
|
113 |
-
ignore_first_episode: bool = False,
|
114 |
-
) -> None:
|
115 |
-
super().__init__()
|
116 |
-
self.policy = policy
|
117 |
-
self.env = env
|
118 |
-
self.tb_writer = tb_writer
|
119 |
-
self.best_model_path = best_model_path
|
120 |
-
self.step_freq = int(step_freq)
|
121 |
-
self.n_episodes = n_episodes
|
122 |
-
self.save_best = save_best
|
123 |
-
self.deterministic = deterministic
|
124 |
-
self.stats: List[EpisodesStats] = []
|
125 |
-
self.best = None
|
126 |
-
|
127 |
-
self.record_best_videos = record_best_videos
|
128 |
-
assert video_env or not record_best_videos
|
129 |
-
self.video_env = video_env
|
130 |
-
assert best_video_dir or not record_best_videos
|
131 |
-
self.best_video_dir = best_video_dir
|
132 |
-
if best_video_dir:
|
133 |
-
os.makedirs(best_video_dir, exist_ok=True)
|
134 |
-
self.max_video_length = max_video_length
|
135 |
-
self.best_video_base_path = None
|
136 |
-
|
137 |
-
self.ignore_first_episode = ignore_first_episode
|
138 |
-
|
139 |
-
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
140 |
-
super().on_step(timesteps_elapsed)
|
141 |
-
if self.timesteps_elapsed // self.step_freq >= len(self.stats):
|
142 |
-
self.evaluate()
|
143 |
-
return True
|
144 |
-
|
145 |
-
def evaluate(
|
146 |
-
self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
|
147 |
-
) -> EpisodesStats:
|
148 |
-
start_time = perf_counter()
|
149 |
-
eval_stat = evaluate(
|
150 |
-
self.env,
|
151 |
-
self.policy,
|
152 |
-
n_episodes or self.n_episodes,
|
153 |
-
deterministic=self.deterministic,
|
154 |
-
print_returns=print_returns or False,
|
155 |
-
ignore_first_episode=self.ignore_first_episode,
|
156 |
-
)
|
157 |
-
end_time = perf_counter()
|
158 |
-
self.tb_writer.add_scalar(
|
159 |
-
"eval/steps_per_second",
|
160 |
-
eval_stat.length.sum() / (end_time - start_time),
|
161 |
-
self.timesteps_elapsed,
|
162 |
-
)
|
163 |
-
self.policy.train(True)
|
164 |
-
print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
|
165 |
-
|
166 |
-
self.stats.append(eval_stat)
|
167 |
-
|
168 |
-
if not self.best or eval_stat >= self.best:
|
169 |
-
strictly_better = not self.best or eval_stat > self.best
|
170 |
-
self.best = eval_stat
|
171 |
-
if self.save_best:
|
172 |
-
assert self.best_model_path
|
173 |
-
self.policy.save(self.best_model_path)
|
174 |
-
print("Saved best model")
|
175 |
-
self.best.write_to_tensorboard(
|
176 |
-
self.tb_writer, "best_eval", self.timesteps_elapsed
|
177 |
-
)
|
178 |
-
if strictly_better and self.record_best_videos:
|
179 |
-
assert self.video_env and self.best_video_dir
|
180 |
-
self.best_video_base_path = os.path.join(
|
181 |
-
self.best_video_dir, str(self.timesteps_elapsed)
|
182 |
-
)
|
183 |
-
video_wrapped = VecEpisodeRecorder(
|
184 |
-
self.video_env,
|
185 |
-
self.best_video_base_path,
|
186 |
-
max_video_length=self.max_video_length,
|
187 |
-
)
|
188 |
-
video_stats = evaluate(
|
189 |
-
video_wrapped,
|
190 |
-
self.policy,
|
191 |
-
1,
|
192 |
-
deterministic=self.deterministic,
|
193 |
-
print_returns=False,
|
194 |
-
)
|
195 |
-
print(f"Saved best video: {video_stats}")
|
196 |
-
|
197 |
-
eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
|
198 |
-
|
199 |
-
return eval_stat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shared/gae.py
DELETED
@@ -1,67 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from typing import NamedTuple, Sequence
|
5 |
-
|
6 |
-
from shared.policy.on_policy import OnPolicy
|
7 |
-
from shared.trajectory import Trajectory
|
8 |
-
|
9 |
-
|
10 |
-
class RtgAdvantage(NamedTuple):
|
11 |
-
rewards_to_go: torch.Tensor
|
12 |
-
advantage: torch.Tensor
|
13 |
-
|
14 |
-
|
15 |
-
def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
|
16 |
-
dc = x.copy()
|
17 |
-
for i in reversed(range(len(x) - 1)):
|
18 |
-
dc[i] += gamma * dc[i + 1]
|
19 |
-
return dc
|
20 |
-
|
21 |
-
|
22 |
-
def compute_advantage(
|
23 |
-
trajectories: Sequence[Trajectory],
|
24 |
-
policy: OnPolicy,
|
25 |
-
gamma: float,
|
26 |
-
gae_lambda: float,
|
27 |
-
device: torch.device,
|
28 |
-
) -> torch.Tensor:
|
29 |
-
advantage = []
|
30 |
-
for traj in trajectories:
|
31 |
-
last_val = 0
|
32 |
-
if not traj.terminated and traj.next_obs is not None:
|
33 |
-
last_val = policy.value(traj.next_obs)
|
34 |
-
rew = np.append(np.array(traj.rew), last_val)
|
35 |
-
v = np.append(np.array(traj.v), last_val)
|
36 |
-
deltas = rew[:-1] + gamma * v[1:] - v[:-1]
|
37 |
-
advantage.append(discounted_cumsum(deltas, gamma * gae_lambda))
|
38 |
-
return torch.as_tensor(
|
39 |
-
np.concatenate(advantage), dtype=torch.float32, device=device
|
40 |
-
)
|
41 |
-
|
42 |
-
|
43 |
-
def compute_rtg_and_advantage(
|
44 |
-
trajectories: Sequence[Trajectory],
|
45 |
-
policy: OnPolicy,
|
46 |
-
gamma: float,
|
47 |
-
gae_lambda: float,
|
48 |
-
device: torch.device,
|
49 |
-
) -> RtgAdvantage:
|
50 |
-
rewards_to_go = []
|
51 |
-
advantages = []
|
52 |
-
for traj in trajectories:
|
53 |
-
last_val = 0
|
54 |
-
if not traj.terminated and traj.next_obs is not None:
|
55 |
-
last_val = policy.value(traj.next_obs)
|
56 |
-
rew = np.append(np.array(traj.rew), last_val)
|
57 |
-
v = np.append(np.array(traj.v), last_val)
|
58 |
-
deltas = rew[:-1] + gamma * v[1:] - v[:-1]
|
59 |
-
adv = discounted_cumsum(deltas, gamma * gae_lambda)
|
60 |
-
advantages.append(adv)
|
61 |
-
rewards_to_go.append(v[:-1] + adv)
|
62 |
-
return RtgAdvantage(
|
63 |
-
torch.as_tensor(
|
64 |
-
np.concatenate(rewards_to_go), dtype=torch.float32, device=device
|
65 |
-
),
|
66 |
-
torch.as_tensor(np.concatenate(advantages), dtype=torch.float32, device=device),
|
67 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shared/module/feature_extractor.py
DELETED
@@ -1,215 +0,0 @@
|
|
1 |
-
import gym
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from abc import ABC, abstractmethod
|
7 |
-
from gym.spaces import Box, Discrete
|
8 |
-
from stable_baselines3.common.preprocessing import get_flattened_obs_dim
|
9 |
-
from typing import Dict, Optional, Sequence, Type
|
10 |
-
|
11 |
-
from shared.module.module import layer_init
|
12 |
-
|
13 |
-
|
14 |
-
class CnnFeatureExtractor(nn.Module, ABC):
|
15 |
-
@abstractmethod
|
16 |
-
def __init__(
|
17 |
-
self,
|
18 |
-
in_channels: int,
|
19 |
-
activation: Type[nn.Module] = nn.ReLU,
|
20 |
-
init_layers_orthogonal: Optional[bool] = None,
|
21 |
-
**kwargs,
|
22 |
-
) -> None:
|
23 |
-
super().__init__()
|
24 |
-
|
25 |
-
|
26 |
-
class NatureCnn(CnnFeatureExtractor):
|
27 |
-
"""
|
28 |
-
CNN from DQN Nature paper: Mnih, Volodymyr, et al.
|
29 |
-
"Human-level control through deep reinforcement learning."
|
30 |
-
Nature 518.7540 (2015): 529-533.
|
31 |
-
"""
|
32 |
-
|
33 |
-
def __init__(
|
34 |
-
self,
|
35 |
-
in_channels: int,
|
36 |
-
activation: Type[nn.Module] = nn.ReLU,
|
37 |
-
init_layers_orthogonal: Optional[bool] = None,
|
38 |
-
**kwargs,
|
39 |
-
) -> None:
|
40 |
-
if init_layers_orthogonal is None:
|
41 |
-
init_layers_orthogonal = True
|
42 |
-
super().__init__(in_channels, activation, init_layers_orthogonal)
|
43 |
-
self.cnn = nn.Sequential(
|
44 |
-
layer_init(
|
45 |
-
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
|
46 |
-
init_layers_orthogonal,
|
47 |
-
),
|
48 |
-
activation(),
|
49 |
-
layer_init(
|
50 |
-
nn.Conv2d(32, 64, kernel_size=4, stride=2),
|
51 |
-
init_layers_orthogonal,
|
52 |
-
),
|
53 |
-
activation(),
|
54 |
-
layer_init(
|
55 |
-
nn.Conv2d(64, 64, kernel_size=3, stride=1),
|
56 |
-
init_layers_orthogonal,
|
57 |
-
),
|
58 |
-
activation(),
|
59 |
-
nn.Flatten(),
|
60 |
-
)
|
61 |
-
|
62 |
-
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
63 |
-
return self.cnn(obs)
|
64 |
-
|
65 |
-
|
66 |
-
class ResidualBlock(nn.Module):
|
67 |
-
def __init__(
|
68 |
-
self,
|
69 |
-
channels: int,
|
70 |
-
activation: Type[nn.Module] = nn.ReLU,
|
71 |
-
init_layers_orthogonal: bool = False,
|
72 |
-
) -> None:
|
73 |
-
super().__init__()
|
74 |
-
self.residual = nn.Sequential(
|
75 |
-
activation(),
|
76 |
-
layer_init(
|
77 |
-
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
|
78 |
-
),
|
79 |
-
activation(),
|
80 |
-
layer_init(
|
81 |
-
nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal
|
82 |
-
),
|
83 |
-
)
|
84 |
-
|
85 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86 |
-
return x + self.residual(x)
|
87 |
-
|
88 |
-
|
89 |
-
class ConvSequence(nn.Module):
|
90 |
-
def __init__(
|
91 |
-
self,
|
92 |
-
in_channels: int,
|
93 |
-
out_channels: int,
|
94 |
-
activation: Type[nn.Module] = nn.ReLU,
|
95 |
-
init_layers_orthogonal: bool = False,
|
96 |
-
) -> None:
|
97 |
-
super().__init__()
|
98 |
-
self.seq = nn.Sequential(
|
99 |
-
layer_init(
|
100 |
-
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
101 |
-
init_layers_orthogonal,
|
102 |
-
),
|
103 |
-
nn.MaxPool2d(3, stride=2, padding=1),
|
104 |
-
ResidualBlock(out_channels, activation, init_layers_orthogonal),
|
105 |
-
ResidualBlock(out_channels, activation, init_layers_orthogonal),
|
106 |
-
)
|
107 |
-
|
108 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
109 |
-
return self.seq(x)
|
110 |
-
|
111 |
-
|
112 |
-
class ImpalaCnn(CnnFeatureExtractor):
|
113 |
-
"""
|
114 |
-
IMPALA-style CNN architecture
|
115 |
-
"""
|
116 |
-
|
117 |
-
def __init__(
|
118 |
-
self,
|
119 |
-
in_channels: int,
|
120 |
-
activation: Type[nn.Module] = nn.ReLU,
|
121 |
-
init_layers_orthogonal: Optional[bool] = None,
|
122 |
-
impala_channels: Sequence[int] = (16, 32, 32),
|
123 |
-
**kwargs,
|
124 |
-
) -> None:
|
125 |
-
if init_layers_orthogonal is None:
|
126 |
-
init_layers_orthogonal = False
|
127 |
-
super().__init__(in_channels, activation, init_layers_orthogonal)
|
128 |
-
sequences = []
|
129 |
-
for out_channels in impala_channels:
|
130 |
-
sequences.append(
|
131 |
-
ConvSequence(
|
132 |
-
in_channels, out_channels, activation, init_layers_orthogonal
|
133 |
-
)
|
134 |
-
)
|
135 |
-
in_channels = out_channels
|
136 |
-
sequences.extend(
|
137 |
-
[
|
138 |
-
activation(),
|
139 |
-
nn.Flatten(),
|
140 |
-
]
|
141 |
-
)
|
142 |
-
self.seq = nn.Sequential(*sequences)
|
143 |
-
|
144 |
-
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
145 |
-
return self.seq(obs)
|
146 |
-
|
147 |
-
|
148 |
-
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = {
|
149 |
-
"nature": NatureCnn,
|
150 |
-
"impala": ImpalaCnn,
|
151 |
-
}
|
152 |
-
|
153 |
-
|
154 |
-
class FeatureExtractor(nn.Module):
|
155 |
-
def __init__(
|
156 |
-
self,
|
157 |
-
obs_space: gym.Space,
|
158 |
-
activation: Type[nn.Module],
|
159 |
-
init_layers_orthogonal: bool = False,
|
160 |
-
cnn_feature_dim: int = 512,
|
161 |
-
cnn_style: str = "nature",
|
162 |
-
cnn_layers_init_orthogonal: Optional[bool] = None,
|
163 |
-
impala_channels: Sequence[int] = (16, 32, 32),
|
164 |
-
) -> None:
|
165 |
-
super().__init__()
|
166 |
-
if isinstance(obs_space, Box):
|
167 |
-
# Conv2D: (channels, height, width)
|
168 |
-
if len(obs_space.shape) == 3:
|
169 |
-
cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style](
|
170 |
-
obs_space.shape[0],
|
171 |
-
activation,
|
172 |
-
init_layers_orthogonal=cnn_layers_init_orthogonal,
|
173 |
-
impala_channels=impala_channels,
|
174 |
-
)
|
175 |
-
|
176 |
-
def preprocess(obs: torch.Tensor) -> torch.Tensor:
|
177 |
-
if len(obs.shape) == 3:
|
178 |
-
obs = obs.unsqueeze(0)
|
179 |
-
return obs.float() / 255.0
|
180 |
-
|
181 |
-
with torch.no_grad():
|
182 |
-
cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
|
183 |
-
self.preprocess = preprocess
|
184 |
-
self.feature_extractor = nn.Sequential(
|
185 |
-
cnn,
|
186 |
-
layer_init(
|
187 |
-
nn.Linear(cnn_out.shape[1], cnn_feature_dim),
|
188 |
-
init_layers_orthogonal,
|
189 |
-
),
|
190 |
-
activation(),
|
191 |
-
)
|
192 |
-
self.out_dim = cnn_feature_dim
|
193 |
-
elif len(obs_space.shape) == 1:
|
194 |
-
|
195 |
-
def preprocess(obs: torch.Tensor) -> torch.Tensor:
|
196 |
-
if len(obs.shape) == 1:
|
197 |
-
obs = obs.unsqueeze(0)
|
198 |
-
return obs.float()
|
199 |
-
|
200 |
-
self.preprocess = preprocess
|
201 |
-
self.feature_extractor = nn.Flatten()
|
202 |
-
self.out_dim = get_flattened_obs_dim(obs_space)
|
203 |
-
else:
|
204 |
-
raise ValueError(f"Unsupported observation space: {obs_space}")
|
205 |
-
elif isinstance(obs_space, Discrete):
|
206 |
-
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
|
207 |
-
self.feature_extractor = nn.Flatten()
|
208 |
-
self.out_dim = obs_space.n
|
209 |
-
else:
|
210 |
-
raise NotImplementedError
|
211 |
-
|
212 |
-
def forward(self, obs: torch.Tensor) -> torch.Tensor:
|
213 |
-
if self.preprocess:
|
214 |
-
obs = self.preprocess(obs)
|
215 |
-
return self.feature_extractor(obs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|