Commit 
							
							·
						
						341188c
	
1
								Parent(s):
							
							f995b88
								
PPO playing Acrobot-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +147 -0
- LICENSE +21 -0
- README.md +127 -0
- benchmarks/colab_atari1.sh +5 -0
- benchmarks/colab_atari2.sh +5 -0
- benchmarks/colab_basic.sh +5 -0
- benchmarks/colab_benchmark.ipynb +210 -0
- benchmarks/colab_carracing.sh +5 -0
- benchmarks/colab_pybullet.sh +5 -0
- benchmarks/train_loop.sh +17 -0
- colab_enjoy.ipynb +213 -0
- colab_requirements.txt +7 -0
- colab_train.ipynb +215 -0
- dqn/dqn.py +182 -0
- dqn/policy.py +37 -0
- dqn/q_net.py +29 -0
- enjoy.py +105 -0
- environment.yml +17 -0
- hyperparams/dqn.yml +117 -0
- hyperparams/ppo.yml +202 -0
- hyperparams/vpg.yml +157 -0
- lambda_labs/benchmark.sh +33 -0
- lambda_labs/lambda_requirements.txt +9 -0
- lambda_labs/setup.sh +10 -0
- poetry.lock +0 -0
- ppo/policy.py +36 -0
- ppo/ppo.py +367 -0
- pyproject.toml +27 -0
- replay.meta.json +1 -0
- replay.mp4 +0 -0
- runner/config.py +130 -0
- runner/env.py +134 -0
- runner/running_utils.py +188 -0
- runner/train.py +126 -0
- saved_models/ppo-Acrobot-v1-S4-best/model.pth +3 -0
- saved_models/ppo-Acrobot-v1-S4-best/vecnormalize.pkl +3 -0
- shared/algorithm.py +35 -0
- shared/callbacks/callback.py +12 -0
- shared/callbacks/eval_callback.py +174 -0
- shared/module.py +121 -0
- shared/policy/actor.py +304 -0
- shared/policy/critic.py +27 -0
- shared/policy/on_policy.py +177 -0
- shared/policy/policy.py +60 -0
- shared/schedule.py +19 -0
- shared/stats.py +173 -0
- shared/trajectory.py +30 -0
- shared/utils.py +8 -0
- train.py +81 -0
- vpg/policy.py +119 -0
    	
        .gitignore
    ADDED
    
    | @@ -0,0 +1,147 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 2 | 
            +
            __pycache__/
         | 
| 3 | 
            +
            *.py[cod]
         | 
| 4 | 
            +
            *$py.class
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # C extensions
         | 
| 7 | 
            +
            *.so
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # Distribution / packaging
         | 
| 10 | 
            +
            .Python
         | 
| 11 | 
            +
            build/
         | 
| 12 | 
            +
            develop-eggs/
         | 
| 13 | 
            +
            dist/
         | 
| 14 | 
            +
            downloads/
         | 
| 15 | 
            +
            eggs/
         | 
| 16 | 
            +
            .eggs/
         | 
| 17 | 
            +
            lib/
         | 
| 18 | 
            +
            lib64/
         | 
| 19 | 
            +
            parts/
         | 
| 20 | 
            +
            sdist/
         | 
| 21 | 
            +
            var/
         | 
| 22 | 
            +
            wheels/
         | 
| 23 | 
            +
            pip-wheel-metadata/
         | 
| 24 | 
            +
            share/python-wheels/
         | 
| 25 | 
            +
            *.egg-info/
         | 
| 26 | 
            +
            .installed.cfg
         | 
| 27 | 
            +
            *.egg
         | 
| 28 | 
            +
            MANIFEST
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # PyInstaller
         | 
| 31 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 32 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 33 | 
            +
            *.manifest
         | 
| 34 | 
            +
            *.spec
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Installer logs
         | 
| 37 | 
            +
            pip-log.txt
         | 
| 38 | 
            +
            pip-delete-this-directory.txt
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            # Unit test / coverage reports
         | 
| 41 | 
            +
            htmlcov/
         | 
| 42 | 
            +
            .tox/
         | 
| 43 | 
            +
            .nox/
         | 
| 44 | 
            +
            .coverage
         | 
| 45 | 
            +
            .coverage.*
         | 
| 46 | 
            +
            .cache
         | 
| 47 | 
            +
            nosetests.xml
         | 
| 48 | 
            +
            coverage.xml
         | 
| 49 | 
            +
            *.cover
         | 
| 50 | 
            +
            *.py,cover
         | 
| 51 | 
            +
            .hypothesis/
         | 
| 52 | 
            +
            .pytest_cache/
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            # Translations
         | 
| 55 | 
            +
            *.mo
         | 
| 56 | 
            +
            *.pot
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Django stuff:
         | 
| 59 | 
            +
            *.log
         | 
| 60 | 
            +
            local_settings.py
         | 
| 61 | 
            +
            db.sqlite3
         | 
| 62 | 
            +
            db.sqlite3-journal
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            # Flask stuff:
         | 
| 65 | 
            +
            instance/
         | 
| 66 | 
            +
            .webassets-cache
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            # Scrapy stuff:
         | 
| 69 | 
            +
            .scrapy
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            # Sphinx documentation
         | 
| 72 | 
            +
            docs/_build/
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            # PyBuilder
         | 
| 75 | 
            +
            target/
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            # Jupyter Notebook
         | 
| 78 | 
            +
            .ipynb_checkpoints
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            # IPython
         | 
| 81 | 
            +
            profile_default/
         | 
| 82 | 
            +
            ipython_config.py
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            # pyenv
         | 
| 85 | 
            +
            .python-version
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            # pipenv
         | 
| 88 | 
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         | 
| 89 | 
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         | 
| 90 | 
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         | 
| 91 | 
            +
            #   install all needed dependencies.
         | 
| 92 | 
            +
            #Pipfile.lock
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow
         | 
| 95 | 
            +
            __pypackages__/
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            # Celery stuff
         | 
| 98 | 
            +
            celerybeat-schedule
         | 
| 99 | 
            +
            celerybeat.pid
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            # SageMath parsed files
         | 
| 102 | 
            +
            *.sage.py
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            # Environments
         | 
| 105 | 
            +
            .env
         | 
| 106 | 
            +
            .venv
         | 
| 107 | 
            +
            env/
         | 
| 108 | 
            +
            venv/
         | 
| 109 | 
            +
            ENV/
         | 
| 110 | 
            +
            env.bak/
         | 
| 111 | 
            +
            venv.bak/
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            # Spyder project settings
         | 
| 114 | 
            +
            .spyderproject
         | 
| 115 | 
            +
            .spyproject
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            # Rope project settings
         | 
| 118 | 
            +
            .ropeproject
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            # mkdocs documentation
         | 
| 121 | 
            +
            /site
         | 
| 122 | 
            +
             | 
| 123 | 
            +
            # mypy
         | 
| 124 | 
            +
            .mypy_cache/
         | 
| 125 | 
            +
            .dmypy.json
         | 
| 126 | 
            +
            dmypy.json
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            # Pyre type checker
         | 
| 129 | 
            +
            .pyre/
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            # Logging into tensorboard and wandb
         | 
| 132 | 
            +
            runs/*
         | 
| 133 | 
            +
            wandb
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            # macOS
         | 
| 136 | 
            +
            .DS_STORE
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            # Local scratch work
         | 
| 139 | 
            +
            scratch/*
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            # vscode
         | 
| 142 | 
            +
            .vscode/
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            # Don't bother tracking saved_models or videos
         | 
| 145 | 
            +
            saved_models/*
         | 
| 146 | 
            +
            downloaded_models/*
         | 
| 147 | 
            +
            videos/*
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) 2023 Scott Goodfriend
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,127 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            library_name: rl-algo-impls
         | 
| 3 | 
            +
            tags:
         | 
| 4 | 
            +
            - Acrobot-v1
         | 
| 5 | 
            +
            - ppo
         | 
| 6 | 
            +
            - deep-reinforcement-learning
         | 
| 7 | 
            +
            - reinforcement-learning
         | 
| 8 | 
            +
            model-index:
         | 
| 9 | 
            +
            - name: ppo
         | 
| 10 | 
            +
              results:
         | 
| 11 | 
            +
              - metrics:
         | 
| 12 | 
            +
                - type: mean_reward
         | 
| 13 | 
            +
                  value: -72.5 +/- 7.68
         | 
| 14 | 
            +
                  name: mean_reward
         | 
| 15 | 
            +
                task:
         | 
| 16 | 
            +
                  type: reinforcement-learning
         | 
| 17 | 
            +
                  name: reinforcement-learning
         | 
| 18 | 
            +
                dataset:
         | 
| 19 | 
            +
                  name: Acrobot-v1
         | 
| 20 | 
            +
                  type: Acrobot-v1
         | 
| 21 | 
            +
            ---
         | 
| 22 | 
            +
            # **PPO** Agent playing **Acrobot-v1**
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            This is a trained model of a **PPO** agent playing **Acrobot-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/6p2sjqtn.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ## Training Results
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            | algo   | env        |   seed |   reward_mean |   reward_std |   eval_episodes | best   | wandb_url                                                                    |
         | 
| 33 | 
            +
            |:-------|:-----------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
         | 
| 34 | 
            +
            | ppo    | Acrobot-v1 |      4 |       -72.5   |      7.68115 |              16 | *      | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/bzab0jtv) |
         | 
| 35 | 
            +
            | ppo    | Acrobot-v1 |      5 |       -71.875 |      9.55167 |              16 |        | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/zqord0fg) |
         | 
| 36 | 
            +
            | ppo    | Acrobot-v1 |      6 |       -74.375 |     14.5081  |              16 |        | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/y1w2hqhu) |
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            ### Prerequisites: Weights & Biases (WandB)
         | 
| 40 | 
            +
            Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
         | 
| 41 | 
            +
            By default training goes to a rl-algo-impls project while benchmarks go to
         | 
| 42 | 
            +
            rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
         | 
| 43 | 
            +
            models and the model weights are uploaded to WandB.
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            Before doing any of the runs below, you'll need to create a wandb account and run `wandb
         | 
| 46 | 
            +
            login`.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            ## Usage
         | 
| 51 | 
            +
            /sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            Note: While the model state dictionary and hyperaparameters are saved, the
         | 
| 54 | 
            +
            implementation could be sufficiently different to not be able to reproduce similar
         | 
| 55 | 
            +
            results. You might need to checkout the commit the agent was trained on:
         | 
| 56 | 
            +
            [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
         | 
| 57 | 
            +
            ```
         | 
| 58 | 
            +
            # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
         | 
| 59 | 
            +
            python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/bzab0jtv
         | 
| 60 | 
            +
            ```
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            Setup hasn't been completely worked out yet, so you might be best served by using Google
         | 
| 63 | 
            +
            Colab starting from the
         | 
| 64 | 
            +
            [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
         | 
| 65 | 
            +
            notebook.
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            ## Training
         | 
| 70 | 
            +
            If you want the highest chance to reproduce these results, you'll want to checkout the
         | 
| 71 | 
            +
            commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            ```
         | 
| 74 | 
            +
            python train.py --algo ppo --env Acrobot-v1 --seed 4
         | 
| 75 | 
            +
            ```
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            Setup hasn't been completely worked out yet, so you might be best served by using Google
         | 
| 78 | 
            +
            Colab starting from the
         | 
| 79 | 
            +
            [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
         | 
| 80 | 
            +
            notebook.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            ## Benchmarking (with Lambda Labs instance)
         | 
| 85 | 
            +
            This and other models from https://api.wandb.ai/links/sgoodfriend/6p2sjqtn were generated by running a script on a Lambda
         | 
| 86 | 
            +
            Labs instance. In a Lambda Labs instance terminal:
         | 
| 87 | 
            +
            ```
         | 
| 88 | 
            +
            git clone git@github.com:sgoodfriend/rl-algo-impls.git
         | 
| 89 | 
            +
            cd rl-algo-impls
         | 
| 90 | 
            +
            bash ./lambda_labs/setup.sh
         | 
| 91 | 
            +
            wandb login
         | 
| 92 | 
            +
            bash ./lambda_labs/benchmark.sh
         | 
| 93 | 
            +
            ```
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            ### Alternative: Google Colab Pro+
         | 
| 96 | 
            +
            As an alternative,
         | 
| 97 | 
            +
            [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
         | 
| 98 | 
            +
            can be used. However, this requires a Google Colab Pro+ subscription and running across
         | 
| 99 | 
            +
            4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
         | 
| 100 | 
            +
             | 
| 101 | 
            +
             | 
| 102 | 
            +
             | 
| 103 | 
            +
            ## Hyperparameters
         | 
| 104 | 
            +
            This isn't exactly the format of hyperparams in hyperparams/ppo.yml, but instead the Wandb Run Config. However, it's very
         | 
| 105 | 
            +
            close and has some additional data:
         | 
| 106 | 
            +
            ```
         | 
| 107 | 
            +
            algo: ppo
         | 
| 108 | 
            +
            algo_hyperparams:
         | 
| 109 | 
            +
              ent_coef: 0
         | 
| 110 | 
            +
              gae_lambda: 0.94
         | 
| 111 | 
            +
              gamma: 0.99
         | 
| 112 | 
            +
              n_epochs: 4
         | 
| 113 | 
            +
              n_steps: 256
         | 
| 114 | 
            +
            env: Acrobot-v1
         | 
| 115 | 
            +
            env_hyperparams:
         | 
| 116 | 
            +
              n_envs: 16
         | 
| 117 | 
            +
              normalize: true
         | 
| 118 | 
            +
            n_timesteps: 1000000
         | 
| 119 | 
            +
            seed: 4
         | 
| 120 | 
            +
            use_deterministic_algorithms: true
         | 
| 121 | 
            +
            wandb_entity: null
         | 
| 122 | 
            +
            wandb_project_name: rl-algo-impls-benchmarks
         | 
| 123 | 
            +
            wandb_tags:
         | 
| 124 | 
            +
            - benchmark_5598ebc
         | 
| 125 | 
            +
            - host_192-9-145-26
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            ```
         | 
    	
        benchmarks/colab_atari1.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            source benchmarks/train_loop.sh
         | 
| 2 | 
            +
            ALGOS="ppo"
         | 
| 3 | 
            +
            ENVS="PongNoFrameskip-v4 BreakoutNoFrameskip-v4"
         | 
| 4 | 
            +
            BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
         | 
| 5 | 
            +
            train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
         | 
    	
        benchmarks/colab_atari2.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            source benchmarks/train_loop.sh
         | 
| 2 | 
            +
            ALGOS="ppo"
         | 
| 3 | 
            +
            ENVS="SpaceInvadersNoFrameskip-v4 QbertNoFrameskip-v4"
         | 
| 4 | 
            +
            BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
         | 
| 5 | 
            +
            train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
         | 
    	
        benchmarks/colab_basic.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            source benchmarks/train_loop.sh
         | 
| 2 | 
            +
            ALGOS="ppo"
         | 
| 3 | 
            +
            ENVS="CartPole-v1 MountainCar-v0 MountainCarContinuous-v0 Acrobot-v1 LunarLander-v2"
         | 
| 4 | 
            +
            BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
         | 
| 5 | 
            +
            train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
         | 
    	
        benchmarks/colab_benchmark.ipynb
    ADDED
    
    | @@ -0,0 +1,210 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "nbformat": 4,
         | 
| 3 | 
            +
              "nbformat_minor": 0,
         | 
| 4 | 
            +
              "metadata": {
         | 
| 5 | 
            +
                "colab": {
         | 
| 6 | 
            +
                  "provenance": [],
         | 
| 7 | 
            +
                  "machine_shape": "hm",
         | 
| 8 | 
            +
                  "authorship_tag": "ABX9TyMJFprw7XNl/BqbKAHd/483",
         | 
| 9 | 
            +
                  "include_colab_link": true
         | 
| 10 | 
            +
                },
         | 
| 11 | 
            +
                "kernelspec": {
         | 
| 12 | 
            +
                  "name": "python3",
         | 
| 13 | 
            +
                  "display_name": "Python 3"
         | 
| 14 | 
            +
                },
         | 
| 15 | 
            +
                "language_info": {
         | 
| 16 | 
            +
                  "name": "python"
         | 
| 17 | 
            +
                },
         | 
| 18 | 
            +
                "gpuClass": "standard",
         | 
| 19 | 
            +
                "accelerator": "GPU"
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              "cells": [
         | 
| 22 | 
            +
                {
         | 
| 23 | 
            +
                  "cell_type": "markdown",
         | 
| 24 | 
            +
                  "metadata": {
         | 
| 25 | 
            +
                    "id": "view-in-github",
         | 
| 26 | 
            +
                    "colab_type": "text"
         | 
| 27 | 
            +
                  },
         | 
| 28 | 
            +
                  "source": [
         | 
| 29 | 
            +
                    "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/benchmarks/colab_benchmark.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
         | 
| 30 | 
            +
                  ]
         | 
| 31 | 
            +
                },
         | 
| 32 | 
            +
                {
         | 
| 33 | 
            +
                  "cell_type": "markdown",
         | 
| 34 | 
            +
                  "source": [
         | 
| 35 | 
            +
                    "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
         | 
| 36 | 
            +
                    "## Parameters\n",
         | 
| 37 | 
            +
                    "\n",
         | 
| 38 | 
            +
                    "\n",
         | 
| 39 | 
            +
                    "1.   Wandb\n",
         | 
| 40 | 
            +
                    "\n"
         | 
| 41 | 
            +
                  ],
         | 
| 42 | 
            +
                  "metadata": {
         | 
| 43 | 
            +
                    "id": "S-tXDWP8WTLc"
         | 
| 44 | 
            +
                  }
         | 
| 45 | 
            +
                },
         | 
| 46 | 
            +
                {
         | 
| 47 | 
            +
                  "cell_type": "code",
         | 
| 48 | 
            +
                  "source": [
         | 
| 49 | 
            +
                    "from getpass import getpass\n",
         | 
| 50 | 
            +
                    "import os\n",
         | 
| 51 | 
            +
                    "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
         | 
| 52 | 
            +
                  ],
         | 
| 53 | 
            +
                  "metadata": {
         | 
| 54 | 
            +
                    "id": "1ZtdYgxWNGwZ"
         | 
| 55 | 
            +
                  },
         | 
| 56 | 
            +
                  "execution_count": null,
         | 
| 57 | 
            +
                  "outputs": []
         | 
| 58 | 
            +
                },
         | 
| 59 | 
            +
                {
         | 
| 60 | 
            +
                  "cell_type": "markdown",
         | 
| 61 | 
            +
                  "source": [
         | 
| 62 | 
            +
                    "## Setup\n",
         | 
| 63 | 
            +
                    "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
         | 
| 64 | 
            +
                  ],
         | 
| 65 | 
            +
                  "metadata": {
         | 
| 66 | 
            +
                    "id": "bsG35Io0hmKG"
         | 
| 67 | 
            +
                  }
         | 
| 68 | 
            +
                },
         | 
| 69 | 
            +
                {
         | 
| 70 | 
            +
                  "cell_type": "code",
         | 
| 71 | 
            +
                  "source": [
         | 
| 72 | 
            +
                    "%%capture\n",
         | 
| 73 | 
            +
                    "!mkdir -p ~/.ssh\n",
         | 
| 74 | 
            +
                    "\n",
         | 
| 75 | 
            +
                    "with open(\"/root/.ssh/id_ed25519\", mode=\"w\") as f:\n",
         | 
| 76 | 
            +
                    "    f.write(\"\"\"-----BEGIN OPENSSH PRIVATE KEY-----\n",
         | 
| 77 | 
            +
                    "b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\n",
         | 
| 78 | 
            +
                    "QyNTUxOQAAACAkIepH6T90umhyp8+bkmSplqth1/+yxnu/Dax61KlSVAAAAKA4W3D3OFtw\n",
         | 
| 79 | 
            +
                    "9wAAAAtzc2gtZWQyNTUxOQAAACAkIepH6T90umhyp8+bkmSplqth1/+yxnu/Dax61KlSVA\n",
         | 
| 80 | 
            +
                    "AAAEA4SPGDm0/gofiOYXPTAi1Oxmw4mTppG2GdNgdMwMiDaSQh6kfpP3S6aHKnz5uSZKmW\n",
         | 
| 81 | 
            +
                    "q2HX/7LGe78NrHrUqVJUAAAAGmdvb2RmcmllbmQuc2NvdHRAZ21haWwuY29tAQID\n",
         | 
| 82 | 
            +
                    "-----END OPENSSH PRIVATE KEY-----\n",
         | 
| 83 | 
            +
                    "\"\"\"\n",
         | 
| 84 | 
            +
                    ")\n",
         | 
| 85 | 
            +
                    "\n",
         | 
| 86 | 
            +
                    "!ssh-keyscan -t ed25519 github.com >> ~/.ssh/known_hosts\n",
         | 
| 87 | 
            +
                    "!chmod go-rwx /root/.ssh/id_ed25519\n",
         | 
| 88 | 
            +
                    "!git clone git@github.com:sgoodfriend/rl-algo-impls.git"
         | 
| 89 | 
            +
                  ],
         | 
| 90 | 
            +
                  "metadata": {
         | 
| 91 | 
            +
                    "id": "k5ynTV25hdAf"
         | 
| 92 | 
            +
                  },
         | 
| 93 | 
            +
                  "execution_count": null,
         | 
| 94 | 
            +
                  "outputs": []
         | 
| 95 | 
            +
                },
         | 
| 96 | 
            +
                {
         | 
| 97 | 
            +
                  "cell_type": "markdown",
         | 
| 98 | 
            +
                  "source": [
         | 
| 99 | 
            +
                    "Installing the correct packages:\n",
         | 
| 100 | 
            +
                    "\n",
         | 
| 101 | 
            +
                    "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
         | 
| 102 | 
            +
                  ],
         | 
| 103 | 
            +
                  "metadata": {
         | 
| 104 | 
            +
                    "id": "jKxGok-ElYQ7"
         | 
| 105 | 
            +
                  }
         | 
| 106 | 
            +
                },
         | 
| 107 | 
            +
                {
         | 
| 108 | 
            +
                  "cell_type": "code",
         | 
| 109 | 
            +
                  "source": [
         | 
| 110 | 
            +
                    "%%capture\n",
         | 
| 111 | 
            +
                    "!apt install python-opengl\n",
         | 
| 112 | 
            +
                    "!apt install ffmpeg\n",
         | 
| 113 | 
            +
                    "!apt install xvfb\n",
         | 
| 114 | 
            +
                    "!apt install swig"
         | 
| 115 | 
            +
                  ],
         | 
| 116 | 
            +
                  "metadata": {
         | 
| 117 | 
            +
                    "id": "nn6EETTc2Ewf"
         | 
| 118 | 
            +
                  },
         | 
| 119 | 
            +
                  "execution_count": null,
         | 
| 120 | 
            +
                  "outputs": []
         | 
| 121 | 
            +
                },
         | 
| 122 | 
            +
                {
         | 
| 123 | 
            +
                  "cell_type": "code",
         | 
| 124 | 
            +
                  "source": [
         | 
| 125 | 
            +
                    "%%capture\n",
         | 
| 126 | 
            +
                    "%cd /content/rl-algo-impls\n",
         | 
| 127 | 
            +
                    "!pip install -r colab_requirements.txt"
         | 
| 128 | 
            +
                  ],
         | 
| 129 | 
            +
                  "metadata": {
         | 
| 130 | 
            +
                    "id": "AfZh9rH3yQii"
         | 
| 131 | 
            +
                  },
         | 
| 132 | 
            +
                  "execution_count": null,
         | 
| 133 | 
            +
                  "outputs": []
         | 
| 134 | 
            +
                },
         | 
| 135 | 
            +
                {
         | 
| 136 | 
            +
                  "cell_type": "markdown",
         | 
| 137 | 
            +
                  "source": [
         | 
| 138 | 
            +
                    "## Run Once Per Runtime"
         | 
| 139 | 
            +
                  ],
         | 
| 140 | 
            +
                  "metadata": {
         | 
| 141 | 
            +
                    "id": "4o5HOLjc4wq7"
         | 
| 142 | 
            +
                  }
         | 
| 143 | 
            +
                },
         | 
| 144 | 
            +
                {
         | 
| 145 | 
            +
                  "cell_type": "code",
         | 
| 146 | 
            +
                  "source": [
         | 
| 147 | 
            +
                    "import wandb\n",
         | 
| 148 | 
            +
                    "wandb.login()"
         | 
| 149 | 
            +
                  ],
         | 
| 150 | 
            +
                  "metadata": {
         | 
| 151 | 
            +
                    "id": "PCXa5tdS2qFX"
         | 
| 152 | 
            +
                  },
         | 
| 153 | 
            +
                  "execution_count": null,
         | 
| 154 | 
            +
                  "outputs": []
         | 
| 155 | 
            +
                },
         | 
| 156 | 
            +
                {
         | 
| 157 | 
            +
                  "cell_type": "markdown",
         | 
| 158 | 
            +
                  "source": [
         | 
| 159 | 
            +
                    "## Restart Session beteween runs"
         | 
| 160 | 
            +
                  ],
         | 
| 161 | 
            +
                  "metadata": {
         | 
| 162 | 
            +
                    "id": "AZBZfSUV43JQ"
         | 
| 163 | 
            +
                  }
         | 
| 164 | 
            +
                },
         | 
| 165 | 
            +
                {
         | 
| 166 | 
            +
                  "cell_type": "code",
         | 
| 167 | 
            +
                  "source": [
         | 
| 168 | 
            +
                    "%%capture\n",
         | 
| 169 | 
            +
                    "from pyvirtualdisplay import Display\n",
         | 
| 170 | 
            +
                    "\n",
         | 
| 171 | 
            +
                    "virtual_display = Display(visible=0, size=(1400, 900))\n",
         | 
| 172 | 
            +
                    "virtual_display.start()"
         | 
| 173 | 
            +
                  ],
         | 
| 174 | 
            +
                  "metadata": {
         | 
| 175 | 
            +
                    "id": "VzemeQJP2NO9"
         | 
| 176 | 
            +
                  },
         | 
| 177 | 
            +
                  "execution_count": null,
         | 
| 178 | 
            +
                  "outputs": []
         | 
| 179 | 
            +
                },
         | 
| 180 | 
            +
                {
         | 
| 181 | 
            +
                  "cell_type": "markdown",
         | 
| 182 | 
            +
                  "source": [
         | 
| 183 | 
            +
                    "The below 5 bash scripts train agents on environments with 3 seeds each:\n",
         | 
| 184 | 
            +
                    "- colab_basic.sh and colab_pybullet.sh test on a set of basic gym environments and 4 PyBullet environments. Running both together will likely take about 18 hours. This is likely to run into runtime limits for free Colab and Colab Pro, but is fine for Colab Pro+.\n",
         | 
| 185 | 
            +
                    "- colab_carracing.sh only trains 3 seeds on CarRacing-v0, which takes almost 22 hours on Colab Pro+ on high-RAM, standard GPU.\n",
         | 
| 186 | 
            +
                    "- colab_atari1.sh and colab_atari2.sh likely need to be run separately because each takes about 19 hours on high-RAM, standard GPU."
         | 
| 187 | 
            +
                  ],
         | 
| 188 | 
            +
                  "metadata": {
         | 
| 189 | 
            +
                    "id": "nSHfna0hLlO1"
         | 
| 190 | 
            +
                  }
         | 
| 191 | 
            +
                },
         | 
| 192 | 
            +
                {
         | 
| 193 | 
            +
                  "cell_type": "code",
         | 
| 194 | 
            +
                  "source": [
         | 
| 195 | 
            +
                    "%cd /content/rl-algo-impls\n",
         | 
| 196 | 
            +
                    "os.environ[\"BENCHMARK_MAX_PROCS\"] = str(1) # Can't reliably raise this to 2+, but would make it faster.\n",
         | 
| 197 | 
            +
                    "!./benchmarks/colab_basic.sh\n",
         | 
| 198 | 
            +
                    "!./benchmarks/colab_pybullet.sh\n",
         | 
| 199 | 
            +
                    "# !./benchmarks/colab_carracing.sh\n",
         | 
| 200 | 
            +
                    "# !./benchmarks/colab_atari1.sh\n",
         | 
| 201 | 
            +
                    "# !./benchmarks/colab_atari2.sh"
         | 
| 202 | 
            +
                  ],
         | 
| 203 | 
            +
                  "metadata": {
         | 
| 204 | 
            +
                    "id": "07aHYFH1zfXa"
         | 
| 205 | 
            +
                  },
         | 
| 206 | 
            +
                  "execution_count": null,
         | 
| 207 | 
            +
                  "outputs": []
         | 
| 208 | 
            +
                }
         | 
| 209 | 
            +
              ]
         | 
| 210 | 
            +
            }
         | 
    	
        benchmarks/colab_carracing.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            source benchmarks/train_loop.sh
         | 
| 2 | 
            +
            ALGOS="ppo"
         | 
| 3 | 
            +
            ENVS="CarRacing-v0"
         | 
| 4 | 
            +
            BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
         | 
| 5 | 
            +
            train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
         | 
    	
        benchmarks/colab_pybullet.sh
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            source benchmarks/train_loop.sh
         | 
| 2 | 
            +
            ALGOS="ppo"
         | 
| 3 | 
            +
            ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 Walker2DBulletEnv-v0 HopperBulletEnv-v0"
         | 
| 4 | 
            +
            BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
         | 
| 5 | 
            +
            train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
         | 
    	
        benchmarks/train_loop.sh
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            train_loop () {
         | 
| 2 | 
            +
                local WANDB_TAGS="benchmark_$(git rev-parse --short HEAD) host_$(hostname)"
         | 
| 3 | 
            +
                local algo
         | 
| 4 | 
            +
                local env
         | 
| 5 | 
            +
                local seed
         | 
| 6 | 
            +
                local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
         | 
| 7 | 
            +
                local args=()
         | 
| 8 | 
            +
                (( VIRTUAL_DISPLAY == 1)) && args+=("--virtual-display")
         | 
| 9 | 
            +
                local SEEDS="${SEEDS:-1 2 3}"
         | 
| 10 | 
            +
                for algo in $(echo $1); do
         | 
| 11 | 
            +
                    for env in $(echo $2); do
         | 
| 12 | 
            +
                        for seed in $SEEDS; do
         | 
| 13 | 
            +
                            echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME ${args[@]}
         | 
| 14 | 
            +
                        done
         | 
| 15 | 
            +
                    done
         | 
| 16 | 
            +
                done
         | 
| 17 | 
            +
            }
         | 
    	
        colab_enjoy.ipynb
    ADDED
    
    | @@ -0,0 +1,213 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "nbformat": 4,
         | 
| 3 | 
            +
              "nbformat_minor": 0,
         | 
| 4 | 
            +
              "metadata": {
         | 
| 5 | 
            +
                "colab": {
         | 
| 6 | 
            +
                  "provenance": [],
         | 
| 7 | 
            +
                  "machine_shape": "hm",
         | 
| 8 | 
            +
                  "authorship_tag": "ABX9TyM1iRYRLhijbxWxPLk9Ba7f",
         | 
| 9 | 
            +
                  "include_colab_link": true
         | 
| 10 | 
            +
                },
         | 
| 11 | 
            +
                "kernelspec": {
         | 
| 12 | 
            +
                  "name": "python3",
         | 
| 13 | 
            +
                  "display_name": "Python 3"
         | 
| 14 | 
            +
                },
         | 
| 15 | 
            +
                "language_info": {
         | 
| 16 | 
            +
                  "name": "python"
         | 
| 17 | 
            +
                },
         | 
| 18 | 
            +
                "gpuClass": "standard",
         | 
| 19 | 
            +
                "accelerator": "GPU"
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              "cells": [
         | 
| 22 | 
            +
                {
         | 
| 23 | 
            +
                  "cell_type": "markdown",
         | 
| 24 | 
            +
                  "metadata": {
         | 
| 25 | 
            +
                    "id": "view-in-github",
         | 
| 26 | 
            +
                    "colab_type": "text"
         | 
| 27 | 
            +
                  },
         | 
| 28 | 
            +
                  "source": [
         | 
| 29 | 
            +
                    "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
         | 
| 30 | 
            +
                  ]
         | 
| 31 | 
            +
                },
         | 
| 32 | 
            +
                {
         | 
| 33 | 
            +
                  "cell_type": "markdown",
         | 
| 34 | 
            +
                  "source": [
         | 
| 35 | 
            +
                    "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
         | 
| 36 | 
            +
                    "## Parameters\n",
         | 
| 37 | 
            +
                    "\n",
         | 
| 38 | 
            +
                    "\n",
         | 
| 39 | 
            +
                    "1.   Wandb\n",
         | 
| 40 | 
            +
                    "\n"
         | 
| 41 | 
            +
                  ],
         | 
| 42 | 
            +
                  "metadata": {
         | 
| 43 | 
            +
                    "id": "S-tXDWP8WTLc"
         | 
| 44 | 
            +
                  }
         | 
| 45 | 
            +
                },
         | 
| 46 | 
            +
                {
         | 
| 47 | 
            +
                  "cell_type": "code",
         | 
| 48 | 
            +
                  "source": [
         | 
| 49 | 
            +
                    "from getpass import getpass\n",
         | 
| 50 | 
            +
                    "import os\n",
         | 
| 51 | 
            +
                    "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
         | 
| 52 | 
            +
                  ],
         | 
| 53 | 
            +
                  "metadata": {
         | 
| 54 | 
            +
                    "id": "1ZtdYgxWNGwZ"
         | 
| 55 | 
            +
                  },
         | 
| 56 | 
            +
                  "execution_count": null,
         | 
| 57 | 
            +
                  "outputs": []
         | 
| 58 | 
            +
                },
         | 
| 59 | 
            +
                {
         | 
| 60 | 
            +
                  "cell_type": "markdown",
         | 
| 61 | 
            +
                  "source": [
         | 
| 62 | 
            +
                    "2. enjoy.py parameters"
         | 
| 63 | 
            +
                  ],
         | 
| 64 | 
            +
                  "metadata": {
         | 
| 65 | 
            +
                    "id": "ao0nAh3MOdN7"
         | 
| 66 | 
            +
                  }
         | 
| 67 | 
            +
                },
         | 
| 68 | 
            +
                {
         | 
| 69 | 
            +
                  "cell_type": "code",
         | 
| 70 | 
            +
                  "source": [
         | 
| 71 | 
            +
                    "WANDB_RUN_PATH=\"sgoodfriend/rl-algo-impls-benchmarks/rd0lisee\""
         | 
| 72 | 
            +
                  ],
         | 
| 73 | 
            +
                  "metadata": {
         | 
| 74 | 
            +
                    "id": "jKL_NFhVOjSc"
         | 
| 75 | 
            +
                  },
         | 
| 76 | 
            +
                  "execution_count": 2,
         | 
| 77 | 
            +
                  "outputs": []
         | 
| 78 | 
            +
                },
         | 
| 79 | 
            +
                {
         | 
| 80 | 
            +
                  "cell_type": "markdown",
         | 
| 81 | 
            +
                  "source": [
         | 
| 82 | 
            +
                    "## Setup\n",
         | 
| 83 | 
            +
                    "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
         | 
| 84 | 
            +
                  ],
         | 
| 85 | 
            +
                  "metadata": {
         | 
| 86 | 
            +
                    "id": "bsG35Io0hmKG"
         | 
| 87 | 
            +
                  }
         | 
| 88 | 
            +
                },
         | 
| 89 | 
            +
                {
         | 
| 90 | 
            +
                  "cell_type": "code",
         | 
| 91 | 
            +
                  "source": [
         | 
| 92 | 
            +
                    "%%capture\n",
         | 
| 93 | 
            +
                    "!mkdir -p ~/.ssh\n",
         | 
| 94 | 
            +
                    "\n",
         | 
| 95 | 
            +
                    "with open(\"/root/.ssh/id_ed25519\", mode=\"w\") as f:\n",
         | 
| 96 | 
            +
                    "    f.write(\"\"\"-----BEGIN OPENSSH PRIVATE KEY-----\n",
         | 
| 97 | 
            +
                    "b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\n",
         | 
| 98 | 
            +
                    "QyNTUxOQAAACAkIepH6T90umhyp8+bkmSplqth1/+yxnu/Dax61KlSVAAAAKA4W3D3OFtw\n",
         | 
| 99 | 
            +
                    "9wAAAAtzc2gtZWQyNTUxOQAAACAkIepH6T90umhyp8+bkmSplqth1/+yxnu/Dax61KlSVA\n",
         | 
| 100 | 
            +
                    "AAAEA4SPGDm0/gofiOYXPTAi1Oxmw4mTppG2GdNgdMwMiDaSQh6kfpP3S6aHKnz5uSZKmW\n",
         | 
| 101 | 
            +
                    "q2HX/7LGe78NrHrUqVJUAAAAGmdvb2RmcmllbmQuc2NvdHRAZ21haWwuY29tAQID\n",
         | 
| 102 | 
            +
                    "-----END OPENSSH PRIVATE KEY-----\n",
         | 
| 103 | 
            +
                    "\"\"\"\n",
         | 
| 104 | 
            +
                    ")\n",
         | 
| 105 | 
            +
                    "\n",
         | 
| 106 | 
            +
                    "!ssh-keyscan -t ed25519 github.com >> ~/.ssh/known_hosts\n",
         | 
| 107 | 
            +
                    "!chmod go-rwx /root/.ssh/id_ed25519\n",
         | 
| 108 | 
            +
                    "!git clone git@github.com:sgoodfriend/rl-algo-impls.git"
         | 
| 109 | 
            +
                  ],
         | 
| 110 | 
            +
                  "metadata": {
         | 
| 111 | 
            +
                    "id": "k5ynTV25hdAf"
         | 
| 112 | 
            +
                  },
         | 
| 113 | 
            +
                  "execution_count": 3,
         | 
| 114 | 
            +
                  "outputs": []
         | 
| 115 | 
            +
                },
         | 
| 116 | 
            +
                {
         | 
| 117 | 
            +
                  "cell_type": "markdown",
         | 
| 118 | 
            +
                  "source": [
         | 
| 119 | 
            +
                    "Installing the correct packages:\n",
         | 
| 120 | 
            +
                    "\n",
         | 
| 121 | 
            +
                    "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
         | 
| 122 | 
            +
                  ],
         | 
| 123 | 
            +
                  "metadata": {
         | 
| 124 | 
            +
                    "id": "jKxGok-ElYQ7"
         | 
| 125 | 
            +
                  }
         | 
| 126 | 
            +
                },
         | 
| 127 | 
            +
                {
         | 
| 128 | 
            +
                  "cell_type": "code",
         | 
| 129 | 
            +
                  "source": [
         | 
| 130 | 
            +
                    "%%capture\n",
         | 
| 131 | 
            +
                    "!apt install python-opengl\n",
         | 
| 132 | 
            +
                    "!apt install ffmpeg\n",
         | 
| 133 | 
            +
                    "!apt install xvfb\n",
         | 
| 134 | 
            +
                    "!apt install swig"
         | 
| 135 | 
            +
                  ],
         | 
| 136 | 
            +
                  "metadata": {
         | 
| 137 | 
            +
                    "id": "nn6EETTc2Ewf"
         | 
| 138 | 
            +
                  },
         | 
| 139 | 
            +
                  "execution_count": 4,
         | 
| 140 | 
            +
                  "outputs": []
         | 
| 141 | 
            +
                },
         | 
| 142 | 
            +
                {
         | 
| 143 | 
            +
                  "cell_type": "code",
         | 
| 144 | 
            +
                  "source": [
         | 
| 145 | 
            +
                    "%%capture\n",
         | 
| 146 | 
            +
                    "%cd /content/rl-algo-impls\n",
         | 
| 147 | 
            +
                    "!pip install -r colab_requirements.txt"
         | 
| 148 | 
            +
                  ],
         | 
| 149 | 
            +
                  "metadata": {
         | 
| 150 | 
            +
                    "id": "AfZh9rH3yQii"
         | 
| 151 | 
            +
                  },
         | 
| 152 | 
            +
                  "execution_count": 5,
         | 
| 153 | 
            +
                  "outputs": []
         | 
| 154 | 
            +
                },
         | 
| 155 | 
            +
                {
         | 
| 156 | 
            +
                  "cell_type": "markdown",
         | 
| 157 | 
            +
                  "source": [
         | 
| 158 | 
            +
                    "## Run Once Per Runtime"
         | 
| 159 | 
            +
                  ],
         | 
| 160 | 
            +
                  "metadata": {
         | 
| 161 | 
            +
                    "id": "4o5HOLjc4wq7"
         | 
| 162 | 
            +
                  }
         | 
| 163 | 
            +
                },
         | 
| 164 | 
            +
                {
         | 
| 165 | 
            +
                  "cell_type": "code",
         | 
| 166 | 
            +
                  "source": [
         | 
| 167 | 
            +
                    "import wandb\n",
         | 
| 168 | 
            +
                    "wandb.login()"
         | 
| 169 | 
            +
                  ],
         | 
| 170 | 
            +
                  "metadata": {
         | 
| 171 | 
            +
                    "id": "PCXa5tdS2qFX"
         | 
| 172 | 
            +
                  },
         | 
| 173 | 
            +
                  "execution_count": null,
         | 
| 174 | 
            +
                  "outputs": []
         | 
| 175 | 
            +
                },
         | 
| 176 | 
            +
                {
         | 
| 177 | 
            +
                  "cell_type": "markdown",
         | 
| 178 | 
            +
                  "source": [
         | 
| 179 | 
            +
                    "## Restart Session beteween runs"
         | 
| 180 | 
            +
                  ],
         | 
| 181 | 
            +
                  "metadata": {
         | 
| 182 | 
            +
                    "id": "AZBZfSUV43JQ"
         | 
| 183 | 
            +
                  }
         | 
| 184 | 
            +
                },
         | 
| 185 | 
            +
                {
         | 
| 186 | 
            +
                  "cell_type": "code",
         | 
| 187 | 
            +
                  "source": [
         | 
| 188 | 
            +
                    "%%capture\n",
         | 
| 189 | 
            +
                    "from pyvirtualdisplay import Display\n",
         | 
| 190 | 
            +
                    "\n",
         | 
| 191 | 
            +
                    "virtual_display = Display(visible=0, size=(1400, 900))\n",
         | 
| 192 | 
            +
                    "virtual_display.start()"
         | 
| 193 | 
            +
                  ],
         | 
| 194 | 
            +
                  "metadata": {
         | 
| 195 | 
            +
                    "id": "VzemeQJP2NO9"
         | 
| 196 | 
            +
                  },
         | 
| 197 | 
            +
                  "execution_count": 7,
         | 
| 198 | 
            +
                  "outputs": []
         | 
| 199 | 
            +
                },
         | 
| 200 | 
            +
                {
         | 
| 201 | 
            +
                  "cell_type": "code",
         | 
| 202 | 
            +
                  "source": [
         | 
| 203 | 
            +
                    "%cd /content/rl-algo-impls\n",
         | 
| 204 | 
            +
                    "!python enjoy.py --wandb-run-path={WANDB_RUN_PATH}"
         | 
| 205 | 
            +
                  ],
         | 
| 206 | 
            +
                  "metadata": {
         | 
| 207 | 
            +
                    "id": "07aHYFH1zfXa"
         | 
| 208 | 
            +
                  },
         | 
| 209 | 
            +
                  "execution_count": null,
         | 
| 210 | 
            +
                  "outputs": []
         | 
| 211 | 
            +
                }
         | 
| 212 | 
            +
              ]
         | 
| 213 | 
            +
            }
         | 
    	
        colab_requirements.txt
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            AutoROM.accept-rom-license >= 0.4.2, < 0.5
         | 
| 2 | 
            +
            stable-baselines3[extra] >= 1.7.0, < 1.8
         | 
| 3 | 
            +
            gym[box2d] >= 0.21.0, < 0.22
         | 
| 4 | 
            +
            pyglet == 1.5.27
         | 
| 5 | 
            +
            wandb >= 0.13.9, < 0.14
         | 
| 6 | 
            +
            pyvirtualdisplay == 3.0
         | 
| 7 | 
            +
            pybullet >= 3.2.5, < 3.3
         | 
    	
        colab_train.ipynb
    ADDED
    
    | @@ -0,0 +1,215 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "nbformat": 4,
         | 
| 3 | 
            +
              "nbformat_minor": 0,
         | 
| 4 | 
            +
              "metadata": {
         | 
| 5 | 
            +
                "colab": {
         | 
| 6 | 
            +
                  "provenance": [],
         | 
| 7 | 
            +
                  "machine_shape": "hm",
         | 
| 8 | 
            +
                  "authorship_tag": "ABX9TyNGs5TudweZiYKySQxg6H+K",
         | 
| 9 | 
            +
                  "include_colab_link": true
         | 
| 10 | 
            +
                },
         | 
| 11 | 
            +
                "kernelspec": {
         | 
| 12 | 
            +
                  "name": "python3",
         | 
| 13 | 
            +
                  "display_name": "Python 3"
         | 
| 14 | 
            +
                },
         | 
| 15 | 
            +
                "language_info": {
         | 
| 16 | 
            +
                  "name": "python"
         | 
| 17 | 
            +
                },
         | 
| 18 | 
            +
                "gpuClass": "standard",
         | 
| 19 | 
            +
                "accelerator": "GPU"
         | 
| 20 | 
            +
              },
         | 
| 21 | 
            +
              "cells": [
         | 
| 22 | 
            +
                {
         | 
| 23 | 
            +
                  "cell_type": "markdown",
         | 
| 24 | 
            +
                  "metadata": {
         | 
| 25 | 
            +
                    "id": "view-in-github",
         | 
| 26 | 
            +
                    "colab_type": "text"
         | 
| 27 | 
            +
                  },
         | 
| 28 | 
            +
                  "source": [
         | 
| 29 | 
            +
                    "<a href=\"https://colab.research.google.com/github/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
         | 
| 30 | 
            +
                  ]
         | 
| 31 | 
            +
                },
         | 
| 32 | 
            +
                {
         | 
| 33 | 
            +
                  "cell_type": "markdown",
         | 
| 34 | 
            +
                  "source": [
         | 
| 35 | 
            +
                    "# [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) in Google Colaboratory\n",
         | 
| 36 | 
            +
                    "## Parameters\n",
         | 
| 37 | 
            +
                    "\n",
         | 
| 38 | 
            +
                    "\n",
         | 
| 39 | 
            +
                    "1.   Wandb\n",
         | 
| 40 | 
            +
                    "\n"
         | 
| 41 | 
            +
                  ],
         | 
| 42 | 
            +
                  "metadata": {
         | 
| 43 | 
            +
                    "id": "S-tXDWP8WTLc"
         | 
| 44 | 
            +
                  }
         | 
| 45 | 
            +
                },
         | 
| 46 | 
            +
                {
         | 
| 47 | 
            +
                  "cell_type": "code",
         | 
| 48 | 
            +
                  "source": [
         | 
| 49 | 
            +
                    "from getpass import getpass\n",
         | 
| 50 | 
            +
                    "import os\n",
         | 
| 51 | 
            +
                    "os.environ[\"WANDB_API_KEY\"] = getpass(\"Wandb API key to upload metrics, videos, and models: \")"
         | 
| 52 | 
            +
                  ],
         | 
| 53 | 
            +
                  "metadata": {
         | 
| 54 | 
            +
                    "id": "1ZtdYgxWNGwZ"
         | 
| 55 | 
            +
                  },
         | 
| 56 | 
            +
                  "execution_count": null,
         | 
| 57 | 
            +
                  "outputs": []
         | 
| 58 | 
            +
                },
         | 
| 59 | 
            +
                {
         | 
| 60 | 
            +
                  "cell_type": "markdown",
         | 
| 61 | 
            +
                  "source": [
         | 
| 62 | 
            +
                    "2. train run parameters"
         | 
| 63 | 
            +
                  ],
         | 
| 64 | 
            +
                  "metadata": {
         | 
| 65 | 
            +
                    "id": "ao0nAh3MOdN7"
         | 
| 66 | 
            +
                  }
         | 
| 67 | 
            +
                },
         | 
| 68 | 
            +
                {
         | 
| 69 | 
            +
                  "cell_type": "code",
         | 
| 70 | 
            +
                  "source": [
         | 
| 71 | 
            +
                    "ALGO = \"ppo\"\n",
         | 
| 72 | 
            +
                    "ENV = \"CartPole-v1\"\n",
         | 
| 73 | 
            +
                    "SEED = 1"
         | 
| 74 | 
            +
                  ],
         | 
| 75 | 
            +
                  "metadata": {
         | 
| 76 | 
            +
                    "id": "jKL_NFhVOjSc"
         | 
| 77 | 
            +
                  },
         | 
| 78 | 
            +
                  "execution_count": 2,
         | 
| 79 | 
            +
                  "outputs": []
         | 
| 80 | 
            +
                },
         | 
| 81 | 
            +
                {
         | 
| 82 | 
            +
                  "cell_type": "markdown",
         | 
| 83 | 
            +
                  "source": [
         | 
| 84 | 
            +
                    "## Setup\n",
         | 
| 85 | 
            +
                    "Clone [sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) "
         | 
| 86 | 
            +
                  ],
         | 
| 87 | 
            +
                  "metadata": {
         | 
| 88 | 
            +
                    "id": "bsG35Io0hmKG"
         | 
| 89 | 
            +
                  }
         | 
| 90 | 
            +
                },
         | 
| 91 | 
            +
                {
         | 
| 92 | 
            +
                  "cell_type": "code",
         | 
| 93 | 
            +
                  "source": [
         | 
| 94 | 
            +
                    "%%capture\n",
         | 
| 95 | 
            +
                    "!mkdir -p ~/.ssh\n",
         | 
| 96 | 
            +
                    "\n",
         | 
| 97 | 
            +
                    "with open(\"/root/.ssh/id_ed25519\", mode=\"w\") as f:\n",
         | 
| 98 | 
            +
                    "    f.write(\"\"\"-----BEGIN OPENSSH PRIVATE KEY-----\n",
         | 
| 99 | 
            +
                    "b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW\n",
         | 
| 100 | 
            +
                    "QyNTUxOQAAACAkIepH6T90umhyp8+bkmSplqth1/+yxnu/Dax61KlSVAAAAKA4W3D3OFtw\n",
         | 
| 101 | 
            +
                    "9wAAAAtzc2gtZWQyNTUxOQAAACAkIepH6T90umhyp8+bkmSplqth1/+yxnu/Dax61KlSVA\n",
         | 
| 102 | 
            +
                    "AAAEA4SPGDm0/gofiOYXPTAi1Oxmw4mTppG2GdNgdMwMiDaSQh6kfpP3S6aHKnz5uSZKmW\n",
         | 
| 103 | 
            +
                    "q2HX/7LGe78NrHrUqVJUAAAAGmdvb2RmcmllbmQuc2NvdHRAZ21haWwuY29tAQID\n",
         | 
| 104 | 
            +
                    "-----END OPENSSH PRIVATE KEY-----\n",
         | 
| 105 | 
            +
                    "\"\"\"\n",
         | 
| 106 | 
            +
                    ")\n",
         | 
| 107 | 
            +
                    "\n",
         | 
| 108 | 
            +
                    "!ssh-keyscan -t ed25519 github.com >> ~/.ssh/known_hosts\n",
         | 
| 109 | 
            +
                    "!chmod go-rwx /root/.ssh/id_ed25519\n",
         | 
| 110 | 
            +
                    "!git clone git@github.com:sgoodfriend/rl-algo-impls.git"
         | 
| 111 | 
            +
                  ],
         | 
| 112 | 
            +
                  "metadata": {
         | 
| 113 | 
            +
                    "id": "k5ynTV25hdAf"
         | 
| 114 | 
            +
                  },
         | 
| 115 | 
            +
                  "execution_count": 3,
         | 
| 116 | 
            +
                  "outputs": []
         | 
| 117 | 
            +
                },
         | 
| 118 | 
            +
                {
         | 
| 119 | 
            +
                  "cell_type": "markdown",
         | 
| 120 | 
            +
                  "source": [
         | 
| 121 | 
            +
                    "Installing the correct packages:\n",
         | 
| 122 | 
            +
                    "\n",
         | 
| 123 | 
            +
                    "While conda and poetry are generally used for package management, the mismatch in Python versions (3.10 in the project file vs 3.8 in Colab) makes using the package yml files difficult to use. For now, instead I'm going to specify the list of requirements manually below:"
         | 
| 124 | 
            +
                  ],
         | 
| 125 | 
            +
                  "metadata": {
         | 
| 126 | 
            +
                    "id": "jKxGok-ElYQ7"
         | 
| 127 | 
            +
                  }
         | 
| 128 | 
            +
                },
         | 
| 129 | 
            +
                {
         | 
| 130 | 
            +
                  "cell_type": "code",
         | 
| 131 | 
            +
                  "source": [
         | 
| 132 | 
            +
                    "%%capture\n",
         | 
| 133 | 
            +
                    "!apt install python-opengl\n",
         | 
| 134 | 
            +
                    "!apt install ffmpeg\n",
         | 
| 135 | 
            +
                    "!apt install xvfb\n",
         | 
| 136 | 
            +
                    "!apt install swig"
         | 
| 137 | 
            +
                  ],
         | 
| 138 | 
            +
                  "metadata": {
         | 
| 139 | 
            +
                    "id": "nn6EETTc2Ewf"
         | 
| 140 | 
            +
                  },
         | 
| 141 | 
            +
                  "execution_count": 4,
         | 
| 142 | 
            +
                  "outputs": []
         | 
| 143 | 
            +
                },
         | 
| 144 | 
            +
                {
         | 
| 145 | 
            +
                  "cell_type": "code",
         | 
| 146 | 
            +
                  "source": [
         | 
| 147 | 
            +
                    "%%capture\n",
         | 
| 148 | 
            +
                    "%cd /content/rl-algo-impls\n",
         | 
| 149 | 
            +
                    "!pip install -r colab_requirements.txt"
         | 
| 150 | 
            +
                  ],
         | 
| 151 | 
            +
                  "metadata": {
         | 
| 152 | 
            +
                    "id": "AfZh9rH3yQii"
         | 
| 153 | 
            +
                  },
         | 
| 154 | 
            +
                  "execution_count": 5,
         | 
| 155 | 
            +
                  "outputs": []
         | 
| 156 | 
            +
                },
         | 
| 157 | 
            +
                {
         | 
| 158 | 
            +
                  "cell_type": "markdown",
         | 
| 159 | 
            +
                  "source": [
         | 
| 160 | 
            +
                    "## Run Once Per Runtime"
         | 
| 161 | 
            +
                  ],
         | 
| 162 | 
            +
                  "metadata": {
         | 
| 163 | 
            +
                    "id": "4o5HOLjc4wq7"
         | 
| 164 | 
            +
                  }
         | 
| 165 | 
            +
                },
         | 
| 166 | 
            +
                {
         | 
| 167 | 
            +
                  "cell_type": "code",
         | 
| 168 | 
            +
                  "source": [
         | 
| 169 | 
            +
                    "import wandb\n",
         | 
| 170 | 
            +
                    "wandb.login()"
         | 
| 171 | 
            +
                  ],
         | 
| 172 | 
            +
                  "metadata": {
         | 
| 173 | 
            +
                    "id": "PCXa5tdS2qFX"
         | 
| 174 | 
            +
                  },
         | 
| 175 | 
            +
                  "execution_count": null,
         | 
| 176 | 
            +
                  "outputs": []
         | 
| 177 | 
            +
                },
         | 
| 178 | 
            +
                {
         | 
| 179 | 
            +
                  "cell_type": "markdown",
         | 
| 180 | 
            +
                  "source": [
         | 
| 181 | 
            +
                    "## Restart Session beteween runs"
         | 
| 182 | 
            +
                  ],
         | 
| 183 | 
            +
                  "metadata": {
         | 
| 184 | 
            +
                    "id": "AZBZfSUV43JQ"
         | 
| 185 | 
            +
                  }
         | 
| 186 | 
            +
                },
         | 
| 187 | 
            +
                {
         | 
| 188 | 
            +
                  "cell_type": "code",
         | 
| 189 | 
            +
                  "source": [
         | 
| 190 | 
            +
                    "%%capture\n",
         | 
| 191 | 
            +
                    "from pyvirtualdisplay import Display\n",
         | 
| 192 | 
            +
                    "\n",
         | 
| 193 | 
            +
                    "virtual_display = Display(visible=0, size=(1400, 900))\n",
         | 
| 194 | 
            +
                    "virtual_display.start()"
         | 
| 195 | 
            +
                  ],
         | 
| 196 | 
            +
                  "metadata": {
         | 
| 197 | 
            +
                    "id": "VzemeQJP2NO9"
         | 
| 198 | 
            +
                  },
         | 
| 199 | 
            +
                  "execution_count": 7,
         | 
| 200 | 
            +
                  "outputs": []
         | 
| 201 | 
            +
                },
         | 
| 202 | 
            +
                {
         | 
| 203 | 
            +
                  "cell_type": "code",
         | 
| 204 | 
            +
                  "source": [
         | 
| 205 | 
            +
                    "%cd /content/rl-algo-impls\n",
         | 
| 206 | 
            +
                    "!python train.py --algo {ALGO} --env {ENV} --seed {SEED}"
         | 
| 207 | 
            +
                  ],
         | 
| 208 | 
            +
                  "metadata": {
         | 
| 209 | 
            +
                    "id": "07aHYFH1zfXa"
         | 
| 210 | 
            +
                  },
         | 
| 211 | 
            +
                  "execution_count": null,
         | 
| 212 | 
            +
                  "outputs": []
         | 
| 213 | 
            +
                }
         | 
| 214 | 
            +
              ]
         | 
| 215 | 
            +
            }
         | 
    	
        dqn/dqn.py
    ADDED
    
    | @@ -0,0 +1,182 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
            import torch.nn.functional as F
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from collections import deque
         | 
| 9 | 
            +
            from torch.optim import Adam
         | 
| 10 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
         | 
| 11 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 12 | 
            +
            from typing import List, NamedTuple, Optional, TypeVar
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from dqn.policy import DQNPolicy
         | 
| 15 | 
            +
            from shared.algorithm import Algorithm
         | 
| 16 | 
            +
            from shared.callbacks.callback import Callback
         | 
| 17 | 
            +
            from shared.schedule import linear_schedule
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class Transition(NamedTuple):
         | 
| 21 | 
            +
                obs: np.ndarray
         | 
| 22 | 
            +
                action: np.ndarray
         | 
| 23 | 
            +
                reward: float
         | 
| 24 | 
            +
                done: bool
         | 
| 25 | 
            +
                next_obs: np.ndarray
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class Batch(NamedTuple):
         | 
| 29 | 
            +
                obs: np.ndarray
         | 
| 30 | 
            +
                actions: np.ndarray
         | 
| 31 | 
            +
                rewards: np.ndarray
         | 
| 32 | 
            +
                dones: np.ndarray
         | 
| 33 | 
            +
                next_obs: np.ndarray
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class ReplayBuffer:
         | 
| 37 | 
            +
                def __init__(self, num_envs: int, maxlen: int) -> None:
         | 
| 38 | 
            +
                    self.num_envs = num_envs
         | 
| 39 | 
            +
                    self.buffer = deque(maxlen=maxlen)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def add(
         | 
| 42 | 
            +
                    self,
         | 
| 43 | 
            +
                    obs: VecEnvObs,
         | 
| 44 | 
            +
                    action: np.ndarray,
         | 
| 45 | 
            +
                    reward: np.ndarray,
         | 
| 46 | 
            +
                    done: np.ndarray,
         | 
| 47 | 
            +
                    next_obs: VecEnvObs,
         | 
| 48 | 
            +
                ) -> None:
         | 
| 49 | 
            +
                    assert isinstance(obs, np.ndarray)
         | 
| 50 | 
            +
                    assert isinstance(next_obs, np.ndarray)
         | 
| 51 | 
            +
                    for i in range(self.num_envs):
         | 
| 52 | 
            +
                        self.buffer.append(
         | 
| 53 | 
            +
                            Transition(obs[i], action[i], reward[i], done[i], next_obs[i])
         | 
| 54 | 
            +
                        )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                def sample(self, batch_size: int) -> Batch:
         | 
| 57 | 
            +
                    ts = random.sample(self.buffer, batch_size)
         | 
| 58 | 
            +
                    return Batch(
         | 
| 59 | 
            +
                        obs=np.array([t.obs for t in ts]),
         | 
| 60 | 
            +
                        actions=np.array([t.action for t in ts]),
         | 
| 61 | 
            +
                        rewards=np.array([t.reward for t in ts]),
         | 
| 62 | 
            +
                        dones=np.array([t.done for t in ts]),
         | 
| 63 | 
            +
                        next_obs=np.array([t.next_obs for t in ts]),
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def __len__(self) -> int:
         | 
| 67 | 
            +
                    return len(self.buffer)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            DQNSelf = TypeVar("DQNSelf", bound="DQN")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class DQN(Algorithm):
         | 
| 74 | 
            +
                def __init__(
         | 
| 75 | 
            +
                    self,
         | 
| 76 | 
            +
                    policy: DQNPolicy,
         | 
| 77 | 
            +
                    env: VecEnv,
         | 
| 78 | 
            +
                    device: torch.device,
         | 
| 79 | 
            +
                    tb_writer: SummaryWriter,
         | 
| 80 | 
            +
                    learning_rate: float = 1e-4,
         | 
| 81 | 
            +
                    buffer_size: int = 1_000_000,
         | 
| 82 | 
            +
                    learning_starts: int = 50_000,
         | 
| 83 | 
            +
                    batch_size: int = 32,
         | 
| 84 | 
            +
                    tau: float = 1.0,
         | 
| 85 | 
            +
                    gamma: float = 0.99,
         | 
| 86 | 
            +
                    train_freq: int = 4,
         | 
| 87 | 
            +
                    gradient_steps: int = 1,
         | 
| 88 | 
            +
                    target_update_interval: int = 10_000,
         | 
| 89 | 
            +
                    exploration_fraction: float = 0.1,
         | 
| 90 | 
            +
                    exploration_initial_eps: float = 1.0,
         | 
| 91 | 
            +
                    exploration_final_eps: float = 0.05,
         | 
| 92 | 
            +
                    max_grad_norm: float = 10.0,
         | 
| 93 | 
            +
                ) -> None:
         | 
| 94 | 
            +
                    super().__init__(policy, env, device, tb_writer)
         | 
| 95 | 
            +
                    self.policy = policy
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.optimizer = Adam(self.policy.q_net.parameters(), lr=learning_rate)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.target_q_net = copy.deepcopy(self.policy.q_net).to(self.device)
         | 
| 100 | 
            +
                    self.target_q_net.train(False)
         | 
| 101 | 
            +
                    self.tau = tau
         | 
| 102 | 
            +
                    self.target_update_interval = target_update_interval
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    self.replay_buffer = ReplayBuffer(self.env.num_envs, buffer_size)
         | 
| 105 | 
            +
                    self.batch_size = batch_size
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.learning_starts = learning_starts
         | 
| 108 | 
            +
                    self.train_freq = train_freq
         | 
| 109 | 
            +
                    self.gradient_steps = gradient_steps
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.gamma = gamma
         | 
| 112 | 
            +
                    self.exploration_eps_schedule = linear_schedule(
         | 
| 113 | 
            +
                        exploration_initial_eps,
         | 
| 114 | 
            +
                        exploration_final_eps,
         | 
| 115 | 
            +
                        end_fraction=exploration_fraction,
         | 
| 116 | 
            +
                    )
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    self.max_grad_norm = max_grad_norm
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def learn(
         | 
| 121 | 
            +
                    self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
         | 
| 122 | 
            +
                ) -> DQNSelf:
         | 
| 123 | 
            +
                    self.policy.train(True)
         | 
| 124 | 
            +
                    obs = self.env.reset()
         | 
| 125 | 
            +
                    obs = self._collect_rollout(self.learning_starts, obs, 1)
         | 
| 126 | 
            +
                    learning_steps = total_timesteps - self.learning_starts
         | 
| 127 | 
            +
                    timesteps_elapsed = 0
         | 
| 128 | 
            +
                    steps_since_target_update = 0
         | 
| 129 | 
            +
                    while timesteps_elapsed < learning_steps:
         | 
| 130 | 
            +
                        progress = timesteps_elapsed / learning_steps
         | 
| 131 | 
            +
                        eps = self.exploration_eps_schedule(progress)
         | 
| 132 | 
            +
                        obs = self._collect_rollout(self.train_freq, obs, eps)
         | 
| 133 | 
            +
                        rollout_steps = self.train_freq
         | 
| 134 | 
            +
                        timesteps_elapsed += rollout_steps
         | 
| 135 | 
            +
                        for _ in range(
         | 
| 136 | 
            +
                            self.gradient_steps if self.gradient_steps > 0 else self.train_freq
         | 
| 137 | 
            +
                        ):
         | 
| 138 | 
            +
                            self.train()
         | 
| 139 | 
            +
                        steps_since_target_update += rollout_steps
         | 
| 140 | 
            +
                        if steps_since_target_update >= self.target_update_interval:
         | 
| 141 | 
            +
                            self._update_target()
         | 
| 142 | 
            +
                            steps_since_target_update = 0
         | 
| 143 | 
            +
                        if callback:
         | 
| 144 | 
            +
                            callback.on_step(timesteps_elapsed=rollout_steps)
         | 
| 145 | 
            +
                    return self
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def train(self) -> None:
         | 
| 148 | 
            +
                    if len(self.replay_buffer) < self.batch_size:
         | 
| 149 | 
            +
                        return
         | 
| 150 | 
            +
                    o, a, r, d, next_o = self.replay_buffer.sample(self.batch_size)
         | 
| 151 | 
            +
                    o = torch.as_tensor(o, device=self.device)
         | 
| 152 | 
            +
                    a = torch.as_tensor(a, device=self.device).unsqueeze(1)
         | 
| 153 | 
            +
                    r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
         | 
| 154 | 
            +
                    d = torch.as_tensor(d, dtype=torch.long, device=self.device)
         | 
| 155 | 
            +
                    next_o = torch.as_tensor(next_o, device=self.device)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    with torch.no_grad():
         | 
| 158 | 
            +
                        target = r + (1 - d) * self.gamma * self.target_q_net(next_o).max(1).values
         | 
| 159 | 
            +
                    current = self.policy.q_net(o).gather(dim=1, index=a).squeeze(1)
         | 
| 160 | 
            +
                    loss = F.smooth_l1_loss(current, target)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    self.optimizer.zero_grad()
         | 
| 163 | 
            +
                    loss.backward()
         | 
| 164 | 
            +
                    if self.max_grad_norm:
         | 
| 165 | 
            +
                        nn.utils.clip_grad_norm_(self.policy.q_net.parameters(), self.max_grad_norm)
         | 
| 166 | 
            +
                    self.optimizer.step()
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                def _collect_rollout(self, timesteps: int, obs: VecEnvObs, eps: float) -> VecEnvObs:
         | 
| 169 | 
            +
                    for _ in range(0, timesteps, self.env.num_envs):
         | 
| 170 | 
            +
                        action = self.policy.act(obs, eps, deterministic=False)
         | 
| 171 | 
            +
                        next_obs, reward, done, _ = self.env.step(action)
         | 
| 172 | 
            +
                        self.replay_buffer.add(obs, action, reward, done, next_obs)
         | 
| 173 | 
            +
                        obs = next_obs
         | 
| 174 | 
            +
                    return obs
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def _update_target(self) -> None:
         | 
| 177 | 
            +
                    for target_param, param in zip(
         | 
| 178 | 
            +
                        self.target_q_net.parameters(), self.policy.q_net.parameters()
         | 
| 179 | 
            +
                    ):
         | 
| 180 | 
            +
                        target_param.data.copy_(
         | 
| 181 | 
            +
                            self.tau * param.data + (1 - self.tau) * target_param.data
         | 
| 182 | 
            +
                        )
         | 
    	
        dqn/policy.py
    ADDED
    
    | @@ -0,0 +1,37 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
         | 
| 6 | 
            +
            from typing import Sequence, TypeVar
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from dqn.q_net import QNetwork
         | 
| 9 | 
            +
            from shared.policy.policy import Policy
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy")
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class DQNPolicy(Policy):
         | 
| 15 | 
            +
                def __init__(
         | 
| 16 | 
            +
                    self,
         | 
| 17 | 
            +
                    env: VecEnv,
         | 
| 18 | 
            +
                    hidden_sizes: Sequence[int],
         | 
| 19 | 
            +
                    **kwargs,
         | 
| 20 | 
            +
                ) -> None:
         | 
| 21 | 
            +
                    super().__init__(env, **kwargs)
         | 
| 22 | 
            +
                    self.q_net = QNetwork(env.observation_space, env.action_space, hidden_sizes)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def act(
         | 
| 25 | 
            +
                    self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True
         | 
| 26 | 
            +
                ) -> np.ndarray:
         | 
| 27 | 
            +
                    assert eps == 0 if deterministic else eps >= 0
         | 
| 28 | 
            +
                    if not deterministic and np.random.random() < eps:
         | 
| 29 | 
            +
                        return np.array(
         | 
| 30 | 
            +
                            [self.env.action_space.sample() for _ in range(self.env.num_envs)]
         | 
| 31 | 
            +
                        )
         | 
| 32 | 
            +
                    else:
         | 
| 33 | 
            +
                        with torch.no_grad():
         | 
| 34 | 
            +
                            obs_th = torch.as_tensor(np.array(obs))
         | 
| 35 | 
            +
                            if self.device:
         | 
| 36 | 
            +
                                obs_th = obs_th.to(self.device)
         | 
| 37 | 
            +
                            return self.q_net(obs_th).argmax(axis=1).cpu().numpy()
         | 
    	
        dqn/q_net.py
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gym
         | 
| 2 | 
            +
            import torch as th
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from gym.spaces import Discrete
         | 
| 6 | 
            +
            from typing import Sequence, Type
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from shared.module import FeatureExtractor, mlp
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class QNetwork(nn.Module):
         | 
| 12 | 
            +
                def __init__(
         | 
| 13 | 
            +
                    self,
         | 
| 14 | 
            +
                    observation_space: gym.Space,
         | 
| 15 | 
            +
                    action_space: gym.Space,
         | 
| 16 | 
            +
                    hidden_sizes: Sequence[int],
         | 
| 17 | 
            +
                    activation: Type[nn.Module] = nn.ReLU,  # Used by stable-baselines3
         | 
| 18 | 
            +
                ) -> None:
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
                    assert isinstance(action_space, Discrete)
         | 
| 21 | 
            +
                    self._feature_extractor = FeatureExtractor(observation_space, activation)
         | 
| 22 | 
            +
                    layer_sizes = (
         | 
| 23 | 
            +
                        (self._feature_extractor.out_dim,) + tuple(hidden_sizes) + (action_space.n,)
         | 
| 24 | 
            +
                    )
         | 
| 25 | 
            +
                    self._fc = mlp(layer_sizes, activation)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, obs: th.Tensor) -> th.Tensor:
         | 
| 28 | 
            +
                    x = self._feature_extractor(obs)
         | 
| 29 | 
            +
                    return self._fc(x)
         | 
    	
        enjoy.py
    ADDED
    
    | @@ -0,0 +1,105 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import shutil
         | 
| 7 | 
            +
            import yaml
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from dataclasses import dataclass
         | 
| 10 | 
            +
            from typing import Optional
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from runner.env import make_eval_env
         | 
| 13 | 
            +
            from runner.config import Config, RunArgs
         | 
| 14 | 
            +
            from runner.running_utils import (
         | 
| 15 | 
            +
                base_parser,
         | 
| 16 | 
            +
                load_hyperparams,
         | 
| 17 | 
            +
                set_seeds,
         | 
| 18 | 
            +
                get_device,
         | 
| 19 | 
            +
                make_policy,
         | 
| 20 | 
            +
            )
         | 
| 21 | 
            +
            from shared.callbacks.eval_callback import evaluate
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            @dataclass
         | 
| 25 | 
            +
            class EvalArgs(RunArgs):
         | 
| 26 | 
            +
                render: bool = True
         | 
| 27 | 
            +
                best: bool = True
         | 
| 28 | 
            +
                n_envs: int = 1
         | 
| 29 | 
            +
                n_episodes: int = 3
         | 
| 30 | 
            +
                deterministic: Optional[bool] = None
         | 
| 31 | 
            +
                wandb_run_path: Optional[str] = None
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            if __name__ == "__main__":
         | 
| 35 | 
            +
                parser = base_parser()
         | 
| 36 | 
            +
                parser.add_argument("--render", default=True, type=bool)
         | 
| 37 | 
            +
                parser.add_argument("--best", default=True, type=bool)
         | 
| 38 | 
            +
                parser.add_argument("--n_envs", default=1, type=int)
         | 
| 39 | 
            +
                parser.add_argument("--n_episodes", default=3, type=int)
         | 
| 40 | 
            +
                parser.add_argument("--deterministic", default=None, type=bool)
         | 
| 41 | 
            +
                parser.add_argument("--wandb-run-path", default=None, type=str)
         | 
| 42 | 
            +
                parser.set_defaults(
         | 
| 43 | 
            +
                    wandb_run_path="sgoodfriend/rl-algo-impls/sfi78a3t",
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
                args = EvalArgs(**vars(parser.parse_args()))
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                if args.wandb_run_path:
         | 
| 48 | 
            +
                    import wandb
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    api = wandb.Api()
         | 
| 51 | 
            +
                    run = api.run(args.wandb_run_path)
         | 
| 52 | 
            +
                    hyperparams = run.config
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    args.algo = hyperparams["algo"]
         | 
| 55 | 
            +
                    args.env = hyperparams["env"]
         | 
| 56 | 
            +
                    args.use_deterministic_algorithms = hyperparams.get(
         | 
| 57 | 
            +
                        "use_deterministic_algorithms", True
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    config = Config(args, hyperparams, os.path.dirname(__file__))
         | 
| 61 | 
            +
                    model_path = config.model_dir_path(best=args.best, downloaded=True)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
         | 
| 64 | 
            +
                    run.file(model_archive_name).download()
         | 
| 65 | 
            +
                    if os.path.isdir(model_path):
         | 
| 66 | 
            +
                        shutil.rmtree(model_path)
         | 
| 67 | 
            +
                    shutil.unpack_archive(model_archive_name, model_path)
         | 
| 68 | 
            +
                    os.remove(model_archive_name)
         | 
| 69 | 
            +
                else:
         | 
| 70 | 
            +
                    hyperparams = load_hyperparams(args.algo, args.env, os.path.dirname(__file__))
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    config = Config(args, hyperparams, os.path.dirname(__file__))
         | 
| 73 | 
            +
                    model_path = config.model_dir_path(best=args.best)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                print(args)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                set_seeds(args.seed, args.use_deterministic_algorithms)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                env = make_eval_env(
         | 
| 80 | 
            +
                    config,
         | 
| 81 | 
            +
                    override_n_envs=args.n_envs,
         | 
| 82 | 
            +
                    render=args.render,
         | 
| 83 | 
            +
                    normalize_load_path=model_path,
         | 
| 84 | 
            +
                    **config.env_hyperparams,
         | 
| 85 | 
            +
                )
         | 
| 86 | 
            +
                device = get_device(config.device, env)
         | 
| 87 | 
            +
                policy = make_policy(
         | 
| 88 | 
            +
                    args.algo,
         | 
| 89 | 
            +
                    env,
         | 
| 90 | 
            +
                    device,
         | 
| 91 | 
            +
                    load_path=model_path,
         | 
| 92 | 
            +
                    **config.policy_hyperparams,
         | 
| 93 | 
            +
                ).eval()
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                if args.deterministic is None:
         | 
| 96 | 
            +
                    deterministic = config.eval_params.get("deterministic", True)
         | 
| 97 | 
            +
                else:
         | 
| 98 | 
            +
                    deterministic = args.deterministic
         | 
| 99 | 
            +
                evaluate(
         | 
| 100 | 
            +
                    env,
         | 
| 101 | 
            +
                    policy,
         | 
| 102 | 
            +
                    args.n_episodes,
         | 
| 103 | 
            +
                    render=args.render,
         | 
| 104 | 
            +
                    deterministic=deterministic,
         | 
| 105 | 
            +
                )
         | 
    	
        environment.yml
    ADDED
    
    | @@ -0,0 +1,17 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: rl_algo_impls
         | 
| 2 | 
            +
            channels:
         | 
| 3 | 
            +
              - pytorch
         | 
| 4 | 
            +
              - conda-forge
         | 
| 5 | 
            +
              - nodefaults
         | 
| 6 | 
            +
            dependencies:
         | 
| 7 | 
            +
              - python=3.10.*
         | 
| 8 | 
            +
              - mamba
         | 
| 9 | 
            +
              - pip
         | 
| 10 | 
            +
              - poetry
         | 
| 11 | 
            +
              - pytorch
         | 
| 12 | 
            +
              - torchvision
         | 
| 13 | 
            +
              - torchaudio
         | 
| 14 | 
            +
              - cmake
         | 
| 15 | 
            +
              - swig
         | 
| 16 | 
            +
              - ipywidgets
         | 
| 17 | 
            +
              - black
         | 
    	
        hyperparams/dqn.yml
    ADDED
    
    | @@ -0,0 +1,117 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            CartPole-v1: &cartpole-defaults
         | 
| 2 | 
            +
              n_timesteps: !!float 5e4
         | 
| 3 | 
            +
              env_hyperparams:
         | 
| 4 | 
            +
                n_envs: 1
         | 
| 5 | 
            +
                rolling_length: 50
         | 
| 6 | 
            +
              policy_hyperparams:
         | 
| 7 | 
            +
                hidden_sizes: [256, 256]
         | 
| 8 | 
            +
              algo_hyperparams:
         | 
| 9 | 
            +
                learning_rate: !!float 2.3e-3
         | 
| 10 | 
            +
                batch_size: 64
         | 
| 11 | 
            +
                buffer_size: 100000
         | 
| 12 | 
            +
                learning_starts: 1000
         | 
| 13 | 
            +
                gamma: 0.99
         | 
| 14 | 
            +
                target_update_interval: 10
         | 
| 15 | 
            +
                train_freq: 256
         | 
| 16 | 
            +
                gradient_steps: 128
         | 
| 17 | 
            +
                exploration_fraction: 0.16
         | 
| 18 | 
            +
                exploration_final_eps: 0.04
         | 
| 19 | 
            +
              eval_params:
         | 
| 20 | 
            +
                step_freq: !!float 1e4
         | 
| 21 | 
            +
                n_episodes: 10
         | 
| 22 | 
            +
                save_best: true
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            CartPole-v0:
         | 
| 25 | 
            +
              <<: *cartpole-defaults
         | 
| 26 | 
            +
              n_timesteps: !!float 4e4
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            MountainCar-v0:
         | 
| 29 | 
            +
              n_timesteps: !!float 1.2e5
         | 
| 30 | 
            +
              env_hyperparams:
         | 
| 31 | 
            +
                rolling_length: 50
         | 
| 32 | 
            +
              policy_hyperparams:
         | 
| 33 | 
            +
                hidden_sizes: [256, 256]
         | 
| 34 | 
            +
              algo_hyperparams:
         | 
| 35 | 
            +
                learning_rate: !!float 4e-3
         | 
| 36 | 
            +
                batch_size: 128
         | 
| 37 | 
            +
                buffer_size: 10000
         | 
| 38 | 
            +
                learning_starts: 1000
         | 
| 39 | 
            +
                gamma: 0.98
         | 
| 40 | 
            +
                target_update_interval: 600
         | 
| 41 | 
            +
                train_freq: 16
         | 
| 42 | 
            +
                gradient_steps: 8
         | 
| 43 | 
            +
                exploration_fraction: 0.2
         | 
| 44 | 
            +
                exploration_final_eps: 0.07
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            Acrobot-v1:
         | 
| 47 | 
            +
              n_timesteps: !!float 1e5
         | 
| 48 | 
            +
              env_hyperparams:
         | 
| 49 | 
            +
                rolling_length: 10
         | 
| 50 | 
            +
              policy_hyperparams:
         | 
| 51 | 
            +
                hidden_sizes: [256, 256]
         | 
| 52 | 
            +
              algo_hyperparams:
         | 
| 53 | 
            +
                learning_rate: !!float 6.3e-4
         | 
| 54 | 
            +
                batch_size: 128
         | 
| 55 | 
            +
                buffer_size: 50000
         | 
| 56 | 
            +
                learning_starts: 0
         | 
| 57 | 
            +
                gamma: 0.99
         | 
| 58 | 
            +
                target_update_interval: 250
         | 
| 59 | 
            +
                train_freq: 4
         | 
| 60 | 
            +
                gradient_steps: -1
         | 
| 61 | 
            +
                exploration_fraction: 0.12
         | 
| 62 | 
            +
                exploration_final_eps: 0.1
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            LunarLander-v2:
         | 
| 65 | 
            +
              n_timesteps: !!float 5e5
         | 
| 66 | 
            +
              env_hyperparams:
         | 
| 67 | 
            +
                rolling_length: 10
         | 
| 68 | 
            +
              policy_hyperparams:
         | 
| 69 | 
            +
                hidden_sizes: [256, 256]
         | 
| 70 | 
            +
              algo_hyperparams:
         | 
| 71 | 
            +
                learning_rate: !!float 1e-4
         | 
| 72 | 
            +
                batch_size: 256
         | 
| 73 | 
            +
                buffer_size: 100000
         | 
| 74 | 
            +
                learning_starts: 10000
         | 
| 75 | 
            +
                gamma: 0.99
         | 
| 76 | 
            +
                target_update_interval: 250
         | 
| 77 | 
            +
                train_freq: 8
         | 
| 78 | 
            +
                gradient_steps: -1
         | 
| 79 | 
            +
                exploration_fraction: 0.12
         | 
| 80 | 
            +
                exploration_final_eps: 0.1
         | 
| 81 | 
            +
                max_grad_norm: 0.5
         | 
| 82 | 
            +
              eval_params:
         | 
| 83 | 
            +
                step_freq: 25_000
         | 
| 84 | 
            +
                n_episodes: 10
         | 
| 85 | 
            +
                save_best: true
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            SpaceInvadersNoFrameskip-v4: &atari-defaults
         | 
| 88 | 
            +
              n_timesteps: !!float 1e7
         | 
| 89 | 
            +
              env_hyperparams:
         | 
| 90 | 
            +
                frame_stack: 4
         | 
| 91 | 
            +
                no_reward_timeout_steps: 1_000
         | 
| 92 | 
            +
                n_envs: 8
         | 
| 93 | 
            +
                vec_env_class: "subproc"
         | 
| 94 | 
            +
                rolling_length: 20
         | 
| 95 | 
            +
              policy_hyperparams:
         | 
| 96 | 
            +
                hidden_sizes: [512]
         | 
| 97 | 
            +
              algo_hyperparams:
         | 
| 98 | 
            +
                buffer_size: 100000
         | 
| 99 | 
            +
                learning_rate: !!float 1e-4
         | 
| 100 | 
            +
                batch_size: 32
         | 
| 101 | 
            +
                learning_starts: 100000
         | 
| 102 | 
            +
                target_update_interval: 1000
         | 
| 103 | 
            +
                train_freq: 8
         | 
| 104 | 
            +
                gradient_steps: 2
         | 
| 105 | 
            +
                exploration_fraction: 0.1
         | 
| 106 | 
            +
                exploration_final_eps: 0.01
         | 
| 107 | 
            +
              eval_params:
         | 
| 108 | 
            +
                step_freq: 100_000
         | 
| 109 | 
            +
                n_episodes: 10
         | 
| 110 | 
            +
                save_best: true
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            BreakoutNoFrameskip-v4:
         | 
| 113 | 
            +
              <<: *atari-defaults
         | 
| 114 | 
            +
             | 
| 115 | 
            +
            PongNoFrameskip-v4:
         | 
| 116 | 
            +
              <<: *atari-defaults
         | 
| 117 | 
            +
              n_timesteps: !!float 2.5e6
         | 
    	
        hyperparams/ppo.yml
    ADDED
    
    | @@ -0,0 +1,202 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            CartPole-v1: &cartpole-defaults
         | 
| 2 | 
            +
              n_timesteps: !!float 1e5
         | 
| 3 | 
            +
              env_hyperparams:
         | 
| 4 | 
            +
                n_envs: 8
         | 
| 5 | 
            +
              algo_hyperparams:
         | 
| 6 | 
            +
                n_steps: 32
         | 
| 7 | 
            +
                batch_size: 256
         | 
| 8 | 
            +
                n_epochs: 20
         | 
| 9 | 
            +
                gae_lambda: 0.8
         | 
| 10 | 
            +
                gamma: 0.98
         | 
| 11 | 
            +
                ent_coef: 0.0
         | 
| 12 | 
            +
                learning_rate: 0.001
         | 
| 13 | 
            +
                learning_rate_decay: linear
         | 
| 14 | 
            +
                clip_range: 0.2
         | 
| 15 | 
            +
                clip_range_decay: linear
         | 
| 16 | 
            +
              eval_params:
         | 
| 17 | 
            +
                step_freq: !!float 2.5e4
         | 
| 18 | 
            +
                n_episodes: 10
         | 
| 19 | 
            +
                save_best: true
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            CartPole-v0:
         | 
| 22 | 
            +
              <<: *cartpole-defaults
         | 
| 23 | 
            +
              n_timesteps: !!float 5e4
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            MountainCar-v0:
         | 
| 26 | 
            +
              n_timesteps: !!float 1e6
         | 
| 27 | 
            +
              env_hyperparams:
         | 
| 28 | 
            +
                normalize: true
         | 
| 29 | 
            +
                n_envs: 16
         | 
| 30 | 
            +
              algo_hyperparams:
         | 
| 31 | 
            +
                n_steps: 16
         | 
| 32 | 
            +
                n_epochs: 4
         | 
| 33 | 
            +
                gae_lambda: 0.98
         | 
| 34 | 
            +
                gamma: 0.99
         | 
| 35 | 
            +
                ent_coef: 0.0
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            MountainCarContinuous-v0:
         | 
| 38 | 
            +
              n_timesteps: !!float 1e5
         | 
| 39 | 
            +
              env_hyperparams:
         | 
| 40 | 
            +
                normalize: true
         | 
| 41 | 
            +
                n_envs: 4
         | 
| 42 | 
            +
              policy_hyperparams:
         | 
| 43 | 
            +
                init_layers_orthogonal: false
         | 
| 44 | 
            +
                # log_std_init: -3.29
         | 
| 45 | 
            +
              algo_hyperparams:
         | 
| 46 | 
            +
                n_steps: 512
         | 
| 47 | 
            +
                batch_size: 256
         | 
| 48 | 
            +
                n_epochs: 10
         | 
| 49 | 
            +
                learning_rate: !!float 7.77e-5
         | 
| 50 | 
            +
                ent_coef: 0.01 # 0.00429
         | 
| 51 | 
            +
                ent_coef_decay: linear
         | 
| 52 | 
            +
                clip_range: 0.1
         | 
| 53 | 
            +
                gae_lambda: 0.9
         | 
| 54 | 
            +
                max_grad_norm: 5
         | 
| 55 | 
            +
                vf_coef: 0.19
         | 
| 56 | 
            +
                # use_sde: true
         | 
| 57 | 
            +
              eval_params:
         | 
| 58 | 
            +
                step_freq: 5000
         | 
| 59 | 
            +
                n_episodes: 10
         | 
| 60 | 
            +
                save_best: true
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            Acrobot-v1:
         | 
| 63 | 
            +
              n_timesteps: !!float 1e6
         | 
| 64 | 
            +
              env_hyperparams:
         | 
| 65 | 
            +
                n_envs: 16
         | 
| 66 | 
            +
                normalize: true
         | 
| 67 | 
            +
              algo_hyperparams:
         | 
| 68 | 
            +
                n_steps: 256
         | 
| 69 | 
            +
                n_epochs: 4
         | 
| 70 | 
            +
                gae_lambda: 0.94
         | 
| 71 | 
            +
                gamma: 0.99
         | 
| 72 | 
            +
                ent_coef: 0.0
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            LunarLander-v2:
         | 
| 75 | 
            +
              n_timesteps: !!float 1e6
         | 
| 76 | 
            +
              env_hyperparams:
         | 
| 77 | 
            +
                n_envs: 16
         | 
| 78 | 
            +
              algo_hyperparams:
         | 
| 79 | 
            +
                n_steps: 1024
         | 
| 80 | 
            +
                batch_size: 64
         | 
| 81 | 
            +
                n_epochs: 4
         | 
| 82 | 
            +
                gae_lambda: 0.98
         | 
| 83 | 
            +
                gamma: 0.999
         | 
| 84 | 
            +
                ent_coef: 0.01
         | 
| 85 | 
            +
                ent_coef_decay: linear
         | 
| 86 | 
            +
                normalize_advantage: false
         | 
| 87 | 
            +
              eval_params:
         | 
| 88 | 
            +
                step_freq: !!float 5e4
         | 
| 89 | 
            +
                n_episodes: 10
         | 
| 90 | 
            +
                save_best: true
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            CarRacing-v0:
         | 
| 93 | 
            +
              n_timesteps: !!float 4e6
         | 
| 94 | 
            +
              env_hyperparams:
         | 
| 95 | 
            +
                n_envs: 8
         | 
| 96 | 
            +
                frame_stack: 4
         | 
| 97 | 
            +
              policy_hyperparams:
         | 
| 98 | 
            +
                use_sde: true
         | 
| 99 | 
            +
                log_std_init: -2
         | 
| 100 | 
            +
                init_layers_orthogonal: false
         | 
| 101 | 
            +
                activation_fn: relu
         | 
| 102 | 
            +
                share_features_extractor: false
         | 
| 103 | 
            +
                cnn_feature_dim: 256
         | 
| 104 | 
            +
              algo_hyperparams:
         | 
| 105 | 
            +
                n_steps: 512
         | 
| 106 | 
            +
                batch_size: 128
         | 
| 107 | 
            +
                n_epochs: 10
         | 
| 108 | 
            +
                learning_rate: !!float 1e-4
         | 
| 109 | 
            +
                learning_rate_decay: linear
         | 
| 110 | 
            +
                gamma: 0.99
         | 
| 111 | 
            +
                gae_lambda: 0.95
         | 
| 112 | 
            +
                ent_coef: 0.0
         | 
| 113 | 
            +
                sde_sample_freq: 4
         | 
| 114 | 
            +
                max_grad_norm: 0.5
         | 
| 115 | 
            +
                vf_coef: 0.5
         | 
| 116 | 
            +
                clip_range: 0.2
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            # BreakoutNoFrameskip-v4
         | 
| 119 | 
            +
            # PongNoFrameskip-v4
         | 
| 120 | 
            +
            # SpaceInvadersNoFrameskip-v4
         | 
| 121 | 
            +
            # QbertNoFrameskip-v4
         | 
| 122 | 
            +
            atari: &atari-defaults
         | 
| 123 | 
            +
              n_timesteps: !!float 1e7
         | 
| 124 | 
            +
              policy_hyperparams:
         | 
| 125 | 
            +
                activation_fn: relu
         | 
| 126 | 
            +
              env_hyperparams: &atari-env-defaults
         | 
| 127 | 
            +
                n_envs: 8
         | 
| 128 | 
            +
                frame_stack: 4
         | 
| 129 | 
            +
                no_reward_timeout_steps: 1000
         | 
| 130 | 
            +
                no_reward_fire_steps: 500
         | 
| 131 | 
            +
                vec_env_class: subproc
         | 
| 132 | 
            +
              algo_hyperparams:
         | 
| 133 | 
            +
                n_steps: 128
         | 
| 134 | 
            +
                batch_size: 256
         | 
| 135 | 
            +
                n_epochs: 4
         | 
| 136 | 
            +
                learning_rate: !!float 2.5e-4
         | 
| 137 | 
            +
                learning_rate_decay: linear
         | 
| 138 | 
            +
                clip_range: 0.1
         | 
| 139 | 
            +
                clip_range_decay: linear
         | 
| 140 | 
            +
                vf_coef: 0.5
         | 
| 141 | 
            +
                ent_coef: 0.01
         | 
| 142 | 
            +
              eval_params:
         | 
| 143 | 
            +
                deterministic: false
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            HalfCheetahBulletEnv-v0: &pybullet-defaults
         | 
| 146 | 
            +
              n_timesteps: !!float 2e6
         | 
| 147 | 
            +
              env_hyperparams: &pybullet-env-defaults
         | 
| 148 | 
            +
                n_envs: 16
         | 
| 149 | 
            +
                normalize: true
         | 
| 150 | 
            +
              policy_hyperparams: &pybullet-policy-defaults
         | 
| 151 | 
            +
                pi_hidden_sizes: [256, 256]
         | 
| 152 | 
            +
                v_hidden_sizes: [256, 256]
         | 
| 153 | 
            +
                activation_fn: relu
         | 
| 154 | 
            +
              algo_hyperparams: &pybullet-algo-defaults
         | 
| 155 | 
            +
                n_steps: 512
         | 
| 156 | 
            +
                batch_size: 128
         | 
| 157 | 
            +
                n_epochs: 20
         | 
| 158 | 
            +
                gamma: 0.99
         | 
| 159 | 
            +
                gae_lambda: 0.9
         | 
| 160 | 
            +
                ent_coef: 0.0
         | 
| 161 | 
            +
                sde_sample_freq: 4
         | 
| 162 | 
            +
                max_grad_norm: 0.5
         | 
| 163 | 
            +
                vf_coef: 0.5
         | 
| 164 | 
            +
                learning_rate: !!float 3e-5
         | 
| 165 | 
            +
                clip_range: 0.4
         | 
| 166 | 
            +
             | 
| 167 | 
            +
            AntBulletEnv-v0:
         | 
| 168 | 
            +
              <<: *pybullet-defaults
         | 
| 169 | 
            +
              policy_hyperparams:
         | 
| 170 | 
            +
                <<: *pybullet-policy-defaults
         | 
| 171 | 
            +
              algo_hyperparams:
         | 
| 172 | 
            +
                <<: *pybullet-algo-defaults
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            Walker2DBulletEnv-v0:
         | 
| 175 | 
            +
              <<: *pybullet-defaults
         | 
| 176 | 
            +
              algo_hyperparams:
         | 
| 177 | 
            +
                <<: *pybullet-algo-defaults
         | 
| 178 | 
            +
                clip_range_decay: linear
         | 
| 179 | 
            +
             | 
| 180 | 
            +
            HopperBulletEnv-v0:
         | 
| 181 | 
            +
              <<: *pybullet-defaults
         | 
| 182 | 
            +
              algo_hyperparams:
         | 
| 183 | 
            +
                <<: *pybullet-algo-defaults
         | 
| 184 | 
            +
                clip_range_decay: linear
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            HumanoidBulletEnv-v0:
         | 
| 187 | 
            +
              <<: *pybullet-defaults
         | 
| 188 | 
            +
              n_timesteps: !!float 1e7
         | 
| 189 | 
            +
              env_hyperparams:
         | 
| 190 | 
            +
                <<: *pybullet-env-defaults
         | 
| 191 | 
            +
                n_envs: 8
         | 
| 192 | 
            +
              policy_hyperparams:
         | 
| 193 | 
            +
                <<: *pybullet-policy-defaults
         | 
| 194 | 
            +
                # log_std_init: -1
         | 
| 195 | 
            +
              algo_hyperparams:
         | 
| 196 | 
            +
                <<: *pybullet-algo-defaults
         | 
| 197 | 
            +
                n_steps: 2048
         | 
| 198 | 
            +
                batch_size: 64
         | 
| 199 | 
            +
                n_epochs: 10
         | 
| 200 | 
            +
                gae_lambda: 0.95
         | 
| 201 | 
            +
                learning_rate: !!float 2.5e-4
         | 
| 202 | 
            +
                clip_range: 0.2
         | 
    	
        hyperparams/vpg.yml
    ADDED
    
    | @@ -0,0 +1,157 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            CartPole-v1: &cartpole-defaults
         | 
| 2 | 
            +
              n_timesteps: !!float 4e5
         | 
| 3 | 
            +
              policy_hyperparams:
         | 
| 4 | 
            +
                hidden_sizes: [32]
         | 
| 5 | 
            +
              algo_hyperparams:
         | 
| 6 | 
            +
                steps_per_epoch: 4096
         | 
| 7 | 
            +
                pi_lr: 0.01
         | 
| 8 | 
            +
                gamma: 0.99
         | 
| 9 | 
            +
                lam: 1
         | 
| 10 | 
            +
                val_lr: 0.01
         | 
| 11 | 
            +
                train_v_iters: 80
         | 
| 12 | 
            +
              eval_params:
         | 
| 13 | 
            +
                step_freq: !!float 2.5e4
         | 
| 14 | 
            +
                n_episodes: 10
         | 
| 15 | 
            +
                save_best: true
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            CartPole-v0:
         | 
| 18 | 
            +
              <<: *cartpole-defaults
         | 
| 19 | 
            +
              n_timesteps: !!float 1e5
         | 
| 20 | 
            +
              algo_hyperparams:
         | 
| 21 | 
            +
                steps_per_epoch: 1024
         | 
| 22 | 
            +
                pi_lr: 0.01
         | 
| 23 | 
            +
                gamma: 0.99
         | 
| 24 | 
            +
                lam: 1
         | 
| 25 | 
            +
                val_lr: 0.01
         | 
| 26 | 
            +
                train_v_iters: 80
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            Acrobot-v1:
         | 
| 29 | 
            +
              n_timesteps: !!float 2e5
         | 
| 30 | 
            +
              policy_hyperparams:
         | 
| 31 | 
            +
                hidden_sizes: [32, 32]
         | 
| 32 | 
            +
              algo_hyperparams:
         | 
| 33 | 
            +
                steps_per_epoch: 2048
         | 
| 34 | 
            +
                pi_lr: 0.005
         | 
| 35 | 
            +
                gamma: 0.99
         | 
| 36 | 
            +
                lam: 0.97
         | 
| 37 | 
            +
                val_lr: 0.01
         | 
| 38 | 
            +
                train_v_iters: 80
         | 
| 39 | 
            +
                max_grad_norm: 0.5
         | 
| 40 | 
            +
              eval_params:
         | 
| 41 | 
            +
                step_freq: !!float 4e4
         | 
| 42 | 
            +
                n_episodes: 10
         | 
| 43 | 
            +
                save_best: true
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            LunarLander-v2:
         | 
| 46 | 
            +
              n_timesteps: !!float 4e6
         | 
| 47 | 
            +
              policy_hyperparams:
         | 
| 48 | 
            +
                hidden_sizes: [256, 256]
         | 
| 49 | 
            +
              algo_hyperparams:
         | 
| 50 | 
            +
                steps_per_epoch: 2048
         | 
| 51 | 
            +
                pi_lr: 0.0001
         | 
| 52 | 
            +
                gamma: 0.999
         | 
| 53 | 
            +
                lam: 0.97
         | 
| 54 | 
            +
                val_lr: 0.0001
         | 
| 55 | 
            +
                train_v_iters: 80
         | 
| 56 | 
            +
                max_grad_norm: 0.5
         | 
| 57 | 
            +
              eval_params:
         | 
| 58 | 
            +
                step_freq: !!float 5e4
         | 
| 59 | 
            +
                n_episodes: 10
         | 
| 60 | 
            +
                save_best: true
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            CarRacing-v0:
         | 
| 63 | 
            +
              n_timesteps: !!float 4e6
         | 
| 64 | 
            +
              env_hyperparams:
         | 
| 65 | 
            +
                frame_stack: 4
         | 
| 66 | 
            +
                n_envs: 4
         | 
| 67 | 
            +
                vec_env_class: "dummy"
         | 
| 68 | 
            +
              policy_hyperparams:
         | 
| 69 | 
            +
                hidden_sizes: [256, 256]
         | 
| 70 | 
            +
              algo_hyperparams:
         | 
| 71 | 
            +
                steps_per_epoch: 4000
         | 
| 72 | 
            +
                pi_lr: !!float 7e-5
         | 
| 73 | 
            +
                gamma: 0.99
         | 
| 74 | 
            +
                lam: 0.95
         | 
| 75 | 
            +
                val_lr: !!float 1e-4
         | 
| 76 | 
            +
                train_v_iters: 40
         | 
| 77 | 
            +
                max_grad_norm: 0.5
         | 
| 78 | 
            +
              eval_params:
         | 
| 79 | 
            +
                step_freq: !!float 5e4
         | 
| 80 | 
            +
                n_episodes: 10
         | 
| 81 | 
            +
                save_best: true
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            HalfCheetahBulletEnv-v0: &pybullet-defaults
         | 
| 84 | 
            +
              n_timesteps: !!float 2e6
         | 
| 85 | 
            +
              policy_hyperparams:
         | 
| 86 | 
            +
                hidden_sizes: [64, 64]
         | 
| 87 | 
            +
                init_layers_orthogonal: true
         | 
| 88 | 
            +
              algo_hyperparams:
         | 
| 89 | 
            +
                steps_per_epoch: 4000
         | 
| 90 | 
            +
                pi_lr: !!float 3e-4
         | 
| 91 | 
            +
                gamma: 0.99
         | 
| 92 | 
            +
                lam: 0.97
         | 
| 93 | 
            +
                val_lr: !!float 1e-3
         | 
| 94 | 
            +
                train_v_iters: 80
         | 
| 95 | 
            +
                max_grad_norm: 0.5
         | 
| 96 | 
            +
              eval_params:
         | 
| 97 | 
            +
                step_freq: !!float 1e5
         | 
| 98 | 
            +
                n_episodes: 10
         | 
| 99 | 
            +
                save_best: true
         | 
| 100 | 
            +
             | 
| 101 | 
            +
            HopperBulletEnv-v0:
         | 
| 102 | 
            +
              <<: *pybullet-defaults
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            AntBulletEnv-v0:
         | 
| 105 | 
            +
              <<: *pybullet-defaults
         | 
| 106 | 
            +
              policy_hyperparams:
         | 
| 107 | 
            +
                hidden_sizes: [400, 300]
         | 
| 108 | 
            +
              algo_hyperparams:
         | 
| 109 | 
            +
                pi_lr: !!float 7e-4
         | 
| 110 | 
            +
                gamma: 0.99
         | 
| 111 | 
            +
                lam: 0.97
         | 
| 112 | 
            +
                val_lr: !!float 7e-3
         | 
| 113 | 
            +
                train_v_iters: 80
         | 
| 114 | 
            +
                max_grad_norm: 0.5
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            FrozenLake-v1:
         | 
| 117 | 
            +
              n_timesteps: !!float 8e5
         | 
| 118 | 
            +
              env_params:
         | 
| 119 | 
            +
                make_kwargs:
         | 
| 120 | 
            +
                  map_name: 8x8
         | 
| 121 | 
            +
                  is_slippery: true
         | 
| 122 | 
            +
              policy_hyperparams:
         | 
| 123 | 
            +
                hidden_sizes: [64]
         | 
| 124 | 
            +
              algo_hyperparams:
         | 
| 125 | 
            +
                steps_per_epoch: 2048
         | 
| 126 | 
            +
                pi_lr: 0.01
         | 
| 127 | 
            +
                gamma: 0.99
         | 
| 128 | 
            +
                lam: 0.98
         | 
| 129 | 
            +
                val_lr: 0.01
         | 
| 130 | 
            +
                train_v_iters: 80
         | 
| 131 | 
            +
                max_grad_norm: 0.5
         | 
| 132 | 
            +
              eval_params:
         | 
| 133 | 
            +
                step_freq: !!float 5e4
         | 
| 134 | 
            +
                n_episodes: 10
         | 
| 135 | 
            +
                save_best: true
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            SpaceInvadersNoFrameskip-v4: &atari-defaults
         | 
| 138 | 
            +
              n_timesteps: !!float 1e7
         | 
| 139 | 
            +
              env_hyperparams:
         | 
| 140 | 
            +
                frame_stack: 4
         | 
| 141 | 
            +
                no_reward_timeout_steps: 1_000
         | 
| 142 | 
            +
                n_envs: 8
         | 
| 143 | 
            +
                vec_env_class: "subproc"
         | 
| 144 | 
            +
              policy_hyperparams:
         | 
| 145 | 
            +
                hidden_sizes: [256, 256]
         | 
| 146 | 
            +
              algo_hyperparams:
         | 
| 147 | 
            +
                steps_per_epoch: 4096
         | 
| 148 | 
            +
                pi_lr: !!float 1e-4
         | 
| 149 | 
            +
                gamma: 0.99
         | 
| 150 | 
            +
                lam: 0.95
         | 
| 151 | 
            +
                val_lr: !!float 2e-4
         | 
| 152 | 
            +
                train_v_iters: 80
         | 
| 153 | 
            +
                max_grad_norm: 0.5
         | 
| 154 | 
            +
              eval_params:
         | 
| 155 | 
            +
                step_freq: !!float 1e5
         | 
| 156 | 
            +
                n_episodes: 10
         | 
| 157 | 
            +
                save_best: true
         | 
    	
        lambda_labs/benchmark.sh
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            source benchmarks/train_loop.sh
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # export WANDB_PROJECT_NAME="rl-algo-impls"
         | 
| 4 | 
            +
            export VIRTUAL_DISPLAY=1
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ALGOS=(
         | 
| 9 | 
            +
                # "vpg"
         | 
| 10 | 
            +
                # "dqn"
         | 
| 11 | 
            +
                "ppo"
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
            ENVS=(
         | 
| 14 | 
            +
                # Basic
         | 
| 15 | 
            +
                "CartPole-v1"
         | 
| 16 | 
            +
                "MountainCar-v0"
         | 
| 17 | 
            +
                "MountainCarContinuous-v0"
         | 
| 18 | 
            +
                "Acrobot-v1"
         | 
| 19 | 
            +
                "LunarLander-v2"
         | 
| 20 | 
            +
                # PyBullet
         | 
| 21 | 
            +
                "HalfCheetahBulletEnv-v0"
         | 
| 22 | 
            +
                "AntBulletEnv-v0"
         | 
| 23 | 
            +
                "Walker2DBulletEnv-v0"
         | 
| 24 | 
            +
                "HopperBulletEnv-v0"
         | 
| 25 | 
            +
                # CarRacing
         | 
| 26 | 
            +
                "CarRacing-v0"
         | 
| 27 | 
            +
                # Atari
         | 
| 28 | 
            +
                "PongNoFrameskip-v4"
         | 
| 29 | 
            +
                "BreakoutNoFrameskip-v4"
         | 
| 30 | 
            +
                "SpaceInvadersNoFrameskip-v4"
         | 
| 31 | 
            +
                "QbertNoFrameskip-v4"
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
            train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
         | 
    	
        lambda_labs/lambda_requirements.txt
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            scipy >= 1.10.0, < 1.11
         | 
| 2 | 
            +
            tensorboard >= ^2.11.0, < 2.12
         | 
| 3 | 
            +
            AutoROM.accept-rom-license >= 0.4.2, < 0.5
         | 
| 4 | 
            +
            stable-baselines3[extra] >= 1.7.0, < 1.8
         | 
| 5 | 
            +
            gym[box2d] >= 0.21.0, < 0.22
         | 
| 6 | 
            +
            pyglet == 1.5.27
         | 
| 7 | 
            +
            wandb >= 0.13.9, < 0.14
         | 
| 8 | 
            +
            pyvirtualdisplay == 3.0
         | 
| 9 | 
            +
            pybullet >= 3.2.5, < 3.3
         | 
    	
        lambda_labs/setup.sh
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            sudo apt update
         | 
| 2 | 
            +
            sudo apt install -y python-opengl
         | 
| 3 | 
            +
            sudo apt install -y ffmpeg
         | 
| 4 | 
            +
            sudo apt install -y xvfb
         | 
| 5 | 
            +
            sudo apt install -y swig
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            python3 -m pip install --upgrade pip
         | 
| 8 | 
            +
            pip install --upgrade torch torchvision torchaudio
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            pip install --upgrade -r ~/rl-algo-impls/lambda_labs/lambda_requirements.txt
         | 
    	
        poetry.lock
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        ppo/policy.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv
         | 
| 2 | 
            +
            from typing import Optional, Sequence
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from gym.spaces import Box, Discrete
         | 
| 5 | 
            +
            from shared.policy.on_policy import ActorCritic
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            class PPOActorCritic(ActorCritic):
         | 
| 9 | 
            +
                def __init__(
         | 
| 10 | 
            +
                    self,
         | 
| 11 | 
            +
                    env: VecEnv,
         | 
| 12 | 
            +
                    pi_hidden_sizes: Optional[Sequence[int]] = None,
         | 
| 13 | 
            +
                    v_hidden_sizes: Optional[Sequence[int]] = None,
         | 
| 14 | 
            +
                    **kwargs,
         | 
| 15 | 
            +
                ) -> None:
         | 
| 16 | 
            +
                    obs_space = env.observation_space
         | 
| 17 | 
            +
                    if isinstance(obs_space, Box):
         | 
| 18 | 
            +
                        if len(obs_space.shape) == 3:
         | 
| 19 | 
            +
                            pi_hidden_sizes = pi_hidden_sizes or []
         | 
| 20 | 
            +
                            v_hidden_sizes = v_hidden_sizes or []
         | 
| 21 | 
            +
                        elif len(obs_space.shape) == 1:
         | 
| 22 | 
            +
                            pi_hidden_sizes = pi_hidden_sizes or [64, 64]
         | 
| 23 | 
            +
                            v_hidden_sizes = v_hidden_sizes or [64, 64]
         | 
| 24 | 
            +
                        else:
         | 
| 25 | 
            +
                            raise ValueError(f"Unsupported observation space: {obs_space}")
         | 
| 26 | 
            +
                    elif isinstance(obs_space, Discrete):
         | 
| 27 | 
            +
                        pi_hidden_sizes = pi_hidden_sizes or [64]
         | 
| 28 | 
            +
                        v_hidden_sizes = v_hidden_sizes or [64]
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        raise ValueError(f"Unsupported observation space: {obs_space}")
         | 
| 31 | 
            +
                    super().__init__(
         | 
| 32 | 
            +
                        env,
         | 
| 33 | 
            +
                        pi_hidden_sizes,
         | 
| 34 | 
            +
                        v_hidden_sizes,
         | 
| 35 | 
            +
                        **kwargs,
         | 
| 36 | 
            +
                    )
         | 
    	
        ppo/ppo.py
    ADDED
    
    | @@ -0,0 +1,367 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from dataclasses import asdict, dataclass
         | 
| 6 | 
            +
            from torch.optim import Adam
         | 
| 7 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
         | 
| 8 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 9 | 
            +
            from typing import List, Optional, Sequence, NamedTuple, TypeVar
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from shared.algorithm import Algorithm
         | 
| 12 | 
            +
            from shared.callbacks.callback import Callback
         | 
| 13 | 
            +
            from shared.policy.on_policy import ActorCritic
         | 
| 14 | 
            +
            from shared.schedule import constant_schedule, linear_schedule
         | 
| 15 | 
            +
            from shared.trajectory import Trajectory as BaseTrajectory
         | 
| 16 | 
            +
            from shared.utils import discounted_cumsum
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @dataclass
         | 
| 20 | 
            +
            class PPOTrajectory(BaseTrajectory):
         | 
| 21 | 
            +
                logp_a: List[float]
         | 
| 22 | 
            +
                next_obs: Optional[np.ndarray]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(self) -> None:
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.logp_a = []
         | 
| 27 | 
            +
                    self.next_obs = None
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def add(
         | 
| 30 | 
            +
                    self,
         | 
| 31 | 
            +
                    obs: np.ndarray,
         | 
| 32 | 
            +
                    act: np.ndarray,
         | 
| 33 | 
            +
                    next_obs: np.ndarray,
         | 
| 34 | 
            +
                    rew: float,
         | 
| 35 | 
            +
                    terminated: bool,
         | 
| 36 | 
            +
                    v: float,
         | 
| 37 | 
            +
                    logp_a: float,
         | 
| 38 | 
            +
                ):
         | 
| 39 | 
            +
                    super().add(obs, act, rew, v)
         | 
| 40 | 
            +
                    self.next_obs = next_obs if not terminated else None
         | 
| 41 | 
            +
                    self.terminated = terminated
         | 
| 42 | 
            +
                    self.logp_a.append(logp_a)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            class TrajectoryAccumulator:
         | 
| 46 | 
            +
                def __init__(self, num_envs: int) -> None:
         | 
| 47 | 
            +
                    self.num_envs = num_envs
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.trajectories_ = []
         | 
| 50 | 
            +
                    self.current_trajectories_ = [PPOTrajectory() for _ in range(num_envs)]
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def step(
         | 
| 53 | 
            +
                    self,
         | 
| 54 | 
            +
                    obs: VecEnvObs,
         | 
| 55 | 
            +
                    action: np.ndarray,
         | 
| 56 | 
            +
                    next_obs: VecEnvObs,
         | 
| 57 | 
            +
                    reward: np.ndarray,
         | 
| 58 | 
            +
                    done: np.ndarray,
         | 
| 59 | 
            +
                    val: np.ndarray,
         | 
| 60 | 
            +
                    logp_a: np.ndarray,
         | 
| 61 | 
            +
                ) -> None:
         | 
| 62 | 
            +
                    assert isinstance(obs, np.ndarray)
         | 
| 63 | 
            +
                    assert isinstance(next_obs, np.ndarray)
         | 
| 64 | 
            +
                    for i, trajectory in enumerate(self.current_trajectories_):
         | 
| 65 | 
            +
                        # TODO: Eventually take advantage of terminated/truncated differentiation in
         | 
| 66 | 
            +
                        # later versions of gym.
         | 
| 67 | 
            +
                        trajectory.add(
         | 
| 68 | 
            +
                            obs[i], action[i], next_obs[i], reward[i], done[i], val[i], logp_a[i]
         | 
| 69 | 
            +
                        )
         | 
| 70 | 
            +
                        if done[i]:
         | 
| 71 | 
            +
                            self.trajectories_.append(trajectory)
         | 
| 72 | 
            +
                            self.current_trajectories_[i] = PPOTrajectory()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                @property
         | 
| 75 | 
            +
                def all_trajectories(self) -> List[PPOTrajectory]:
         | 
| 76 | 
            +
                    return self.trajectories_ + list(
         | 
| 77 | 
            +
                        filter(lambda t: len(t), self.current_trajectories_)
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            class RtgAdvantage(NamedTuple):
         | 
| 82 | 
            +
                rewards_to_go: torch.Tensor
         | 
| 83 | 
            +
                advantage: torch.Tensor
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            class TrainStepStats(NamedTuple):
         | 
| 87 | 
            +
                loss: float
         | 
| 88 | 
            +
                pi_loss: float
         | 
| 89 | 
            +
                v_loss: float
         | 
| 90 | 
            +
                entropy_loss: float
         | 
| 91 | 
            +
                approx_kl: float
         | 
| 92 | 
            +
                clipped_frac: float
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            @dataclass
         | 
| 96 | 
            +
            class TrainStats:
         | 
| 97 | 
            +
                loss: float
         | 
| 98 | 
            +
                pi_loss: float
         | 
| 99 | 
            +
                v_loss: float
         | 
| 100 | 
            +
                entropy_loss: float
         | 
| 101 | 
            +
                approx_kl: float
         | 
| 102 | 
            +
                clipped_frac: float
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def __init__(self, step_stats: List[TrainStepStats]) -> None:
         | 
| 105 | 
            +
                    self.loss = np.mean([s.loss for s in step_stats]).item()
         | 
| 106 | 
            +
                    self.pi_loss = np.mean([s.pi_loss for s in step_stats]).item()
         | 
| 107 | 
            +
                    self.v_loss = np.mean([s.v_loss for s in step_stats]).item()
         | 
| 108 | 
            +
                    self.entropy_loss = np.mean([s.entropy_loss for s in step_stats]).item()
         | 
| 109 | 
            +
                    self.approx_kl = np.mean([s.approx_kl for s in step_stats]).item()
         | 
| 110 | 
            +
                    self.clipped_frac = np.mean([s.clipped_frac for s in step_stats]).item()
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None:
         | 
| 113 | 
            +
                    tb_writer.add_scalars("losses", asdict(self), global_step=global_step)
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def __repr__(self) -> str:
         | 
| 116 | 
            +
                    return " | ".join(
         | 
| 117 | 
            +
                        [
         | 
| 118 | 
            +
                            f"Loss: {round(self.loss, 2)}",
         | 
| 119 | 
            +
                            f"Pi L: {round(self.pi_loss, 2)}",
         | 
| 120 | 
            +
                            f"V L: {round(self.v_loss, 2)}",
         | 
| 121 | 
            +
                            f"E L: {round(self.entropy_loss, 2)}",
         | 
| 122 | 
            +
                            f"Apx KL Div: {round(self.approx_kl, 2)}",
         | 
| 123 | 
            +
                            f"Clip Frac: {round(self.clipped_frac, 2)}",
         | 
| 124 | 
            +
                        ]
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            PPOSelf = TypeVar("PPOSelf", bound="PPO")
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            class PPO(Algorithm):
         | 
| 132 | 
            +
                def __init__(
         | 
| 133 | 
            +
                    self,
         | 
| 134 | 
            +
                    policy: ActorCritic,
         | 
| 135 | 
            +
                    env: VecEnv,
         | 
| 136 | 
            +
                    device: torch.device,
         | 
| 137 | 
            +
                    tb_writer: SummaryWriter,
         | 
| 138 | 
            +
                    learning_rate: float = 3e-4,
         | 
| 139 | 
            +
                    learning_rate_decay: str = "none",
         | 
| 140 | 
            +
                    n_steps: int = 2048,
         | 
| 141 | 
            +
                    batch_size: int = 64,
         | 
| 142 | 
            +
                    n_epochs: int = 10,
         | 
| 143 | 
            +
                    gamma: float = 0.99,
         | 
| 144 | 
            +
                    gae_lambda: float = 0.95,
         | 
| 145 | 
            +
                    clip_range: float = 0.2,
         | 
| 146 | 
            +
                    clip_range_decay: str = "none",
         | 
| 147 | 
            +
                    clip_range_vf: Optional[float] = None,
         | 
| 148 | 
            +
                    clip_range_vf_decay: str = "none",
         | 
| 149 | 
            +
                    normalize_advantage: bool = True,
         | 
| 150 | 
            +
                    ent_coef: float = 0.0,
         | 
| 151 | 
            +
                    ent_coef_decay: str = "none",
         | 
| 152 | 
            +
                    vf_coef: float = 0.5,
         | 
| 153 | 
            +
                    max_grad_norm: float = 0.5,
         | 
| 154 | 
            +
                    update_rtg_between_epochs: bool = False,
         | 
| 155 | 
            +
                    sde_sample_freq: int = -1,
         | 
| 156 | 
            +
                ) -> None:
         | 
| 157 | 
            +
                    super().__init__(policy, env, device, tb_writer)
         | 
| 158 | 
            +
                    self.policy = policy
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    self.gamma = gamma
         | 
| 161 | 
            +
                    self.gae_lambda = gae_lambda
         | 
| 162 | 
            +
                    self.optimizer = Adam(self.policy.parameters(), lr=learning_rate)
         | 
| 163 | 
            +
                    self.lr_schedule = (
         | 
| 164 | 
            +
                        linear_schedule(learning_rate, 0)
         | 
| 165 | 
            +
                        if learning_rate_decay == "linear"
         | 
| 166 | 
            +
                        else constant_schedule(learning_rate)
         | 
| 167 | 
            +
                    )
         | 
| 168 | 
            +
                    self.max_grad_norm = max_grad_norm
         | 
| 169 | 
            +
                    self.clip_range_schedule = (
         | 
| 170 | 
            +
                        linear_schedule(clip_range, 0)
         | 
| 171 | 
            +
                        if clip_range_decay == "linear"
         | 
| 172 | 
            +
                        else constant_schedule(clip_range)
         | 
| 173 | 
            +
                    )
         | 
| 174 | 
            +
                    self.clip_range_vf_schedule = None
         | 
| 175 | 
            +
                    if clip_range_vf:
         | 
| 176 | 
            +
                        self.clip_range_vf_schedule = (
         | 
| 177 | 
            +
                            linear_schedule(clip_range_vf, 0)
         | 
| 178 | 
            +
                            if clip_range_vf_decay == "linear"
         | 
| 179 | 
            +
                            else constant_schedule(clip_range_vf)
         | 
| 180 | 
            +
                        )
         | 
| 181 | 
            +
                    self.normalize_advantage = normalize_advantage
         | 
| 182 | 
            +
                    self.ent_coef_schedule = (
         | 
| 183 | 
            +
                        linear_schedule(ent_coef, 0)
         | 
| 184 | 
            +
                        if ent_coef_decay == "linear"
         | 
| 185 | 
            +
                        else constant_schedule(ent_coef)
         | 
| 186 | 
            +
                    )
         | 
| 187 | 
            +
                    self.vf_coef = vf_coef
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    self.n_steps = n_steps
         | 
| 190 | 
            +
                    self.batch_size = batch_size
         | 
| 191 | 
            +
                    self.n_epochs = n_epochs
         | 
| 192 | 
            +
                    self.sde_sample_freq = sde_sample_freq
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    self.update_rtg_between_epochs = update_rtg_between_epochs
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                def learn(
         | 
| 197 | 
            +
                    self: PPOSelf,
         | 
| 198 | 
            +
                    total_timesteps: int,
         | 
| 199 | 
            +
                    callback: Optional[Callback] = None,
         | 
| 200 | 
            +
                ) -> PPOSelf:
         | 
| 201 | 
            +
                    obs = self.env.reset()
         | 
| 202 | 
            +
                    ts_elapsed = 0
         | 
| 203 | 
            +
                    while ts_elapsed < total_timesteps:
         | 
| 204 | 
            +
                        accumulator = self._collect_trajectories(obs)
         | 
| 205 | 
            +
                        progress = ts_elapsed / total_timesteps
         | 
| 206 | 
            +
                        train_stats = self.train(accumulator.all_trajectories, progress)
         | 
| 207 | 
            +
                        rollout_steps = self.n_steps * self.env.num_envs
         | 
| 208 | 
            +
                        ts_elapsed += rollout_steps
         | 
| 209 | 
            +
                        train_stats.write_to_tensorboard(self.tb_writer, ts_elapsed)
         | 
| 210 | 
            +
                        if callback:
         | 
| 211 | 
            +
                            callback.on_step(timesteps_elapsed=rollout_steps)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    return self
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                def _collect_trajectories(self, obs: VecEnvObs) -> TrajectoryAccumulator:
         | 
| 216 | 
            +
                    self.policy.eval()
         | 
| 217 | 
            +
                    accumulator = TrajectoryAccumulator(self.env.num_envs)
         | 
| 218 | 
            +
                    self.policy.reset_noise()
         | 
| 219 | 
            +
                    for i in range(self.n_steps):
         | 
| 220 | 
            +
                        if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0:
         | 
| 221 | 
            +
                            self.policy.reset_noise()
         | 
| 222 | 
            +
                        action, value, logp_a, clamped_action = self.policy.step(obs)
         | 
| 223 | 
            +
                        next_obs, reward, done, _ = self.env.step(clamped_action)
         | 
| 224 | 
            +
                        accumulator.step(obs, action, next_obs, reward, done, value, logp_a)
         | 
| 225 | 
            +
                        obs = next_obs
         | 
| 226 | 
            +
                    return accumulator
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def train(self, trajectories: List[PPOTrajectory], progress: float) -> TrainStats:
         | 
| 229 | 
            +
                    self.policy.train()
         | 
| 230 | 
            +
                    learning_rate = self.lr_schedule(progress)
         | 
| 231 | 
            +
                    self.optimizer.param_groups[0]["lr"] = learning_rate
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    pi_clip = self.clip_range_schedule(progress)
         | 
| 234 | 
            +
                    v_clip = (
         | 
| 235 | 
            +
                        self.clip_range_vf_schedule(progress)
         | 
| 236 | 
            +
                        if self.clip_range_vf_schedule
         | 
| 237 | 
            +
                        else None
         | 
| 238 | 
            +
                    )
         | 
| 239 | 
            +
                    ent_coef = self.ent_coef_schedule(progress)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    obs = torch.as_tensor(
         | 
| 242 | 
            +
                        np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device
         | 
| 243 | 
            +
                    )
         | 
| 244 | 
            +
                    act = torch.as_tensor(
         | 
| 245 | 
            +
                        np.concatenate([np.array(t.act) for t in trajectories]), device=self.device
         | 
| 246 | 
            +
                    )
         | 
| 247 | 
            +
                    rtg, adv = self._compute_rtg_and_advantage(trajectories)
         | 
| 248 | 
            +
                    orig_v = torch.as_tensor(
         | 
| 249 | 
            +
                        np.concatenate([np.array(t.v) for t in trajectories]), device=self.device
         | 
| 250 | 
            +
                    )
         | 
| 251 | 
            +
                    orig_logp_a = torch.as_tensor(
         | 
| 252 | 
            +
                        np.concatenate([np.array(t.logp_a) for t in trajectories]),
         | 
| 253 | 
            +
                        device=self.device,
         | 
| 254 | 
            +
                    )
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    step_stats = []
         | 
| 257 | 
            +
                    for _ in range(self.n_epochs):
         | 
| 258 | 
            +
                        if self.update_rtg_between_epochs:
         | 
| 259 | 
            +
                            rtg, adv = self._compute_rtg_and_advantage(trajectories)
         | 
| 260 | 
            +
                        else:
         | 
| 261 | 
            +
                            adv = self._compute_advantage(trajectories)
         | 
| 262 | 
            +
                        idxs = torch.randperm(len(obs))
         | 
| 263 | 
            +
                        for i in range(0, len(obs), self.batch_size):
         | 
| 264 | 
            +
                            mb_idxs = idxs[i : i + self.batch_size]
         | 
| 265 | 
            +
                            mb_adv = adv[mb_idxs]
         | 
| 266 | 
            +
                            if self.normalize_advantage:
         | 
| 267 | 
            +
                                mb_adv = (mb_adv - mb_adv.mean(-1)) / (mb_adv.std(-1) + 1e-8)
         | 
| 268 | 
            +
                            step_stats.append(
         | 
| 269 | 
            +
                                self._train_step(
         | 
| 270 | 
            +
                                    pi_clip,
         | 
| 271 | 
            +
                                    v_clip,
         | 
| 272 | 
            +
                                    ent_coef,
         | 
| 273 | 
            +
                                    obs[mb_idxs],
         | 
| 274 | 
            +
                                    act[mb_idxs],
         | 
| 275 | 
            +
                                    rtg[mb_idxs],
         | 
| 276 | 
            +
                                    mb_adv,
         | 
| 277 | 
            +
                                    orig_v[mb_idxs],
         | 
| 278 | 
            +
                                    orig_logp_a[mb_idxs],
         | 
| 279 | 
            +
                                )
         | 
| 280 | 
            +
                            )
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    return TrainStats(step_stats)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def _train_step(
         | 
| 285 | 
            +
                    self,
         | 
| 286 | 
            +
                    pi_clip: float,
         | 
| 287 | 
            +
                    v_clip: Optional[float],
         | 
| 288 | 
            +
                    ent_coef: float,
         | 
| 289 | 
            +
                    obs: torch.Tensor,
         | 
| 290 | 
            +
                    act: torch.Tensor,
         | 
| 291 | 
            +
                    rtg: torch.Tensor,
         | 
| 292 | 
            +
                    adv: torch.Tensor,
         | 
| 293 | 
            +
                    orig_v: torch.Tensor,
         | 
| 294 | 
            +
                    orig_logp_a: torch.Tensor,
         | 
| 295 | 
            +
                ) -> TrainStepStats:
         | 
| 296 | 
            +
                    logp_a, entropy, v = self.policy(obs, act)
         | 
| 297 | 
            +
                    logratio = logp_a - orig_logp_a
         | 
| 298 | 
            +
                    ratio = torch.exp(logratio)
         | 
| 299 | 
            +
                    clip_ratio = torch.clamp(ratio, min=1 - pi_clip, max=1 + pi_clip)
         | 
| 300 | 
            +
                    pi_loss = torch.maximum(-ratio * adv, -clip_ratio * adv).mean()
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    v_loss = (v - rtg).pow(2)
         | 
| 303 | 
            +
                    if v_clip:
         | 
| 304 | 
            +
                        v_clipped = (torch.clamp(v, orig_v - v_clip, orig_v + v_clip) - rtg).pow(2)
         | 
| 305 | 
            +
                        v_loss = torch.maximum(v_loss, v_clipped)
         | 
| 306 | 
            +
                    v_loss = v_loss.mean()
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    entropy_loss = entropy.mean()
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    loss = pi_loss - ent_coef * entropy_loss + self.vf_coef * v_loss
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    self.optimizer.zero_grad()
         | 
| 313 | 
            +
                    loss.backward()
         | 
| 314 | 
            +
                    nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
         | 
| 315 | 
            +
                    self.optimizer.step()
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    with torch.no_grad():
         | 
| 318 | 
            +
                        approx_kl = ((ratio - 1) - logratio).mean().cpu().numpy().item()
         | 
| 319 | 
            +
                        clipped_frac = (
         | 
| 320 | 
            +
                            ((ratio - 1).abs() > pi_clip).float().mean().cpu().numpy().item()
         | 
| 321 | 
            +
                        )
         | 
| 322 | 
            +
                    return TrainStepStats(
         | 
| 323 | 
            +
                        loss.item(),
         | 
| 324 | 
            +
                        pi_loss.item(),
         | 
| 325 | 
            +
                        v_loss.item(),
         | 
| 326 | 
            +
                        entropy_loss.item(),
         | 
| 327 | 
            +
                        approx_kl,
         | 
| 328 | 
            +
                        clipped_frac,
         | 
| 329 | 
            +
                    )
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                def _compute_advantage(self, trajectories: Sequence[PPOTrajectory]) -> torch.Tensor:
         | 
| 332 | 
            +
                    advantage = []
         | 
| 333 | 
            +
                    for traj in trajectories:
         | 
| 334 | 
            +
                        last_val = 0
         | 
| 335 | 
            +
                        if not traj.terminated and traj.next_obs is not None:
         | 
| 336 | 
            +
                            last_val = self.policy.value(np.array(traj.next_obs))
         | 
| 337 | 
            +
                        rew = np.append(np.array(traj.rew), last_val)
         | 
| 338 | 
            +
                        v = np.append(np.array(traj.v), last_val)
         | 
| 339 | 
            +
                        deltas = rew[:-1] + self.gamma * v[1:] - v[:-1]
         | 
| 340 | 
            +
                        advantage.append(discounted_cumsum(deltas, self.gamma * self.gae_lambda))
         | 
| 341 | 
            +
                    return torch.as_tensor(
         | 
| 342 | 
            +
                        np.concatenate(advantage), dtype=torch.float32, device=self.device
         | 
| 343 | 
            +
                    )
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                def _compute_rtg_and_advantage(
         | 
| 346 | 
            +
                    self, trajectories: Sequence[PPOTrajectory]
         | 
| 347 | 
            +
                ) -> RtgAdvantage:
         | 
| 348 | 
            +
                    rewards_to_go = []
         | 
| 349 | 
            +
                    advantages = []
         | 
| 350 | 
            +
                    for traj in trajectories:
         | 
| 351 | 
            +
                        last_val = 0
         | 
| 352 | 
            +
                        if not traj.terminated and traj.next_obs is not None:
         | 
| 353 | 
            +
                            last_val = self.policy.value(np.array(traj.next_obs))
         | 
| 354 | 
            +
                        rew = np.append(np.array(traj.rew), last_val)
         | 
| 355 | 
            +
                        v = np.append(np.array(traj.v), last_val)
         | 
| 356 | 
            +
                        deltas = rew[:-1] + self.gamma * v[1:] - v[:-1]
         | 
| 357 | 
            +
                        adv = discounted_cumsum(deltas, self.gamma * self.gae_lambda)
         | 
| 358 | 
            +
                        advantages.append(adv)
         | 
| 359 | 
            +
                        rewards_to_go.append(v[:-1] + adv)
         | 
| 360 | 
            +
                    return RtgAdvantage(
         | 
| 361 | 
            +
                        torch.as_tensor(
         | 
| 362 | 
            +
                            np.concatenate(rewards_to_go), dtype=torch.float32, device=self.device
         | 
| 363 | 
            +
                        ),
         | 
| 364 | 
            +
                        torch.as_tensor(
         | 
| 365 | 
            +
                            np.concatenate(advantages), dtype=torch.float32, device=self.device
         | 
| 366 | 
            +
                        ),
         | 
| 367 | 
            +
                    )
         | 
    	
        pyproject.toml
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [tool.poetry]
         | 
| 2 | 
            +
            name = "rl-algo-impls"
         | 
| 3 | 
            +
            version = "0.1.0"
         | 
| 4 | 
            +
            description = "Implementations of reinforcement learning algorithms"
         | 
| 5 | 
            +
            authors = ["Scott Goodfriend <goodfriend.scott@gmail.com>"]
         | 
| 6 | 
            +
            license = "MIT License"
         | 
| 7 | 
            +
            readme = "README.md"
         | 
| 8 | 
            +
            packages = [{include = "rl_algo_impls"}]
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            [tool.poetry.dependencies]
         | 
| 11 | 
            +
            python = "~3.10"
         | 
| 12 | 
            +
            "AutoROM.accept-rom-license" = "^0.4.2"
         | 
| 13 | 
            +
            stable-baselines3 = {extras = ["extra"], version = "^1.7.0"}
         | 
| 14 | 
            +
            scipy = "^1.10.0"
         | 
| 15 | 
            +
            gym = {extras = ["box2d"], version = "^0.21.0"}
         | 
| 16 | 
            +
            pyglet = "1.5.27"
         | 
| 17 | 
            +
            PyYAML = "^6.0"
         | 
| 18 | 
            +
            tensorboard = "^2.11.0"
         | 
| 19 | 
            +
            pybullet = "^3.2.5"
         | 
| 20 | 
            +
            wandb = "^0.13.9"
         | 
| 21 | 
            +
            conda-lock = "^1.3.0"
         | 
| 22 | 
            +
            torch-tb-profiler = "^0.4.1"
         | 
| 23 | 
            +
            jupyter = "^1.0.0"
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            [build-system]
         | 
| 26 | 
            +
            requires = ["poetry-core"]
         | 
| 27 | 
            +
            build-backend = "poetry.core.masonry.api"
         | 
    	
        replay.meta.json
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil      57. 28.100 / 57. 28.100\\nlibavcodec     59. 37.100 / 59. 37.100\\nlibavformat    59. 27.100 / 59. 27.100\\nlibavdevice    59.  7.100 / 59.  7.100\\nlibavfilter     8. 44.100 /  8. 44.100\\nlibswscale      6.  7.100 /  6.  7.100\\nlibswresample   4.  7.100 /  4.  7.100\\nlibpostproc    56.  6.100 / 56.  6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmp7t2v9jcd/ppo-Acrobot-v1/replay.mp4"]}, "episode": {"r": -73.0, "l": 74, "t": 1.272925}}
         | 
    	
        replay.mp4
    ADDED
    
    | Binary file (62.5 kB). View file | 
|  | 
    	
        runner/config.py
    ADDED
    
    | @@ -0,0 +1,130 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from datetime import datetime
         | 
| 4 | 
            +
            from dataclasses import dataclass
         | 
| 5 | 
            +
            from typing import Any, Dict, Optional, TypedDict, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @dataclass
         | 
| 9 | 
            +
            class RunArgs:
         | 
| 10 | 
            +
                algo: str
         | 
| 11 | 
            +
                env: str
         | 
| 12 | 
            +
                seed: Optional[int] = None
         | 
| 13 | 
            +
                use_deterministic_algorithms: bool = True
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class Hyperparams(TypedDict, total=False):
         | 
| 17 | 
            +
                device: str
         | 
| 18 | 
            +
                n_timesteps: Union[int, float]
         | 
| 19 | 
            +
                env_hyperparams: Dict[str, Any]
         | 
| 20 | 
            +
                policy_hyperparams: Dict[str, Any]
         | 
| 21 | 
            +
                algo_hyperparams: Dict[str, Any]
         | 
| 22 | 
            +
                eval_params: Dict[str, Any]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            @dataclass
         | 
| 26 | 
            +
            class Config:
         | 
| 27 | 
            +
                args: RunArgs
         | 
| 28 | 
            +
                hyperparams: Hyperparams
         | 
| 29 | 
            +
                root_dir: str
         | 
| 30 | 
            +
                run_id: str = datetime.now().isoformat()
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def seed(self, training: bool = True) -> Optional[int]:
         | 
| 33 | 
            +
                    seed = self.args.seed
         | 
| 34 | 
            +
                    if training or seed is None:
         | 
| 35 | 
            +
                        return seed
         | 
| 36 | 
            +
                    return seed + self.env_hyperparams.get("n_envs", 1)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                @property
         | 
| 39 | 
            +
                def device(self) -> str:
         | 
| 40 | 
            +
                    return self.hyperparams.get("device", "auto")
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                @property
         | 
| 43 | 
            +
                def n_timesteps(self) -> int:
         | 
| 44 | 
            +
                    return int(self.hyperparams.get("n_timesteps", 100_000))
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @property
         | 
| 47 | 
            +
                def env_hyperparams(self) -> Dict[str, Any]:
         | 
| 48 | 
            +
                    return self.hyperparams.get("env_hyperparams", {})
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                @property
         | 
| 51 | 
            +
                def policy_hyperparams(self) -> Dict[str, Any]:
         | 
| 52 | 
            +
                    return self.hyperparams.get("policy_hyperparams", {})
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @property
         | 
| 55 | 
            +
                def algo_hyperparams(self) -> Dict[str, Any]:
         | 
| 56 | 
            +
                    return self.hyperparams.get("algo_hyperparams", {})
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                @property
         | 
| 59 | 
            +
                def eval_params(self) -> Dict[str, Any]:
         | 
| 60 | 
            +
                    return self.hyperparams.get("eval_params", {})
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                @property
         | 
| 63 | 
            +
                def env_id(self) -> str:
         | 
| 64 | 
            +
                    return self.args.env
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                @property
         | 
| 67 | 
            +
                def model_name(self) -> str:
         | 
| 68 | 
            +
                    parts = [self.args.algo, self.env_id]
         | 
| 69 | 
            +
                    if self.args.seed is not None:
         | 
| 70 | 
            +
                        parts.append(f"S{self.args.seed}")
         | 
| 71 | 
            +
                    make_kwargs = self.env_hyperparams.get("make_kwargs", {})
         | 
| 72 | 
            +
                    if make_kwargs:
         | 
| 73 | 
            +
                        for k, v in make_kwargs.items():
         | 
| 74 | 
            +
                            if type(v) == bool and v:
         | 
| 75 | 
            +
                                parts.append(k)
         | 
| 76 | 
            +
                            elif type(v) == int and v:
         | 
| 77 | 
            +
                                parts.append(f"{k}{v}")
         | 
| 78 | 
            +
                            else:
         | 
| 79 | 
            +
                                parts.append(str(v))
         | 
| 80 | 
            +
                    return "-".join(parts)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                @property
         | 
| 83 | 
            +
                def run_name(self) -> str:
         | 
| 84 | 
            +
                    parts = [self.model_name, self.run_id]
         | 
| 85 | 
            +
                    return "-".join(parts)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                @property
         | 
| 88 | 
            +
                def saved_models_dir(self) -> str:
         | 
| 89 | 
            +
                    return os.path.join(self.root_dir, "saved_models")
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                @property
         | 
| 92 | 
            +
                def downloaded_models_dir(self) -> str:
         | 
| 93 | 
            +
                    return os.path.join(self.root_dir, "downloaded_models")
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def model_dir_name(
         | 
| 96 | 
            +
                    self,
         | 
| 97 | 
            +
                    best: bool = False,
         | 
| 98 | 
            +
                    extension: str = "",
         | 
| 99 | 
            +
                ) -> str:
         | 
| 100 | 
            +
                    return self.model_name + ("-best" if best else "") + extension
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
         | 
| 103 | 
            +
                    return os.path.join(
         | 
| 104 | 
            +
                        self.saved_models_dir if not downloaded else self.downloaded_models_dir,
         | 
| 105 | 
            +
                        self.model_dir_name(best=best),
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                @property
         | 
| 109 | 
            +
                def runs_dir(self) -> str:
         | 
| 110 | 
            +
                    return os.path.join(self.root_dir, "runs")
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                @property
         | 
| 113 | 
            +
                def tensorboard_summary_path(self) -> str:
         | 
| 114 | 
            +
                    return os.path.join(self.runs_dir, self.run_name)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                @property
         | 
| 117 | 
            +
                def logs_path(self) -> str:
         | 
| 118 | 
            +
                    return os.path.join(self.runs_dir, f"log.yml")
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                @property
         | 
| 121 | 
            +
                def videos_dir(self) -> str:
         | 
| 122 | 
            +
                    return os.path.join(self.root_dir, "videos")
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                @property
         | 
| 125 | 
            +
                def video_prefix(self) -> str:
         | 
| 126 | 
            +
                    return os.path.join(self.videos_dir, self.model_name)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                @property
         | 
| 129 | 
            +
                def best_videos_dir(self) -> str:
         | 
| 130 | 
            +
                    return os.path.join(self.videos_dir, f"{self.model_name}-best")
         | 
    	
        runner/env.py
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gym
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from gym.wrappers.resize_observation import ResizeObservation
         | 
| 5 | 
            +
            from gym.wrappers.gray_scale_observation import GrayScaleObservation
         | 
| 6 | 
            +
            from gym.wrappers.frame_stack import FrameStack
         | 
| 7 | 
            +
            from stable_baselines3.common.atari_wrappers import (
         | 
| 8 | 
            +
                MaxAndSkipEnv,
         | 
| 9 | 
            +
                NoopResetEnv,
         | 
| 10 | 
            +
            )
         | 
| 11 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv
         | 
| 12 | 
            +
            from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
         | 
| 13 | 
            +
            from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
         | 
| 14 | 
            +
            from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
         | 
| 15 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 16 | 
            +
            from typing import Any, Callable, Dict, Optional, Union
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from runner.config import Config
         | 
| 19 | 
            +
            from shared.policy.policy import VEC_NORMALIZE_FILENAME
         | 
| 20 | 
            +
            from wrappers.atari_wrappers import EpisodicLifeEnv, FireOnLifeStarttEnv, ClipRewardEnv
         | 
| 21 | 
            +
            from wrappers.episode_record_video import EpisodeRecordVideo
         | 
| 22 | 
            +
            from wrappers.episode_stats_writer import EpisodeStatsWriter
         | 
| 23 | 
            +
            from wrappers.initial_step_truncate_wrapper import InitialStepTruncateWrapper
         | 
| 24 | 
            +
            from wrappers.video_compat_wrapper import VideoCompatWrapper
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def make_env(
         | 
| 28 | 
            +
                config: Config,
         | 
| 29 | 
            +
                training: bool = True,
         | 
| 30 | 
            +
                render: bool = False,
         | 
| 31 | 
            +
                normalize_load_path: Optional[str] = None,
         | 
| 32 | 
            +
                n_envs: int = 1,
         | 
| 33 | 
            +
                frame_stack: int = 1,
         | 
| 34 | 
            +
                make_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 35 | 
            +
                no_reward_timeout_steps: Optional[int] = None,
         | 
| 36 | 
            +
                no_reward_fire_steps: Optional[int] = None,
         | 
| 37 | 
            +
                vec_env_class: str = "dummy",
         | 
| 38 | 
            +
                normalize: bool = False,
         | 
| 39 | 
            +
                normalize_kwargs: Optional[Dict[str, Any]] = None,
         | 
| 40 | 
            +
                tb_writer: Optional[SummaryWriter] = None,
         | 
| 41 | 
            +
                rolling_length: int = 100,
         | 
| 42 | 
            +
                train_record_video: bool = False,
         | 
| 43 | 
            +
                video_step_interval: Union[int, float] = 1_000_000,
         | 
| 44 | 
            +
                initial_steps_to_truncate: Optional[int] = None,
         | 
| 45 | 
            +
            ) -> VecEnv:
         | 
| 46 | 
            +
                if "BulletEnv" in config.env_id:
         | 
| 47 | 
            +
                    import pybullet_envs
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                make_kwargs = make_kwargs if make_kwargs is not None else {}
         | 
| 50 | 
            +
                if "BulletEnv" in config.env_id and render:
         | 
| 51 | 
            +
                    make_kwargs["render"] = True
         | 
| 52 | 
            +
                if "CarRacing" in config.env_id:
         | 
| 53 | 
            +
                    make_kwargs["verbose"] = 0
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                spec = gym.spec(config.env_id)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def make(idx: int) -> Callable[[], gym.Env]:
         | 
| 58 | 
            +
                    def _make() -> gym.Env:
         | 
| 59 | 
            +
                        env = gym.make(config.env_id, **make_kwargs)
         | 
| 60 | 
            +
                        env = gym.wrappers.RecordEpisodeStatistics(env)
         | 
| 61 | 
            +
                        env = VideoCompatWrapper(env)
         | 
| 62 | 
            +
                        if training and train_record_video and idx == 0:
         | 
| 63 | 
            +
                            env = EpisodeRecordVideo(
         | 
| 64 | 
            +
                                env,
         | 
| 65 | 
            +
                                config.video_prefix,
         | 
| 66 | 
            +
                                step_increment=n_envs,
         | 
| 67 | 
            +
                                video_step_interval=int(video_step_interval),
         | 
| 68 | 
            +
                            )
         | 
| 69 | 
            +
                        if training and initial_steps_to_truncate:
         | 
| 70 | 
            +
                            env = InitialStepTruncateWrapper(
         | 
| 71 | 
            +
                                env, idx * initial_steps_to_truncate // n_envs
         | 
| 72 | 
            +
                            )
         | 
| 73 | 
            +
                        if "AtariEnv" in spec.entry_point:  # type: ignore
         | 
| 74 | 
            +
                            env = NoopResetEnv(env, noop_max=30)
         | 
| 75 | 
            +
                            env = MaxAndSkipEnv(env, skip=4)
         | 
| 76 | 
            +
                            env = EpisodicLifeEnv(env, training=training)
         | 
| 77 | 
            +
                            action_meanings = env.unwrapped.get_action_meanings()
         | 
| 78 | 
            +
                            if "FIRE" in action_meanings:  # type: ignore
         | 
| 79 | 
            +
                                env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
         | 
| 80 | 
            +
                            env = ClipRewardEnv(env, training=training)
         | 
| 81 | 
            +
                            env = ResizeObservation(env, (84, 84))
         | 
| 82 | 
            +
                            env = GrayScaleObservation(env, keep_dim=False)
         | 
| 83 | 
            +
                            env = FrameStack(env, frame_stack)
         | 
| 84 | 
            +
                        elif "CarRacing" in config.env_id:
         | 
| 85 | 
            +
                            env = ResizeObservation(env, (64, 64))
         | 
| 86 | 
            +
                            env = GrayScaleObservation(env, keep_dim=False)
         | 
| 87 | 
            +
                            env = FrameStack(env, frame_stack)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                        if no_reward_timeout_steps:
         | 
| 90 | 
            +
                            from wrappers.no_reward_timeout import NoRewardTimeout
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                            env = NoRewardTimeout(
         | 
| 93 | 
            +
                                env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
         | 
| 94 | 
            +
                            )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                        seed = config.seed(training=training)
         | 
| 97 | 
            +
                        if seed is not None:
         | 
| 98 | 
            +
                            env.seed(seed + idx)
         | 
| 99 | 
            +
                            env.action_space.seed(seed + idx)
         | 
| 100 | 
            +
                            env.observation_space.seed(seed + idx)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                        return env
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                    return _make
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                VecEnvClass = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_class]
         | 
| 107 | 
            +
                venv = VecEnvClass([make(i) for i in range(n_envs)])
         | 
| 108 | 
            +
                if training:
         | 
| 109 | 
            +
                    assert tb_writer
         | 
| 110 | 
            +
                    venv = EpisodeStatsWriter(
         | 
| 111 | 
            +
                        venv, tb_writer, training=training, rolling_length=rolling_length
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                if normalize:
         | 
| 114 | 
            +
                    if normalize_load_path:
         | 
| 115 | 
            +
                        venv = VecNormalize.load(
         | 
| 116 | 
            +
                            os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME), venv
         | 
| 117 | 
            +
                        )
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                        venv = VecNormalize(venv, training=training, **(normalize_kwargs or {}))
         | 
| 120 | 
            +
                    if not training:
         | 
| 121 | 
            +
                        venv.norm_reward = False
         | 
| 122 | 
            +
                return venv
         | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def make_eval_env(
         | 
| 126 | 
            +
                config: Config, override_n_envs: Optional[int] = None, **kwargs
         | 
| 127 | 
            +
            ) -> VecEnv:
         | 
| 128 | 
            +
                kwargs = kwargs.copy()
         | 
| 129 | 
            +
                kwargs["training"] = False
         | 
| 130 | 
            +
                if override_n_envs is not None:
         | 
| 131 | 
            +
                    kwargs["n_envs"] = override_n_envs
         | 
| 132 | 
            +
                    if override_n_envs == 1:
         | 
| 133 | 
            +
                        kwargs["vec_env_class"] = "dummy"
         | 
| 134 | 
            +
                return make_env(config, **kwargs)
         | 
    	
        runner/running_utils.py
    ADDED
    
    | @@ -0,0 +1,188 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import gym
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import matplotlib.pyplot as plt
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import random
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import torch.backends.cudnn
         | 
| 10 | 
            +
            import yaml
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from gym.spaces import Box, Discrete
         | 
| 13 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv
         | 
| 14 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 15 | 
            +
            from typing import Dict, Optional, Type, Union
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from runner.config import Hyperparams
         | 
| 18 | 
            +
            from shared.algorithm import Algorithm
         | 
| 19 | 
            +
            from shared.callbacks.eval_callback import EvalCallback
         | 
| 20 | 
            +
            from shared.policy.policy import Policy
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from dqn.dqn import DQN
         | 
| 23 | 
            +
            from dqn.policy import DQNPolicy
         | 
| 24 | 
            +
            from vpg.vpg import VanillaPolicyGradient
         | 
| 25 | 
            +
            from vpg.policy import VPGActorCritic
         | 
| 26 | 
            +
            from ppo.ppo import PPO
         | 
| 27 | 
            +
            from ppo.policy import PPOActorCritic
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            ALGOS: Dict[str, Type[Algorithm]] = {
         | 
| 30 | 
            +
                "dqn": DQN,
         | 
| 31 | 
            +
                "vpg": VanillaPolicyGradient,
         | 
| 32 | 
            +
                "ppo": PPO,
         | 
| 33 | 
            +
            }
         | 
| 34 | 
            +
            POLICIES: Dict[str, Type[Policy]] = {
         | 
| 35 | 
            +
                "dqn": DQNPolicy,
         | 
| 36 | 
            +
                "vpg": VPGActorCritic,
         | 
| 37 | 
            +
                "ppo": PPOActorCritic,
         | 
| 38 | 
            +
            }
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            HYPERPARAMS_PATH = "hyperparams"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
            def base_parser() -> argparse.ArgumentParser:
         | 
| 44 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 45 | 
            +
                parser.add_argument(
         | 
| 46 | 
            +
                    "--algo",
         | 
| 47 | 
            +
                    default="dqn",
         | 
| 48 | 
            +
                    type=str,
         | 
| 49 | 
            +
                    choices=list(ALGOS.keys()),
         | 
| 50 | 
            +
                    nargs="+",
         | 
| 51 | 
            +
                    help="Abbreviation(s) of algorithm(s)",
         | 
| 52 | 
            +
                )
         | 
| 53 | 
            +
                parser.add_argument(
         | 
| 54 | 
            +
                    "--env",
         | 
| 55 | 
            +
                    default="CartPole-v1",
         | 
| 56 | 
            +
                    type=str,
         | 
| 57 | 
            +
                    nargs="+",
         | 
| 58 | 
            +
                    help="Name of environment(s) in gym",
         | 
| 59 | 
            +
                )
         | 
| 60 | 
            +
                parser.add_argument(
         | 
| 61 | 
            +
                    "--seed",
         | 
| 62 | 
            +
                    default=1,
         | 
| 63 | 
            +
                    type=int,
         | 
| 64 | 
            +
                    nargs="*",
         | 
| 65 | 
            +
                    help="Seeds to run experiment. Unset will do one run with no set seed",
         | 
| 66 | 
            +
                )
         | 
| 67 | 
            +
                parser.add_argument(
         | 
| 68 | 
            +
                    "--use-deterministic-algorithms",
         | 
| 69 | 
            +
                    default=True,
         | 
| 70 | 
            +
                    type=bool,
         | 
| 71 | 
            +
                    help="If seed set, set torch.use_deterministic_algorithms",
         | 
| 72 | 
            +
                )
         | 
| 73 | 
            +
                return parser
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def load_hyperparams(algo: str, env_id: str, root_path: str) -> Hyperparams:
         | 
| 77 | 
            +
                hyperparams_path = os.path.join(root_path, HYPERPARAMS_PATH, f"{algo}.yml")
         | 
| 78 | 
            +
                with open(hyperparams_path, "r") as f:
         | 
| 79 | 
            +
                    hyperparams_dict = yaml.safe_load(f)
         | 
| 80 | 
            +
                if "BulletEnv" in env_id:
         | 
| 81 | 
            +
                    import pybullet_envs
         | 
| 82 | 
            +
                spec = gym.spec(env_id)
         | 
| 83 | 
            +
                if env_id in hyperparams_dict:
         | 
| 84 | 
            +
                    return hyperparams_dict[env_id]
         | 
| 85 | 
            +
                elif "AtariEnv" in str(spec.entry_point) and "atari" in hyperparams_dict:
         | 
| 86 | 
            +
                    return hyperparams_dict["atari"]
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                    raise ValueError(f"{env_id} not specified in {algo} hyperparameters file")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            def get_device(device: str, env: VecEnv) -> torch.device:
         | 
| 92 | 
            +
                # cuda by default
         | 
| 93 | 
            +
                if device == "auto":
         | 
| 94 | 
            +
                    device = "cuda"
         | 
| 95 | 
            +
                # Apple MPS is a second choice (sometimes)
         | 
| 96 | 
            +
                if device == "cuda" and not torch.cuda.is_available():
         | 
| 97 | 
            +
                    device = "mps"
         | 
| 98 | 
            +
                # If no MPS, fallback to cpu
         | 
| 99 | 
            +
                if device == "mps" and not torch.backends.mps.is_available():
         | 
| 100 | 
            +
                    device = "cpu"
         | 
| 101 | 
            +
                # Simple environments like Discreet and 1-D Boxes might also be better
         | 
| 102 | 
            +
                # served with the CPU.
         | 
| 103 | 
            +
                if device == "mps":
         | 
| 104 | 
            +
                    obs_space = env.observation_space
         | 
| 105 | 
            +
                    if isinstance(obs_space, Discrete):
         | 
| 106 | 
            +
                        device = "cpu"
         | 
| 107 | 
            +
                    elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
         | 
| 108 | 
            +
                        device = "cpu"
         | 
| 109 | 
            +
                print(f"Device: {device}")
         | 
| 110 | 
            +
                return torch.device(device)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
             | 
| 113 | 
            +
            def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
         | 
| 114 | 
            +
                if seed is None:
         | 
| 115 | 
            +
                    return
         | 
| 116 | 
            +
                random.seed(seed)
         | 
| 117 | 
            +
                np.random.seed(seed)
         | 
| 118 | 
            +
                torch.manual_seed(seed)
         | 
| 119 | 
            +
                torch.backends.cudnn.benchmark = False
         | 
| 120 | 
            +
                torch.use_deterministic_algorithms(use_deterministic_algorithms)
         | 
| 121 | 
            +
                os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def make_policy(
         | 
| 125 | 
            +
                algo: str,
         | 
| 126 | 
            +
                env: VecEnv,
         | 
| 127 | 
            +
                device: torch.device,
         | 
| 128 | 
            +
                load_path: Optional[str] = None,
         | 
| 129 | 
            +
                **kwargs,
         | 
| 130 | 
            +
            ) -> Policy:
         | 
| 131 | 
            +
                policy = POLICIES[algo](env, **kwargs).to(device)
         | 
| 132 | 
            +
                if load_path:
         | 
| 133 | 
            +
                    policy.load(load_path)
         | 
| 134 | 
            +
                return policy
         | 
| 135 | 
            +
             | 
| 136 | 
            +
             | 
| 137 | 
            +
            def plot_eval_callback(callback: EvalCallback, tb_writer: SummaryWriter, run_name: str):
         | 
| 138 | 
            +
                figure = plt.figure()
         | 
| 139 | 
            +
                cumulative_steps = [
         | 
| 140 | 
            +
                    (idx + 1) * callback.step_freq for idx in range(len(callback.stats))
         | 
| 141 | 
            +
                ]
         | 
| 142 | 
            +
                plt.plot(
         | 
| 143 | 
            +
                    cumulative_steps,
         | 
| 144 | 
            +
                    [s.score.mean for s in callback.stats],
         | 
| 145 | 
            +
                    "b-",
         | 
| 146 | 
            +
                    label="mean",
         | 
| 147 | 
            +
                )
         | 
| 148 | 
            +
                plt.plot(
         | 
| 149 | 
            +
                    cumulative_steps,
         | 
| 150 | 
            +
                    [s.score.mean - s.score.std for s in callback.stats],
         | 
| 151 | 
            +
                    "g--",
         | 
| 152 | 
            +
                    label="mean-std",
         | 
| 153 | 
            +
                )
         | 
| 154 | 
            +
                plt.fill_between(
         | 
| 155 | 
            +
                    cumulative_steps,
         | 
| 156 | 
            +
                    [s.score.min for s in callback.stats],  # type: ignore
         | 
| 157 | 
            +
                    [s.score.max for s in callback.stats],  # type: ignore
         | 
| 158 | 
            +
                    facecolor="cyan",
         | 
| 159 | 
            +
                    label="range",
         | 
| 160 | 
            +
                )
         | 
| 161 | 
            +
                plt.xlabel("Steps")
         | 
| 162 | 
            +
                plt.ylabel("Score")
         | 
| 163 | 
            +
                plt.legend()
         | 
| 164 | 
            +
                plt.title(f"Eval {run_name}")
         | 
| 165 | 
            +
                tb_writer.add_figure("eval", figure)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
             | 
| 168 | 
            +
            Scalar = Union[bool, str, float, int, None]
         | 
| 169 | 
            +
             | 
| 170 | 
            +
             | 
| 171 | 
            +
            def flatten_hyperparameters(
         | 
| 172 | 
            +
                hyperparams: Hyperparams, args: Dict[str, Union[Scalar, list]]
         | 
| 173 | 
            +
            ) -> Dict[str, Scalar]:
         | 
| 174 | 
            +
                flattened = args.copy()
         | 
| 175 | 
            +
                for k, v in flattened.items():
         | 
| 176 | 
            +
                    if isinstance(v, list):
         | 
| 177 | 
            +
                        flattened[k] = json.dumps(v)
         | 
| 178 | 
            +
                for k, v in hyperparams.items():
         | 
| 179 | 
            +
                    if isinstance(v, dict):
         | 
| 180 | 
            +
                        for sk, sv in v.items():
         | 
| 181 | 
            +
                            key = f"{k}/{sk}"
         | 
| 182 | 
            +
                            if isinstance(sv, dict) or isinstance(sv, list):
         | 
| 183 | 
            +
                                flattened[key] = str(sv)
         | 
| 184 | 
            +
                            else:
         | 
| 185 | 
            +
                                flattened[key] = sv
         | 
| 186 | 
            +
                    else:
         | 
| 187 | 
            +
                        flattened[k] = v  # type: ignore
         | 
| 188 | 
            +
                return flattened  # type: ignore
         | 
    	
        runner/train.py
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import dataclasses
         | 
| 7 | 
            +
            import shutil
         | 
| 8 | 
            +
            import wandb
         | 
| 9 | 
            +
            import yaml
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from dataclasses import dataclass
         | 
| 12 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 13 | 
            +
            from typing import Any, Dict, Optional, Sequence
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from shared.callbacks.eval_callback import EvalCallback
         | 
| 16 | 
            +
            from runner.env import make_env, make_eval_env
         | 
| 17 | 
            +
            from runner.config import Config, RunArgs
         | 
| 18 | 
            +
            from runner.running_utils import (
         | 
| 19 | 
            +
                ALGOS,
         | 
| 20 | 
            +
                load_hyperparams,
         | 
| 21 | 
            +
                set_seeds,
         | 
| 22 | 
            +
                get_device,
         | 
| 23 | 
            +
                make_policy,
         | 
| 24 | 
            +
                plot_eval_callback,
         | 
| 25 | 
            +
                flatten_hyperparameters,
         | 
| 26 | 
            +
            )
         | 
| 27 | 
            +
            from shared.stats import EpisodesStats
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            @dataclass
         | 
| 31 | 
            +
            class TrainArgs(RunArgs):
         | 
| 32 | 
            +
                wandb_project_name: Optional[str] = None
         | 
| 33 | 
            +
                wandb_entity: Optional[str] = None
         | 
| 34 | 
            +
                wandb_tags: Sequence[str] = dataclasses.field(default_factory=list)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def train(args: TrainArgs):
         | 
| 38 | 
            +
                print(args)
         | 
| 39 | 
            +
                hyperparams = load_hyperparams(args.algo, args.env, os.getcwd())
         | 
| 40 | 
            +
                print(hyperparams)
         | 
| 41 | 
            +
                config = Config(args, hyperparams, os.getcwd())
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                wandb_enabled = args.wandb_project_name
         | 
| 44 | 
            +
                if wandb_enabled:
         | 
| 45 | 
            +
                    wandb.tensorboard.patch(
         | 
| 46 | 
            +
                        root_logdir=config.tensorboard_summary_path, pytorch=True
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    wandb.init(
         | 
| 49 | 
            +
                        project=args.wandb_project_name,
         | 
| 50 | 
            +
                        entity=args.wandb_entity,
         | 
| 51 | 
            +
                        config=hyperparams,  # type: ignore
         | 
| 52 | 
            +
                        name=config.run_name,
         | 
| 53 | 
            +
                        monitor_gym=True,
         | 
| 54 | 
            +
                        save_code=True,
         | 
| 55 | 
            +
                        tags=args.wandb_tags,
         | 
| 56 | 
            +
                    )
         | 
| 57 | 
            +
                    wandb.config.update(args)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                tb_writer = SummaryWriter(config.tensorboard_summary_path)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                set_seeds(args.seed, args.use_deterministic_algorithms)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                env = make_env(config, tb_writer=tb_writer, **config.env_hyperparams)
         | 
| 64 | 
            +
                device = get_device(config.device, env)
         | 
| 65 | 
            +
                policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
         | 
| 66 | 
            +
                algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                eval_env = make_eval_env(config, **config.env_hyperparams)
         | 
| 69 | 
            +
                record_best_videos = config.eval_params.get("record_best_videos", True)
         | 
| 70 | 
            +
                callback = EvalCallback(
         | 
| 71 | 
            +
                    policy,
         | 
| 72 | 
            +
                    eval_env,
         | 
| 73 | 
            +
                    tb_writer,
         | 
| 74 | 
            +
                    best_model_path=config.model_dir_path(best=True),
         | 
| 75 | 
            +
                    **config.eval_params,
         | 
| 76 | 
            +
                    video_env=make_eval_env(config, override_n_envs=1, **config.env_hyperparams)
         | 
| 77 | 
            +
                    if record_best_videos
         | 
| 78 | 
            +
                    else None,
         | 
| 79 | 
            +
                    best_video_dir=config.best_videos_dir,
         | 
| 80 | 
            +
                )
         | 
| 81 | 
            +
                algo.learn(config.n_timesteps, callback=callback)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                policy.save(config.model_dir_path(best=False))
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                plot_eval_callback(callback, tb_writer, config.run_name)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                log_dict: Dict[str, Any] = {
         | 
| 90 | 
            +
                    "eval": eval_stats._asdict(),
         | 
| 91 | 
            +
                }
         | 
| 92 | 
            +
                if callback.best:
         | 
| 93 | 
            +
                    log_dict["best_eval"] = callback.best._asdict()
         | 
| 94 | 
            +
                log_dict.update(hyperparams)
         | 
| 95 | 
            +
                log_dict.update(vars(args))
         | 
| 96 | 
            +
                with open(config.logs_path, "a") as f:
         | 
| 97 | 
            +
                    yaml.dump({config.run_name: log_dict}, f)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                best_eval_stats: EpisodesStats = callback.best  # type: ignore
         | 
| 100 | 
            +
                tb_writer.add_hparams(
         | 
| 101 | 
            +
                    flatten_hyperparameters(hyperparams, vars(args)),
         | 
| 102 | 
            +
                    {
         | 
| 103 | 
            +
                        "hparam/best_mean": best_eval_stats.score.mean,
         | 
| 104 | 
            +
                        "hparam/best_result": best_eval_stats.score.mean
         | 
| 105 | 
            +
                        - best_eval_stats.score.std,
         | 
| 106 | 
            +
                        "hparam/last_mean": eval_stats.score.mean,
         | 
| 107 | 
            +
                        "hparam/last_result": eval_stats.score.mean - eval_stats.score.std,
         | 
| 108 | 
            +
                    },
         | 
| 109 | 
            +
                    None,
         | 
| 110 | 
            +
                    config.run_name,
         | 
| 111 | 
            +
                )
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                tb_writer.close()
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                if wandb_enabled:
         | 
| 116 | 
            +
                    shutil.make_archive(
         | 
| 117 | 
            +
                        os.path.join(wandb.run.dir, config.model_dir_name()),
         | 
| 118 | 
            +
                        "zip",
         | 
| 119 | 
            +
                        config.model_dir_path(),
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    shutil.make_archive(
         | 
| 122 | 
            +
                        os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
         | 
| 123 | 
            +
                        "zip",
         | 
| 124 | 
            +
                        config.model_dir_path(best=True),
         | 
| 125 | 
            +
                    )
         | 
| 126 | 
            +
                    wandb.finish()
         | 
    	
        saved_models/ppo-Acrobot-v1-S4-best/model.pth
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:f409782c624f44a81d24dc84be10b3ea2d373dbfccfe86f782a1a9109a9880de
         | 
| 3 | 
            +
            size 41509
         | 
    	
        saved_models/ppo-Acrobot-v1-S4-best/vecnormalize.pkl
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:fb89bec8a5fe259e3d4705482ece78d999ac400dcd5b35ddffdd8cea0a284cf5
         | 
| 3 | 
            +
            size 7013
         | 
    	
        shared/algorithm.py
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gym
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from abc import ABC, abstractmethod
         | 
| 5 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv
         | 
| 6 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 7 | 
            +
            from typing import List, Optional, TypeVar
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from shared.callbacks.callback import Callback
         | 
| 10 | 
            +
            from shared.policy.policy import Policy
         | 
| 11 | 
            +
            from shared.stats import EpisodesStats
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            class Algorithm(ABC):
         | 
| 16 | 
            +
                @abstractmethod
         | 
| 17 | 
            +
                def __init__(
         | 
| 18 | 
            +
                    self,
         | 
| 19 | 
            +
                    policy: Policy,
         | 
| 20 | 
            +
                    env: VecEnv,
         | 
| 21 | 
            +
                    device: torch.device,
         | 
| 22 | 
            +
                    tb_writer: SummaryWriter,
         | 
| 23 | 
            +
                    **kwargs,
         | 
| 24 | 
            +
                ) -> None:
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.policy = policy
         | 
| 27 | 
            +
                    self.env = env
         | 
| 28 | 
            +
                    self.device = device
         | 
| 29 | 
            +
                    self.tb_writer = tb_writer
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                @abstractmethod
         | 
| 32 | 
            +
                def learn(
         | 
| 33 | 
            +
                    self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None
         | 
| 34 | 
            +
                ) -> AlgorithmSelf:
         | 
| 35 | 
            +
                    ...
         | 
    	
        shared/callbacks/callback.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from abc import ABC, abstractmethod
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class Callback(ABC):
         | 
| 5 | 
            +
             | 
| 6 | 
            +
                def __init__(self) -> None:
         | 
| 7 | 
            +
                    super().__init__()
         | 
| 8 | 
            +
                    self.timesteps_elapsed = 0
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                def on_step(self, timesteps_elapsed: int = 1) -> bool:
         | 
| 11 | 
            +
                    self.timesteps_elapsed += timesteps_elapsed
         | 
| 12 | 
            +
                    return True
         | 
    	
        shared/callbacks/eval_callback.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import itertools
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from copy import deepcopy
         | 
| 6 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper
         | 
| 7 | 
            +
            from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
         | 
| 8 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 9 | 
            +
            from typing import List, Optional, Union
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from shared.callbacks.callback import Callback
         | 
| 12 | 
            +
            from shared.policy.policy import Policy
         | 
| 13 | 
            +
            from shared.stats import Episode, EpisodeAccumulator, EpisodesStats
         | 
| 14 | 
            +
            from wrappers.vec_episode_recorder import VecEpisodeRecorder
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class EvaluateAccumulator(EpisodeAccumulator):
         | 
| 18 | 
            +
                def __init__(self, num_envs: int, goal_episodes: int, print_returns: bool = True):
         | 
| 19 | 
            +
                    super().__init__(num_envs)
         | 
| 20 | 
            +
                    self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
         | 
| 21 | 
            +
                    self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
         | 
| 22 | 
            +
                    self.print_returns = print_returns
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def on_done(self, ep_idx: int, episode: Episode) -> None:
         | 
| 25 | 
            +
                    if len(self.completed_episodes_by_env_idx[ep_idx]) >= self.goal_episodes_per_env:
         | 
| 26 | 
            +
                        return
         | 
| 27 | 
            +
                    self.completed_episodes_by_env_idx[ep_idx].append(episode)
         | 
| 28 | 
            +
                    if self.print_returns:
         | 
| 29 | 
            +
                        print(
         | 
| 30 | 
            +
                            f"Episode {len(self)} | "
         | 
| 31 | 
            +
                            f"Score {episode.score} | "
         | 
| 32 | 
            +
                            f"Length {episode.length}"
         | 
| 33 | 
            +
                        )
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def __len__(self) -> int:
         | 
| 36 | 
            +
                    return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                @property
         | 
| 39 | 
            +
                def episodes(self) -> bool:
         | 
| 40 | 
            +
                    return list(itertools.chain(*self.completed_episodes_by_env_idx)) 
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def is_done(self) -> bool:
         | 
| 43 | 
            +
                    return all(len(ce) == self.goal_episodes_per_env for ce in self.completed_episodes_by_env_idx)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            def evaluate(
         | 
| 47 | 
            +
                env: VecEnv,
         | 
| 48 | 
            +
                policy: Policy,
         | 
| 49 | 
            +
                n_episodes: int,
         | 
| 50 | 
            +
                render: bool = False,
         | 
| 51 | 
            +
                deterministic: bool = True,
         | 
| 52 | 
            +
                print_returns: bool = True,
         | 
| 53 | 
            +
            ) -> EpisodesStats:
         | 
| 54 | 
            +
                policy.eval()
         | 
| 55 | 
            +
                episodes = EvaluateAccumulator(env.num_envs, n_episodes, print_returns)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                obs = env.reset()
         | 
| 58 | 
            +
                while not episodes.is_done():
         | 
| 59 | 
            +
                    act = policy.act(obs, deterministic=deterministic)
         | 
| 60 | 
            +
                    obs, rew, done, _ = env.step(act)
         | 
| 61 | 
            +
                    episodes.step(rew, done)
         | 
| 62 | 
            +
                    if render:
         | 
| 63 | 
            +
                        env.render()
         | 
| 64 | 
            +
                stats = EpisodesStats(episodes.episodes)
         | 
| 65 | 
            +
                if print_returns:
         | 
| 66 | 
            +
                    print(stats)
         | 
| 67 | 
            +
                return stats
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            class EvalCallback(Callback):
         | 
| 71 | 
            +
                def __init__(
         | 
| 72 | 
            +
                    self,
         | 
| 73 | 
            +
                    policy: Policy,
         | 
| 74 | 
            +
                    env: VecEnv,
         | 
| 75 | 
            +
                    tb_writer: SummaryWriter,
         | 
| 76 | 
            +
                    best_model_path: Optional[str] = None,
         | 
| 77 | 
            +
                    step_freq: Union[int, float] = 50_000,
         | 
| 78 | 
            +
                    n_episodes: int = 10,
         | 
| 79 | 
            +
                    save_best: bool = True,
         | 
| 80 | 
            +
                    deterministic: bool = True,
         | 
| 81 | 
            +
                    record_best_videos: bool = True,
         | 
| 82 | 
            +
                    video_env: Optional[VecEnv] = None,
         | 
| 83 | 
            +
                    best_video_dir: Optional[str] = None,
         | 
| 84 | 
            +
                    max_video_length: int = 3600,
         | 
| 85 | 
            +
                ) -> None:
         | 
| 86 | 
            +
                    super().__init__()
         | 
| 87 | 
            +
                    self.policy = policy
         | 
| 88 | 
            +
                    self.env = env
         | 
| 89 | 
            +
                    self.tb_writer = tb_writer
         | 
| 90 | 
            +
                    self.best_model_path = best_model_path
         | 
| 91 | 
            +
                    self.step_freq = int(step_freq)
         | 
| 92 | 
            +
                    self.n_episodes = n_episodes
         | 
| 93 | 
            +
                    self.save_best = save_best
         | 
| 94 | 
            +
                    self.deterministic = deterministic
         | 
| 95 | 
            +
                    self.stats: List[EpisodesStats] = []
         | 
| 96 | 
            +
                    self.best = None
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    self.record_best_videos = record_best_videos
         | 
| 99 | 
            +
                    assert video_env or not record_best_videos
         | 
| 100 | 
            +
                    self.video_env = video_env
         | 
| 101 | 
            +
                    assert best_video_dir or not record_best_videos
         | 
| 102 | 
            +
                    self.best_video_dir = best_video_dir
         | 
| 103 | 
            +
                    if best_video_dir:
         | 
| 104 | 
            +
                        os.makedirs(best_video_dir, exist_ok=True)
         | 
| 105 | 
            +
                    self.max_video_length = max_video_length
         | 
| 106 | 
            +
                    self.best_video_base_path = None
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def on_step(self, timesteps_elapsed: int = 1) -> bool:
         | 
| 109 | 
            +
                    super().on_step(timesteps_elapsed)
         | 
| 110 | 
            +
                    if self.timesteps_elapsed // self.step_freq >= len(self.stats):
         | 
| 111 | 
            +
                        self.sync_vec_normalize(self.env)
         | 
| 112 | 
            +
                        self.evaluate()
         | 
| 113 | 
            +
                    return True
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def evaluate(
         | 
| 116 | 
            +
                    self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
         | 
| 117 | 
            +
                ) -> EpisodesStats:
         | 
| 118 | 
            +
                    eval_stat = evaluate(
         | 
| 119 | 
            +
                        self.env,
         | 
| 120 | 
            +
                        self.policy,
         | 
| 121 | 
            +
                        n_episodes or self.n_episodes,
         | 
| 122 | 
            +
                        deterministic=self.deterministic,
         | 
| 123 | 
            +
                        print_returns=print_returns or False,
         | 
| 124 | 
            +
                    )
         | 
| 125 | 
            +
                    self.policy.train(True)
         | 
| 126 | 
            +
                    print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    self.stats.append(eval_stat)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    if not self.best or eval_stat >= self.best:
         | 
| 131 | 
            +
                        strictly_better = not self.best or eval_stat > self.best
         | 
| 132 | 
            +
                        self.best = eval_stat
         | 
| 133 | 
            +
                        if self.save_best:
         | 
| 134 | 
            +
                            assert self.best_model_path
         | 
| 135 | 
            +
                            self.policy.save(self.best_model_path)
         | 
| 136 | 
            +
                            print("Saved best model")
         | 
| 137 | 
            +
                        self.best.write_to_tensorboard(self.tb_writer, "best_eval", self.timesteps_elapsed)
         | 
| 138 | 
            +
                        if strictly_better and self.record_best_videos:
         | 
| 139 | 
            +
                            assert self.video_env and self.best_video_dir
         | 
| 140 | 
            +
                            self.sync_vec_normalize(self.video_env)
         | 
| 141 | 
            +
                            self.best_video_base_path = os.path.join(
         | 
| 142 | 
            +
                                self.best_video_dir, str(self.timesteps_elapsed)
         | 
| 143 | 
            +
                            )
         | 
| 144 | 
            +
                            video_wrapped = VecEpisodeRecorder(
         | 
| 145 | 
            +
                                self.video_env,
         | 
| 146 | 
            +
                                self.best_video_base_path,
         | 
| 147 | 
            +
                                max_video_length=self.max_video_length,
         | 
| 148 | 
            +
                            )
         | 
| 149 | 
            +
                            video_stats = evaluate(
         | 
| 150 | 
            +
                                video_wrapped,
         | 
| 151 | 
            +
                                self.policy,
         | 
| 152 | 
            +
                                1,
         | 
| 153 | 
            +
                                deterministic=self.deterministic,
         | 
| 154 | 
            +
                                print_returns=False,
         | 
| 155 | 
            +
                            )
         | 
| 156 | 
            +
                            print(f"Saved best video: {video_stats}")
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    return eval_stat
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def sync_vec_normalize(self, destination_env: VecEnv) -> None:
         | 
| 163 | 
            +
                    if self.policy.vec_normalize is not None:
         | 
| 164 | 
            +
                        eval_env_wrapper = destination_env
         | 
| 165 | 
            +
                        while isinstance(eval_env_wrapper, VecEnvWrapper):
         | 
| 166 | 
            +
                            if isinstance(eval_env_wrapper, VecNormalize):
         | 
| 167 | 
            +
                                if hasattr(self.policy.vec_normalize, "obs_rms"):
         | 
| 168 | 
            +
                                    eval_env_wrapper.obs_rms = deepcopy(
         | 
| 169 | 
            +
                                        self.policy.vec_normalize.obs_rms
         | 
| 170 | 
            +
                                    )
         | 
| 171 | 
            +
                                eval_env_wrapper.ret_rms = deepcopy(
         | 
| 172 | 
            +
                                    self.policy.vec_normalize.ret_rms
         | 
| 173 | 
            +
                                )
         | 
| 174 | 
            +
                            eval_env_wrapper = eval_env_wrapper.venv
         | 
    	
        shared/module.py
    ADDED
    
    | @@ -0,0 +1,121 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gym
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from gym.spaces import Box, Discrete
         | 
| 8 | 
            +
            from stable_baselines3.common.preprocessing import get_flattened_obs_dim
         | 
| 9 | 
            +
            from typing import Sequence, Type
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class FeatureExtractor(nn.Module):
         | 
| 13 | 
            +
                def __init__(
         | 
| 14 | 
            +
                    self,
         | 
| 15 | 
            +
                    obs_space: gym.Space,
         | 
| 16 | 
            +
                    activation: Type[nn.Module],
         | 
| 17 | 
            +
                    init_layers_orthogonal: bool = False,
         | 
| 18 | 
            +
                    cnn_feature_dim: int = 512,
         | 
| 19 | 
            +
                ) -> None:
         | 
| 20 | 
            +
                    super().__init__()
         | 
| 21 | 
            +
                    if isinstance(obs_space, Box):
         | 
| 22 | 
            +
                        # Conv2D: (channels, height, width)
         | 
| 23 | 
            +
                        if len(obs_space.shape) == 3:
         | 
| 24 | 
            +
                            # CNN from DQN Nature paper: Mnih, Volodymyr, et al.
         | 
| 25 | 
            +
                            # "Human-level control through deep reinforcement learning."
         | 
| 26 | 
            +
                            # Nature 518.7540 (2015): 529-533.
         | 
| 27 | 
            +
                            cnn = nn.Sequential(
         | 
| 28 | 
            +
                                layer_init(
         | 
| 29 | 
            +
                                    nn.Conv2d(obs_space.shape[0], 32, kernel_size=8, stride=4),
         | 
| 30 | 
            +
                                    init_layers_orthogonal,
         | 
| 31 | 
            +
                                ),
         | 
| 32 | 
            +
                                activation(),
         | 
| 33 | 
            +
                                layer_init(
         | 
| 34 | 
            +
                                    nn.Conv2d(32, 64, kernel_size=4, stride=2),
         | 
| 35 | 
            +
                                    init_layers_orthogonal,
         | 
| 36 | 
            +
                                ),
         | 
| 37 | 
            +
                                activation(),
         | 
| 38 | 
            +
                                layer_init(
         | 
| 39 | 
            +
                                    nn.Conv2d(64, 64, kernel_size=3, stride=1),
         | 
| 40 | 
            +
                                    init_layers_orthogonal,
         | 
| 41 | 
            +
                                ),
         | 
| 42 | 
            +
                                activation(),
         | 
| 43 | 
            +
                                nn.Flatten(),
         | 
| 44 | 
            +
                            )
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                            def preprocess(obs: torch.Tensor) -> torch.Tensor:
         | 
| 47 | 
            +
                                if len(obs.shape) == 3:
         | 
| 48 | 
            +
                                    obs = obs.unsqueeze(0)
         | 
| 49 | 
            +
                                return obs.float() / 255.0
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                            with torch.no_grad():
         | 
| 52 | 
            +
                                cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample())))
         | 
| 53 | 
            +
                            self.preprocess = preprocess
         | 
| 54 | 
            +
                            self.feature_extractor = nn.Sequential(
         | 
| 55 | 
            +
                                cnn,
         | 
| 56 | 
            +
                                layer_init(
         | 
| 57 | 
            +
                                    nn.Linear(cnn_out.shape[1], cnn_feature_dim),
         | 
| 58 | 
            +
                                    init_layers_orthogonal,
         | 
| 59 | 
            +
                                ),
         | 
| 60 | 
            +
                                activation(),
         | 
| 61 | 
            +
                            )
         | 
| 62 | 
            +
                            self.out_dim = cnn_feature_dim
         | 
| 63 | 
            +
                        elif len(obs_space.shape) == 1:
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                            def preprocess(obs: torch.Tensor) -> torch.Tensor:
         | 
| 66 | 
            +
                                if len(obs.shape) == 1:
         | 
| 67 | 
            +
                                    obs = obs.unsqueeze(0)
         | 
| 68 | 
            +
                                return obs.float()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                            self.preprocess = preprocess
         | 
| 71 | 
            +
                            self.feature_extractor = nn.Flatten()
         | 
| 72 | 
            +
                            self.out_dim = get_flattened_obs_dim(obs_space)
         | 
| 73 | 
            +
                        else:
         | 
| 74 | 
            +
                            raise ValueError(f"Unsupported observation space: {obs_space}")
         | 
| 75 | 
            +
                    elif isinstance(obs_space, Discrete):
         | 
| 76 | 
            +
                        self.preprocess = lambda x: F.one_hot(x, obs_space.n).float()
         | 
| 77 | 
            +
                        self.feature_extractor = nn.Flatten()
         | 
| 78 | 
            +
                        self.out_dim = obs_space.n
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        raise NotImplementedError
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         | 
| 83 | 
            +
                    if self.preprocess:
         | 
| 84 | 
            +
                        obs = self.preprocess(obs)
         | 
| 85 | 
            +
                    return self.feature_extractor(obs)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
             | 
| 88 | 
            +
            def mlp(
         | 
| 89 | 
            +
                layer_sizes: Sequence[int],
         | 
| 90 | 
            +
                activation: Type[nn.Module],
         | 
| 91 | 
            +
                output_activation: Type[nn.Module] = nn.Identity,
         | 
| 92 | 
            +
                init_layers_orthogonal: bool = False,
         | 
| 93 | 
            +
                final_layer_gain: float = np.sqrt(2),
         | 
| 94 | 
            +
            ) -> nn.Module:
         | 
| 95 | 
            +
                layers = []
         | 
| 96 | 
            +
                for i in range(len(layer_sizes) - 2):
         | 
| 97 | 
            +
                    layers.append(
         | 
| 98 | 
            +
                        layer_init(
         | 
| 99 | 
            +
                            nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal
         | 
| 100 | 
            +
                        )
         | 
| 101 | 
            +
                    )
         | 
| 102 | 
            +
                    layers.append(activation())
         | 
| 103 | 
            +
                layers.append(
         | 
| 104 | 
            +
                    layer_init(
         | 
| 105 | 
            +
                        nn.Linear(layer_sizes[-2], layer_sizes[-1]),
         | 
| 106 | 
            +
                        init_layers_orthogonal,
         | 
| 107 | 
            +
                        std=final_layer_gain,
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
                )
         | 
| 110 | 
            +
                layers.append(output_activation())
         | 
| 111 | 
            +
                return nn.Sequential(*layers)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            def layer_init(
         | 
| 115 | 
            +
                layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2)
         | 
| 116 | 
            +
            ) -> nn.Module:
         | 
| 117 | 
            +
                if not init_layers_orthogonal:
         | 
| 118 | 
            +
                    return layer
         | 
| 119 | 
            +
                nn.init.orthogonal_(layer.weight, std)  # type: ignore
         | 
| 120 | 
            +
                nn.init.constant_(layer.bias, 0.0)  # type: ignore
         | 
| 121 | 
            +
                return layer
         | 
    	
        shared/policy/actor.py
    ADDED
    
    | @@ -0,0 +1,304 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gym
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from abc import ABC, abstractmethod
         | 
| 6 | 
            +
            from gym.spaces import Box, Discrete
         | 
| 7 | 
            +
            from torch.distributions import Categorical, Distribution, Normal
         | 
| 8 | 
            +
            from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from shared.module import FeatureExtractor, mlp
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            class PiForward(NamedTuple):
         | 
| 14 | 
            +
                pi: Distribution
         | 
| 15 | 
            +
                logp_a: Optional[torch.Tensor]
         | 
| 16 | 
            +
                entropy: Optional[torch.Tensor]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class Actor(nn.Module, ABC):
         | 
| 20 | 
            +
                @abstractmethod
         | 
| 21 | 
            +
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         | 
| 22 | 
            +
                    ...
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class CategoricalActorHead(Actor):
         | 
| 26 | 
            +
                def __init__(
         | 
| 27 | 
            +
                    self,
         | 
| 28 | 
            +
                    act_dim: int,
         | 
| 29 | 
            +
                    hidden_sizes: Sequence[int] = (32,),
         | 
| 30 | 
            +
                    activation: Type[nn.Module] = nn.Tanh,
         | 
| 31 | 
            +
                    init_layers_orthogonal: bool = True,
         | 
| 32 | 
            +
                ) -> None:
         | 
| 33 | 
            +
                    super().__init__()
         | 
| 34 | 
            +
                    layer_sizes = tuple(hidden_sizes) + (act_dim,)
         | 
| 35 | 
            +
                    self._fc = mlp(
         | 
| 36 | 
            +
                        layer_sizes,
         | 
| 37 | 
            +
                        activation,
         | 
| 38 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 39 | 
            +
                        final_layer_gain=0.01,
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         | 
| 43 | 
            +
                    logits = self._fc(obs)
         | 
| 44 | 
            +
                    pi = Categorical(logits=logits)
         | 
| 45 | 
            +
                    logp_a = None
         | 
| 46 | 
            +
                    entropy = None
         | 
| 47 | 
            +
                    if a is not None:
         | 
| 48 | 
            +
                        logp_a = pi.log_prob(a)
         | 
| 49 | 
            +
                        entropy = pi.entropy()
         | 
| 50 | 
            +
                    return PiForward(pi, logp_a, entropy)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class GaussianDistribution(Normal):
         | 
| 54 | 
            +
                def log_prob(self, a: torch.Tensor) -> torch.Tensor:
         | 
| 55 | 
            +
                    return super().log_prob(a).sum(axis=-1)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def sample(self) -> torch.Tensor:
         | 
| 58 | 
            +
                    return self.rsample()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class GaussianActorHead(Actor):
         | 
| 62 | 
            +
                def __init__(
         | 
| 63 | 
            +
                    self,
         | 
| 64 | 
            +
                    act_dim: int,
         | 
| 65 | 
            +
                    hidden_sizes: Sequence[int] = (32,),
         | 
| 66 | 
            +
                    activation: Type[nn.Module] = nn.Tanh,
         | 
| 67 | 
            +
                    init_layers_orthogonal: bool = True,
         | 
| 68 | 
            +
                    log_std_init: float = -0.5,
         | 
| 69 | 
            +
                ) -> None:
         | 
| 70 | 
            +
                    super().__init__()
         | 
| 71 | 
            +
                    layer_sizes = tuple(hidden_sizes) + (act_dim,)
         | 
| 72 | 
            +
                    self.mu_net = mlp(
         | 
| 73 | 
            +
                        layer_sizes,
         | 
| 74 | 
            +
                        activation,
         | 
| 75 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 76 | 
            +
                        final_layer_gain=0.01,
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
                    self.log_std = nn.Parameter(
         | 
| 79 | 
            +
                        torch.ones(act_dim, dtype=torch.float32) * log_std_init
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def _distribution(self, obs: torch.Tensor) -> Distribution:
         | 
| 83 | 
            +
                    mu = self.mu_net(obs)
         | 
| 84 | 
            +
                    std = torch.exp(self.log_std)
         | 
| 85 | 
            +
                    return GaussianDistribution(mu, std)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         | 
| 88 | 
            +
                    pi = self._distribution(obs)
         | 
| 89 | 
            +
                    logp_a = None
         | 
| 90 | 
            +
                    entropy = None
         | 
| 91 | 
            +
                    if a is not None:
         | 
| 92 | 
            +
                        logp_a = pi.log_prob(a)
         | 
| 93 | 
            +
                        entropy = pi.entropy()
         | 
| 94 | 
            +
                    return PiForward(pi, logp_a, entropy)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            class TanhBijector:
         | 
| 98 | 
            +
                def __init__(self, epsilon: float = 1e-6) -> None:
         | 
| 99 | 
            +
                    self.epsilon = epsilon
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                @staticmethod
         | 
| 102 | 
            +
                def forward(x: torch.Tensor) -> torch.Tensor:
         | 
| 103 | 
            +
                    return torch.tanh(x)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                @staticmethod
         | 
| 106 | 
            +
                def inverse(y: torch.Tensor) -> torch.Tensor:
         | 
| 107 | 
            +
                    eps = torch.finfo(y.dtype).eps
         | 
| 108 | 
            +
                    clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps)
         | 
| 109 | 
            +
                    return torch.atanh(clamped_y)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
         | 
| 112 | 
            +
                    return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            class StateDependentNoiseDistribution(Normal):
         | 
| 116 | 
            +
                def __init__(
         | 
| 117 | 
            +
                    self,
         | 
| 118 | 
            +
                    loc,
         | 
| 119 | 
            +
                    scale,
         | 
| 120 | 
            +
                    latent_sde: torch.Tensor,
         | 
| 121 | 
            +
                    exploration_mat: torch.Tensor,
         | 
| 122 | 
            +
                    exploration_matrices: torch.Tensor,
         | 
| 123 | 
            +
                    bijector: Optional[TanhBijector] = None,
         | 
| 124 | 
            +
                    validate_args=None,
         | 
| 125 | 
            +
                ):
         | 
| 126 | 
            +
                    super().__init__(loc, scale, validate_args)
         | 
| 127 | 
            +
                    self.latent_sde = latent_sde
         | 
| 128 | 
            +
                    self.exploration_mat = exploration_mat
         | 
| 129 | 
            +
                    self.exploration_matrices = exploration_matrices
         | 
| 130 | 
            +
                    self.bijector = bijector
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def log_prob(self, a: torch.Tensor) -> torch.Tensor:
         | 
| 133 | 
            +
                    gaussian_a = self.bijector.inverse(a) if self.bijector else a
         | 
| 134 | 
            +
                    log_prob = super().log_prob(gaussian_a).sum(axis=-1)
         | 
| 135 | 
            +
                    if self.bijector:
         | 
| 136 | 
            +
                        log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1)
         | 
| 137 | 
            +
                    return log_prob
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def sample(self) -> torch.Tensor:
         | 
| 140 | 
            +
                    noise = self._get_noise()
         | 
| 141 | 
            +
                    actions = self.mean + noise
         | 
| 142 | 
            +
                    return self.bijector.forward(actions) if self.bijector else actions
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                def _get_noise(self) -> torch.Tensor:
         | 
| 145 | 
            +
                    if len(self.latent_sde) == 1 or len(self.latent_sde) != len(
         | 
| 146 | 
            +
                        self.exploration_matrices
         | 
| 147 | 
            +
                    ):
         | 
| 148 | 
            +
                        return torch.mm(self.latent_sde, self.exploration_mat)
         | 
| 149 | 
            +
                    # (batch_size, n_features) -> (batch_size, 1, n_features)
         | 
| 150 | 
            +
                    latent_sde = self.latent_sde.unsqueeze(dim=1)
         | 
| 151 | 
            +
                    # (batch_size, 1, n_actions)
         | 
| 152 | 
            +
                    noise = torch.bmm(latent_sde, self.exploration_matrices)
         | 
| 153 | 
            +
                    return noise.squeeze(dim=1)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                @property
         | 
| 156 | 
            +
                def mode(self) -> torch.Tensor:
         | 
| 157 | 
            +
                    mean = super().mode
         | 
| 158 | 
            +
                    return self.bijector.forward(mean) if self.bijector else mean
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            StateDependentNoiseActorHeadSelf = TypeVar(
         | 
| 162 | 
            +
                "StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead"
         | 
| 163 | 
            +
            )
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            class StateDependentNoiseActorHead(Actor):
         | 
| 167 | 
            +
                def __init__(
         | 
| 168 | 
            +
                    self,
         | 
| 169 | 
            +
                    act_dim: int,
         | 
| 170 | 
            +
                    hidden_sizes: Sequence[int] = (32,),
         | 
| 171 | 
            +
                    activation: Type[nn.Module] = nn.Tanh,
         | 
| 172 | 
            +
                    init_layers_orthogonal: bool = True,
         | 
| 173 | 
            +
                    log_std_init: float = -0.5,
         | 
| 174 | 
            +
                    full_std: bool = True,
         | 
| 175 | 
            +
                    squash_output: bool = False,
         | 
| 176 | 
            +
                    learn_std: bool = False,
         | 
| 177 | 
            +
                ) -> None:
         | 
| 178 | 
            +
                    super().__init__()
         | 
| 179 | 
            +
                    self.act_dim = act_dim
         | 
| 180 | 
            +
                    layer_sizes = tuple(hidden_sizes) + (self.act_dim,)
         | 
| 181 | 
            +
                    if len(layer_sizes) == 2:
         | 
| 182 | 
            +
                        self.latent_net = nn.Identity()
         | 
| 183 | 
            +
                    elif len(layer_sizes) > 2:
         | 
| 184 | 
            +
                        self.latent_net = mlp(
         | 
| 185 | 
            +
                            layer_sizes[:-1],
         | 
| 186 | 
            +
                            activation,
         | 
| 187 | 
            +
                            output_activation=activation,
         | 
| 188 | 
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         | 
| 189 | 
            +
                        )
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        raise ValueError("hidden_sizes must be of at least length 1")
         | 
| 192 | 
            +
                    self.mu_net = mlp(
         | 
| 193 | 
            +
                        layer_sizes[-2:],
         | 
| 194 | 
            +
                        activation,
         | 
| 195 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 196 | 
            +
                        final_layer_gain=0.01,
         | 
| 197 | 
            +
                    )
         | 
| 198 | 
            +
                    self.full_std = full_std
         | 
| 199 | 
            +
                    std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1)
         | 
| 200 | 
            +
                    self.log_std = nn.Parameter(
         | 
| 201 | 
            +
                        torch.ones(std_dim, dtype=torch.float32) * log_std_init
         | 
| 202 | 
            +
                    )
         | 
| 203 | 
            +
                    self.bijector = TanhBijector() if squash_output else None
         | 
| 204 | 
            +
                    self.learn_std = learn_std
         | 
| 205 | 
            +
                    self.device = None
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    self.exploration_mat = None
         | 
| 208 | 
            +
                    self.exploration_matrices = None
         | 
| 209 | 
            +
                    self.sample_weights()
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def to(
         | 
| 212 | 
            +
                    self: StateDependentNoiseActorHeadSelf,
         | 
| 213 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 214 | 
            +
                    dtype: Optional[Union[torch.dtype, str]] = None,
         | 
| 215 | 
            +
                    non_blocking: bool = False,
         | 
| 216 | 
            +
                ) -> StateDependentNoiseActorHeadSelf:
         | 
| 217 | 
            +
                    super().to(device, dtype, non_blocking)
         | 
| 218 | 
            +
                    self.device = device
         | 
| 219 | 
            +
                    return self
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                def _distribution(self, obs: torch.Tensor) -> Distribution:
         | 
| 222 | 
            +
                    latent = self.latent_net(obs)
         | 
| 223 | 
            +
                    mu = self.mu_net(latent)
         | 
| 224 | 
            +
                    latent_sde = latent if self.learn_std else latent.detach()
         | 
| 225 | 
            +
                    variance = torch.mm(latent_sde**2, self._get_std() ** 2)
         | 
| 226 | 
            +
                    assert self.exploration_mat is not None
         | 
| 227 | 
            +
                    assert self.exploration_matrices is not None
         | 
| 228 | 
            +
                    return StateDependentNoiseDistribution(
         | 
| 229 | 
            +
                        mu,
         | 
| 230 | 
            +
                        torch.sqrt(variance + 1e-6),
         | 
| 231 | 
            +
                        latent_sde,
         | 
| 232 | 
            +
                        self.exploration_mat,
         | 
| 233 | 
            +
                        self.exploration_matrices,
         | 
| 234 | 
            +
                        self.bijector,
         | 
| 235 | 
            +
                    )
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                def _get_std(self) -> torch.Tensor:
         | 
| 238 | 
            +
                    std = torch.exp(self.log_std)
         | 
| 239 | 
            +
                    if self.full_std:
         | 
| 240 | 
            +
                        return std
         | 
| 241 | 
            +
                    ones = torch.ones(self.log_std.shape[0], self.act_dim)
         | 
| 242 | 
            +
                    if self.device:
         | 
| 243 | 
            +
                        ones = ones.to(self.device)
         | 
| 244 | 
            +
                    return ones * std
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         | 
| 247 | 
            +
                    pi = self._distribution(obs)
         | 
| 248 | 
            +
                    logp_a = None
         | 
| 249 | 
            +
                    entropy = None
         | 
| 250 | 
            +
                    if a is not None:
         | 
| 251 | 
            +
                        logp_a = pi.log_prob(a)
         | 
| 252 | 
            +
                        entropy = -logp_a
         | 
| 253 | 
            +
                    return PiForward(pi, logp_a, entropy)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def sample_weights(self, batch_size: int = 1) -> None:
         | 
| 256 | 
            +
                    std = self._get_std()
         | 
| 257 | 
            +
                    weights_dist = Normal(torch.zeros_like(std), std)
         | 
| 258 | 
            +
                    # Reparametrization trick to pass gradients
         | 
| 259 | 
            +
                    self.exploration_mat = weights_dist.rsample()
         | 
| 260 | 
            +
                    self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,)))
         | 
| 261 | 
            +
             | 
| 262 | 
            +
             | 
| 263 | 
            +
            def actor_head(
         | 
| 264 | 
            +
                action_space: gym.Space,
         | 
| 265 | 
            +
                hidden_sizes: Sequence[int],
         | 
| 266 | 
            +
                init_layers_orthogonal: bool,
         | 
| 267 | 
            +
                activation: Type[nn.Module],
         | 
| 268 | 
            +
                log_std_init: float = -0.5,
         | 
| 269 | 
            +
                use_sde: bool = False,
         | 
| 270 | 
            +
                full_std: bool = True,
         | 
| 271 | 
            +
                squash_output: bool = False,
         | 
| 272 | 
            +
            ) -> Actor:
         | 
| 273 | 
            +
                assert not use_sde or isinstance(
         | 
| 274 | 
            +
                    action_space, Box
         | 
| 275 | 
            +
                ), "use_sde only valid if Box action_space"
         | 
| 276 | 
            +
                assert not squash_output or use_sde, "squash_output only valid if use_sde"
         | 
| 277 | 
            +
                if isinstance(action_space, Discrete):
         | 
| 278 | 
            +
                    return CategoricalActorHead(
         | 
| 279 | 
            +
                        action_space.n,
         | 
| 280 | 
            +
                        hidden_sizes=hidden_sizes,
         | 
| 281 | 
            +
                        activation=activation,
         | 
| 282 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 283 | 
            +
                    )
         | 
| 284 | 
            +
                elif isinstance(action_space, Box):
         | 
| 285 | 
            +
                    if use_sde:
         | 
| 286 | 
            +
                        return StateDependentNoiseActorHead(
         | 
| 287 | 
            +
                            action_space.shape[0],
         | 
| 288 | 
            +
                            hidden_sizes=hidden_sizes,
         | 
| 289 | 
            +
                            activation=activation,
         | 
| 290 | 
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         | 
| 291 | 
            +
                            log_std_init=log_std_init,
         | 
| 292 | 
            +
                            full_std=full_std,
         | 
| 293 | 
            +
                            squash_output=squash_output,
         | 
| 294 | 
            +
                        )
         | 
| 295 | 
            +
                    else:
         | 
| 296 | 
            +
                        return GaussianActorHead(
         | 
| 297 | 
            +
                            action_space.shape[0],
         | 
| 298 | 
            +
                            hidden_sizes=hidden_sizes,
         | 
| 299 | 
            +
                            activation=activation,
         | 
| 300 | 
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         | 
| 301 | 
            +
                            log_std_init=log_std_init,
         | 
| 302 | 
            +
                        )
         | 
| 303 | 
            +
                else:
         | 
| 304 | 
            +
                    raise ValueError(f"Unsupported action space: {action_space}")
         | 
    	
        shared/policy/critic.py
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gym
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from typing import Sequence, Type
         | 
| 6 | 
            +
            from shared.module import FeatureExtractor, mlp
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class CriticHead(nn.Module):
         | 
| 10 | 
            +
                def __init__(
         | 
| 11 | 
            +
                    self,
         | 
| 12 | 
            +
                    hidden_sizes: Sequence[int] = (32,),
         | 
| 13 | 
            +
                    activation: Type[nn.Module] = nn.Tanh,
         | 
| 14 | 
            +
                    init_layers_orthogonal: bool = True,
         | 
| 15 | 
            +
                ) -> None:
         | 
| 16 | 
            +
                    super().__init__()
         | 
| 17 | 
            +
                    layer_sizes = tuple(hidden_sizes) + (1,)
         | 
| 18 | 
            +
                    self._fc = mlp(
         | 
| 19 | 
            +
                        layer_sizes,
         | 
| 20 | 
            +
                        activation,
         | 
| 21 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 22 | 
            +
                        final_layer_gain=1.0,
         | 
| 23 | 
            +
                    )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def forward(self, obs: torch.Tensor) -> torch.Tensor:
         | 
| 26 | 
            +
                    v = self._fc(obs)
         | 
| 27 | 
            +
                    return v.squeeze(-1)
         | 
    	
        shared/policy/on_policy.py
    ADDED
    
    | @@ -0,0 +1,177 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gym
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from gym.spaces import Box
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
         | 
| 8 | 
            +
            from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from shared.module import FeatureExtractor
         | 
| 11 | 
            +
            from shared.policy.actor import PiForward, StateDependentNoiseActorHead, actor_head
         | 
| 12 | 
            +
            from shared.policy.critic import CriticHead
         | 
| 13 | 
            +
            from shared.policy.policy import ACTIVATION, Policy
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            class Step(NamedTuple):
         | 
| 17 | 
            +
                a: np.ndarray
         | 
| 18 | 
            +
                v: np.ndarray
         | 
| 19 | 
            +
                logp_a: np.ndarray
         | 
| 20 | 
            +
                clamped_a: np.ndarray
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class ACForward(NamedTuple):
         | 
| 24 | 
            +
                logp_a: torch.Tensor
         | 
| 25 | 
            +
                entropy: torch.Tensor
         | 
| 26 | 
            +
                v: torch.Tensor
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            FEAT_EXT_FILE_NAME = "feat_ext.pt"
         | 
| 30 | 
            +
            V_FEAT_EXT_FILE_NAME = "v_feat_ext.pt"
         | 
| 31 | 
            +
            PI_FILE_NAME = "pi.pt"
         | 
| 32 | 
            +
            V_FILE_NAME = "v.pt"
         | 
| 33 | 
            +
            ActorCriticSelf = TypeVar("ActorCriticSelf", bound="ActorCritic")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def clamp_actions(
         | 
| 37 | 
            +
                actions: np.ndarray, action_space: gym.Space, squash_output: bool
         | 
| 38 | 
            +
            ) -> np.ndarray:
         | 
| 39 | 
            +
                if isinstance(action_space, Box):
         | 
| 40 | 
            +
                    low, high = action_space.low, action_space.high  # type: ignore
         | 
| 41 | 
            +
                    if squash_output:
         | 
| 42 | 
            +
                        # Squashed output is already between -1 and 1. Rescale if the actual
         | 
| 43 | 
            +
                        # output needs to something other than -1 and 1
         | 
| 44 | 
            +
                        return low + 0.5 * (actions + 1) * (high - low)
         | 
| 45 | 
            +
                    else:
         | 
| 46 | 
            +
                        return np.clip(actions, low, high)
         | 
| 47 | 
            +
                return actions
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class ActorCritic(Policy):
         | 
| 51 | 
            +
                def __init__(
         | 
| 52 | 
            +
                    self,
         | 
| 53 | 
            +
                    env: VecEnv,
         | 
| 54 | 
            +
                    pi_hidden_sizes: Sequence[int],
         | 
| 55 | 
            +
                    v_hidden_sizes: Sequence[int],
         | 
| 56 | 
            +
                    init_layers_orthogonal: bool = True,
         | 
| 57 | 
            +
                    activation_fn: str = "tanh",
         | 
| 58 | 
            +
                    log_std_init: float = -0.5,
         | 
| 59 | 
            +
                    use_sde: bool = False,
         | 
| 60 | 
            +
                    full_std: bool = True,
         | 
| 61 | 
            +
                    squash_output: bool = False,
         | 
| 62 | 
            +
                    share_features_extractor: bool = True,
         | 
| 63 | 
            +
                    cnn_feature_dim: int = 512,
         | 
| 64 | 
            +
                    **kwargs,
         | 
| 65 | 
            +
                ) -> None:
         | 
| 66 | 
            +
                    super().__init__(env, **kwargs)
         | 
| 67 | 
            +
                    activation = ACTIVATION[activation_fn]
         | 
| 68 | 
            +
                    observation_space = env.observation_space
         | 
| 69 | 
            +
                    self.action_space = env.action_space
         | 
| 70 | 
            +
                    self.squash_output = squash_output
         | 
| 71 | 
            +
                    self.share_features_extractor = share_features_extractor
         | 
| 72 | 
            +
                    self._feature_extractor = FeatureExtractor(
         | 
| 73 | 
            +
                        observation_space,
         | 
| 74 | 
            +
                        activation,
         | 
| 75 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 76 | 
            +
                        cnn_feature_dim=cnn_feature_dim,
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
                    self._pi = actor_head(
         | 
| 79 | 
            +
                        self.action_space,
         | 
| 80 | 
            +
                        (self._feature_extractor.out_dim,) + tuple(pi_hidden_sizes),
         | 
| 81 | 
            +
                        init_layers_orthogonal,
         | 
| 82 | 
            +
                        activation,
         | 
| 83 | 
            +
                        log_std_init=log_std_init,
         | 
| 84 | 
            +
                        use_sde=use_sde,
         | 
| 85 | 
            +
                        full_std=full_std,
         | 
| 86 | 
            +
                        squash_output=squash_output,
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if not share_features_extractor:
         | 
| 90 | 
            +
                        self._v_feature_extractor = FeatureExtractor(
         | 
| 91 | 
            +
                            observation_space,
         | 
| 92 | 
            +
                            activation,
         | 
| 93 | 
            +
                            init_layers_orthogonal=init_layers_orthogonal,
         | 
| 94 | 
            +
                            cnn_feature_dim=cnn_feature_dim,
         | 
| 95 | 
            +
                        )
         | 
| 96 | 
            +
                        v_hidden_sizes = (self._v_feature_extractor.out_dim,) + tuple(
         | 
| 97 | 
            +
                            v_hidden_sizes
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
                    else:
         | 
| 100 | 
            +
                        self._v_feature_extractor = None
         | 
| 101 | 
            +
                        v_hidden_sizes = (self._feature_extractor.out_dim,) + tuple(v_hidden_sizes)
         | 
| 102 | 
            +
                    self._v = CriticHead(
         | 
| 103 | 
            +
                        hidden_sizes=v_hidden_sizes,
         | 
| 104 | 
            +
                        activation=activation,
         | 
| 105 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 106 | 
            +
                    )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def _pi_forward(
         | 
| 109 | 
            +
                    self, obs: torch.Tensor, action: Optional[torch.Tensor] = None
         | 
| 110 | 
            +
                ) -> Tuple[PiForward, torch.Tensor]:
         | 
| 111 | 
            +
                    p_fe = self._feature_extractor(obs)
         | 
| 112 | 
            +
                    pi_forward = self._pi(p_fe, action)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    return pi_forward, p_fe
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
         | 
| 117 | 
            +
                    v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
         | 
| 118 | 
            +
                    return self._v(v_fe)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def forward(self, obs: torch.Tensor, action: torch.Tensor) -> ACForward:
         | 
| 121 | 
            +
                    (_, logp_a, entropy), p_fc = self._pi_forward(obs, action)
         | 
| 122 | 
            +
                    v = self._v_forward(obs, p_fc)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    assert logp_a is not None
         | 
| 125 | 
            +
                    assert entropy is not None
         | 
| 126 | 
            +
                    return ACForward(logp_a, entropy, v)
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
         | 
| 129 | 
            +
                    assert isinstance(obs, np.ndarray)
         | 
| 130 | 
            +
                    o = torch.as_tensor(obs)
         | 
| 131 | 
            +
                    if self.device is not None:
         | 
| 132 | 
            +
                        o = o.to(self.device)
         | 
| 133 | 
            +
                    return o
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def value(self, obs: VecEnvObs) -> np.ndarray:
         | 
| 136 | 
            +
                    o = self._as_tensor(obs)
         | 
| 137 | 
            +
                    with torch.no_grad():
         | 
| 138 | 
            +
                        fe = (
         | 
| 139 | 
            +
                            self._v_feature_extractor(o)
         | 
| 140 | 
            +
                            if self._v_feature_extractor
         | 
| 141 | 
            +
                            else self._feature_extractor(o)
         | 
| 142 | 
            +
                        )
         | 
| 143 | 
            +
                        v = self._v(fe)
         | 
| 144 | 
            +
                    return v.cpu().numpy()
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def step(self, obs: VecEnvObs) -> Step:
         | 
| 147 | 
            +
                    o = self._as_tensor(obs)
         | 
| 148 | 
            +
                    with torch.no_grad():
         | 
| 149 | 
            +
                        (pi, _, _), p_fc = self._pi_forward(o)
         | 
| 150 | 
            +
                        a = pi.sample()
         | 
| 151 | 
            +
                        logp_a = pi.log_prob(a)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                        v = self._v_forward(o, p_fc)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    a_np = a.cpu().numpy()
         | 
| 156 | 
            +
                    clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
         | 
| 157 | 
            +
                    return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
         | 
| 160 | 
            +
                    if not deterministic:
         | 
| 161 | 
            +
                        return self.step(obs).clamped_a
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        o = self._as_tensor(obs)
         | 
| 164 | 
            +
                        with torch.no_grad():
         | 
| 165 | 
            +
                            (pi, _, _), _ = self._pi_forward(o)
         | 
| 166 | 
            +
                            a = pi.mode
         | 
| 167 | 
            +
                        return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def load(self, path: str) -> None:
         | 
| 170 | 
            +
                    super().load(path)
         | 
| 171 | 
            +
                    self.reset_noise()
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                def reset_noise(self, batch_size: Optional[int] = None) -> None:
         | 
| 174 | 
            +
                    if isinstance(self._pi, StateDependentNoiseActorHead):
         | 
| 175 | 
            +
                        self._pi.sample_weights(
         | 
| 176 | 
            +
                            batch_size=batch_size if batch_size else self.env.num_envs
         | 
| 177 | 
            +
                        )
         | 
    	
        shared/policy/policy.py
    ADDED
    
    | @@ -0,0 +1,60 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from abc import ABC, abstractmethod
         | 
| 7 | 
            +
            from stable_baselines3.common.vec_env import unwrap_vec_normalize
         | 
| 8 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
         | 
| 9 | 
            +
            from typing import Dict, Optional, Type, TypeVar, Union
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            ACTIVATION: Dict[str, Type[nn.Module]] = {
         | 
| 12 | 
            +
                "tanh": nn.Tanh,
         | 
| 13 | 
            +
                "relu": nn.ReLU,
         | 
| 14 | 
            +
            }
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            VEC_NORMALIZE_FILENAME = "vecnormalize.pkl"
         | 
| 17 | 
            +
            MODEL_FILENAME = "model.pth"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            PolicySelf = TypeVar("PolicySelf", bound="Policy")
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class Policy(nn.Module, ABC):
         | 
| 23 | 
            +
                @abstractmethod
         | 
| 24 | 
            +
                def __init__(self, env: VecEnv, **kwargs) -> None:
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.env = env
         | 
| 27 | 
            +
                    self.vec_normalize = unwrap_vec_normalize(env)
         | 
| 28 | 
            +
                    self.device = None
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def to(
         | 
| 31 | 
            +
                    self: PolicySelf,
         | 
| 32 | 
            +
                    device: Optional[torch.device] = None,
         | 
| 33 | 
            +
                    dtype: Optional[Union[torch.dtype, str]] = None,
         | 
| 34 | 
            +
                    non_blocking: bool = False,
         | 
| 35 | 
            +
                ) -> PolicySelf:
         | 
| 36 | 
            +
                    super().to(device, dtype, non_blocking)
         | 
| 37 | 
            +
                    self.device = device
         | 
| 38 | 
            +
                    return self
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                @abstractmethod
         | 
| 41 | 
            +
                def act(self, obs: VecEnvObs, deterministic: bool = True) -> np.ndarray:
         | 
| 42 | 
            +
                    ...
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def save(self, path: str) -> None:
         | 
| 45 | 
            +
                    os.makedirs(path, exist_ok=True)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    if self.vec_normalize:
         | 
| 48 | 
            +
                        self.vec_normalize.save(os.path.join(path, VEC_NORMALIZE_FILENAME))
         | 
| 49 | 
            +
                    torch.save(
         | 
| 50 | 
            +
                        self.state_dict(),
         | 
| 51 | 
            +
                        os.path.join(path, MODEL_FILENAME),
         | 
| 52 | 
            +
                    )
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                @abstractmethod
         | 
| 55 | 
            +
                def load(self, path: str) -> None:
         | 
| 56 | 
            +
                    # VecNormalize load occurs in env.py
         | 
| 57 | 
            +
                    self.load_state_dict(torch.load(os.path.join(path, MODEL_FILENAME)))
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def reset_noise(self) -> None:
         | 
| 60 | 
            +
                    pass
         | 
    	
        shared/schedule.py
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Callable
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Schedule = Callable[[float], float]
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def linear_schedule(
         | 
| 7 | 
            +
                start_val: float, end_val: float, end_fraction: float = 1.0
         | 
| 8 | 
            +
            ) -> Schedule:
         | 
| 9 | 
            +
                def func(progress_fraction: float) -> float:
         | 
| 10 | 
            +
                    if progress_fraction >= end_fraction:
         | 
| 11 | 
            +
                        return end_val
         | 
| 12 | 
            +
                    else:
         | 
| 13 | 
            +
                        return start_val + (end_val - start_val) * progress_fraction / end_fraction
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                return func
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def constant_schedule(val: float) -> Schedule:
         | 
| 19 | 
            +
                return lambda f: val
         | 
    	
        shared/stats.py
    ADDED
    
    | @@ -0,0 +1,173 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from torch.utils.tensorboard.writer import SummaryWriter
         | 
| 5 | 
            +
            from typing import Dict, List, Optional, Sequence, TypeVar
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @dataclass
         | 
| 9 | 
            +
            class Episode:
         | 
| 10 | 
            +
                score: float = 0
         | 
| 11 | 
            +
                length: int = 0
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            StatisticSelf = TypeVar("StatisticSelf", bound="Statistic")
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            @dataclass
         | 
| 18 | 
            +
            class Statistic:
         | 
| 19 | 
            +
                values: np.ndarray
         | 
| 20 | 
            +
                round_digits: int = 2
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                @property
         | 
| 23 | 
            +
                def mean(self) -> float:
         | 
| 24 | 
            +
                    return np.mean(self.values).item()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                @property
         | 
| 27 | 
            +
                def std(self) -> float:
         | 
| 28 | 
            +
                    return np.std(self.values).item()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                @property
         | 
| 31 | 
            +
                def min(self) -> float:
         | 
| 32 | 
            +
                    return np.min(self.values).item()
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                @property
         | 
| 35 | 
            +
                def max(self) -> float:
         | 
| 36 | 
            +
                    return np.max(self.values).item()
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def sum(self) -> float:
         | 
| 39 | 
            +
                    return np.sum(self.values).item()
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __len__(self) -> int:
         | 
| 42 | 
            +
                    return len(self.values)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def _diff(self: StatisticSelf, o: StatisticSelf) -> float:
         | 
| 45 | 
            +
                    return (self.mean - self.std) - (o.mean - o.std)
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __gt__(self: StatisticSelf, o: StatisticSelf) -> bool:
         | 
| 48 | 
            +
                    return self._diff(o) > 0
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                def __ge__(self: StatisticSelf, o: StatisticSelf) -> bool:
         | 
| 51 | 
            +
                    return self._diff(o) >= 0
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def __repr__(self) -> str:
         | 
| 54 | 
            +
                    mean = round(self.mean, self.round_digits)
         | 
| 55 | 
            +
                    std = round(self.std, self.round_digits)
         | 
| 56 | 
            +
                    if self.round_digits == 0:
         | 
| 57 | 
            +
                        mean = int(mean)
         | 
| 58 | 
            +
                        std = int(std)
         | 
| 59 | 
            +
                    return f"{mean} +/- {std}"
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def to_dict(self) -> Dict[str, float]:
         | 
| 62 | 
            +
                    return {
         | 
| 63 | 
            +
                        "mean": self.mean,
         | 
| 64 | 
            +
                        "std": self.std,
         | 
| 65 | 
            +
                        "min": self.min,
         | 
| 66 | 
            +
                        "max": self.max,
         | 
| 67 | 
            +
                    }
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            EpisodesStatsSelf = TypeVar("EpisodesStatsSelf", bound="EpisodesStats")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class EpisodesStats:
         | 
| 74 | 
            +
                episodes: Sequence[Episode]
         | 
| 75 | 
            +
                simple: bool
         | 
| 76 | 
            +
                score: Statistic
         | 
| 77 | 
            +
                length: Statistic
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None:
         | 
| 80 | 
            +
                    self.episodes = episodes
         | 
| 81 | 
            +
                    self.simple = simple
         | 
| 82 | 
            +
                    self.score = Statistic(np.array([e.score for e in episodes]))
         | 
| 83 | 
            +
                    self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
         | 
| 86 | 
            +
                    return self.score > o.score
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __ge__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool:
         | 
| 89 | 
            +
                    return self.score >= o.score
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def __repr__(self) -> str:
         | 
| 92 | 
            +
                    return (
         | 
| 93 | 
            +
                        f"Score: {self.score} ({round(self.score.mean - self.score.std, 2)}) | "
         | 
| 94 | 
            +
                        f"Length: {self.length}"
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def _asdict(self) -> dict:
         | 
| 98 | 
            +
                    return {
         | 
| 99 | 
            +
                        "n_episodes": len(self.episodes),
         | 
| 100 | 
            +
                        "score": self.score.to_dict(),
         | 
| 101 | 
            +
                        "length": self.length.to_dict(),
         | 
| 102 | 
            +
                    }
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def write_to_tensorboard(
         | 
| 105 | 
            +
                    self, tb_writer: SummaryWriter, main_tag: str, global_step: Optional[int] = None
         | 
| 106 | 
            +
                ) -> None:
         | 
| 107 | 
            +
                    stats = {"mean": self.score.mean}
         | 
| 108 | 
            +
                    if not self.simple:
         | 
| 109 | 
            +
                        stats.update(
         | 
| 110 | 
            +
                            {
         | 
| 111 | 
            +
                                "min": self.score.min,
         | 
| 112 | 
            +
                                "max": self.score.max,
         | 
| 113 | 
            +
                                "result": self.score.mean - self.score.std,
         | 
| 114 | 
            +
                                "n_episodes": len(self.episodes),
         | 
| 115 | 
            +
                            }
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                    tb_writer.add_scalars(
         | 
| 118 | 
            +
                        main_tag,
         | 
| 119 | 
            +
                        stats,
         | 
| 120 | 
            +
                        global_step=global_step,
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            class EpisodeAccumulator:
         | 
| 125 | 
            +
                def __init__(self, num_envs: int):
         | 
| 126 | 
            +
                    self._episodes = []
         | 
| 127 | 
            +
                    self.current_episodes = [Episode() for _ in range(num_envs)]
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                @property
         | 
| 130 | 
            +
                def episodes(self) -> List[Episode]:
         | 
| 131 | 
            +
                    return self._episodes
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def step(self, reward: np.ndarray, done: np.ndarray) -> None:
         | 
| 134 | 
            +
                    for idx, current in enumerate(self.current_episodes):
         | 
| 135 | 
            +
                        current.score += reward[idx]
         | 
| 136 | 
            +
                        current.length += 1
         | 
| 137 | 
            +
                        if done[idx]:
         | 
| 138 | 
            +
                            self._episodes.append(current)
         | 
| 139 | 
            +
                            self.on_done(idx, current)
         | 
| 140 | 
            +
                            self.current_episodes[idx] = Episode()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def __len__(self) -> int:
         | 
| 143 | 
            +
                    return len(self.episodes)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                def on_done(self, ep_idx: int, episode: Episode) -> None:
         | 
| 146 | 
            +
                    pass
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                def stats(self) -> EpisodesStats:
         | 
| 149 | 
            +
                    return EpisodesStats(self.episodes)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
            class RolloutStats(EpisodeAccumulator):
         | 
| 153 | 
            +
                def __init__(self, num_envs: int, print_n_episodes: int, tb_writer: SummaryWriter):
         | 
| 154 | 
            +
                    super().__init__(num_envs)
         | 
| 155 | 
            +
                    self.print_n_episodes = print_n_episodes
         | 
| 156 | 
            +
                    self.epochs: List[EpisodesStats] = []
         | 
| 157 | 
            +
                    self.tb_writer = tb_writer
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                def on_done(self, ep_idx: int, episode: Episode) -> None:
         | 
| 160 | 
            +
                    if (
         | 
| 161 | 
            +
                        self.print_n_episodes >= 0
         | 
| 162 | 
            +
                        and len(self.episodes) % self.print_n_episodes == 0
         | 
| 163 | 
            +
                    ):
         | 
| 164 | 
            +
                        sample = self.episodes[-self.print_n_episodes :]
         | 
| 165 | 
            +
                        epoch = EpisodesStats(sample)
         | 
| 166 | 
            +
                        self.epochs.append(epoch)
         | 
| 167 | 
            +
                        total_steps = np.sum([e.length for e in self.episodes])
         | 
| 168 | 
            +
                        print(
         | 
| 169 | 
            +
                            f"Episode: {len(self.episodes)} | "
         | 
| 170 | 
            +
                            f"{epoch} | "
         | 
| 171 | 
            +
                            f"Total Steps: {total_steps}"
         | 
| 172 | 
            +
                        )
         | 
| 173 | 
            +
                        epoch.write_to_tensorboard(self.tb_writer, "train", global_step=total_steps)
         | 
    	
        shared/trajectory.py
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from dataclasses import dataclass
         | 
| 5 | 
            +
            from typing import List
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            @dataclass
         | 
| 9 | 
            +
            class Trajectory:
         | 
| 10 | 
            +
                obs: List[np.ndarray]
         | 
| 11 | 
            +
                act: List[np.ndarray]
         | 
| 12 | 
            +
                rew: List[float]
         | 
| 13 | 
            +
                v: List[float]
         | 
| 14 | 
            +
                terminated: bool
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def __init__(self) -> None:
         | 
| 17 | 
            +
                    self.obs = []
         | 
| 18 | 
            +
                    self.act = []
         | 
| 19 | 
            +
                    self.rew = []
         | 
| 20 | 
            +
                    self.v = []
         | 
| 21 | 
            +
                    self.terminated = False
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def add(self, obs: np.ndarray, act: np.ndarray, rew: float, v: float):
         | 
| 24 | 
            +
                    self.obs.append(obs)
         | 
| 25 | 
            +
                    self.act.append(act)
         | 
| 26 | 
            +
                    self.rew.append(rew)
         | 
| 27 | 
            +
                    self.v.append(v)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __len__(self) -> int:
         | 
| 30 | 
            +
                    return len(self.obs)
         | 
    	
        shared/utils.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def discounted_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
         | 
| 5 | 
            +
                dc = x.copy()
         | 
| 6 | 
            +
                for i in reversed(range(len(x) - 1)):
         | 
| 7 | 
            +
                    dc[i] += gamma * dc[i + 1]
         | 
| 8 | 
            +
                return dc
         | 
    	
        train.py
    ADDED
    
    | @@ -0,0 +1,81 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import itertools
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from argparse import Namespace
         | 
| 9 | 
            +
            from multiprocessing import Pool
         | 
| 10 | 
            +
            from typing import Any, Dict
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            from runner.running_utils import base_parser
         | 
| 13 | 
            +
            from runner.train import train, TrainArgs
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def args_dict(algo: str, env: str, seed: str, args: Namespace) -> Dict[str, Any]:
         | 
| 17 | 
            +
                d = vars(args).copy()
         | 
| 18 | 
            +
                d.update(
         | 
| 19 | 
            +
                    {
         | 
| 20 | 
            +
                        "algo": algo,
         | 
| 21 | 
            +
                        "env": env,
         | 
| 22 | 
            +
                        "seed": seed,
         | 
| 23 | 
            +
                    }
         | 
| 24 | 
            +
                )
         | 
| 25 | 
            +
                return d
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            if __name__ == "__main__":
         | 
| 29 | 
            +
                parser = base_parser()
         | 
| 30 | 
            +
                parser.add_argument(
         | 
| 31 | 
            +
                    "--wandb-project-name",
         | 
| 32 | 
            +
                    type=str,
         | 
| 33 | 
            +
                    default="rl-algo-impls",
         | 
| 34 | 
            +
                    help="WandB project namme to upload training data to. If none, won't upload.",
         | 
| 35 | 
            +
                )
         | 
| 36 | 
            +
                parser.add_argument(
         | 
| 37 | 
            +
                    "--wandb-entity",
         | 
| 38 | 
            +
                    type=str,
         | 
| 39 | 
            +
                    default=None,
         | 
| 40 | 
            +
                    help="WandB team of project. None uses default entity",
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
                parser.add_argument(
         | 
| 43 | 
            +
                    "--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
                parser.add_argument(
         | 
| 46 | 
            +
                    "--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
         | 
| 47 | 
            +
                )
         | 
| 48 | 
            +
                parser.add_argument(
         | 
| 49 | 
            +
                    "--virtual-display",
         | 
| 50 | 
            +
                    action="store_true",
         | 
| 51 | 
            +
                    help="Whether to create a virtual display for video rendering",
         | 
| 52 | 
            +
                )
         | 
| 53 | 
            +
                parser.set_defaults(algo="ppo", env="CartPole-v1", seed=1)
         | 
| 54 | 
            +
                args = parser.parse_args()
         | 
| 55 | 
            +
                print(args)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                if args.virtual_display:
         | 
| 58 | 
            +
                    from pyvirtualdisplay import Display
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    virtual_display = Display(visible=0, size=(1400, 900))
         | 
| 61 | 
            +
                    virtual_display.start()
         | 
| 62 | 
            +
                delattr(args, "virtual_display")
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                # pool_size isn't a TrainArg so must be removed from args
         | 
| 65 | 
            +
                pool_size = args.pool_size
         | 
| 66 | 
            +
                delattr(args, "pool_size")
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                algos = args.algo if isinstance(args.algo, list) else [args.algo]
         | 
| 69 | 
            +
                envs = args.env if isinstance(args.env, list) else [args.env]
         | 
| 70 | 
            +
                seeds = args.seed if isinstance(args.seed, list) else [args.seed]
         | 
| 71 | 
            +
                if all(len(arg) == 1 for arg in [algos, envs, seeds]):
         | 
| 72 | 
            +
                    train(TrainArgs(**args_dict(algos[0], envs[0], seeds[0], args)))
         | 
| 73 | 
            +
                else:
         | 
| 74 | 
            +
                    # Force a new process for each job to get around wandb not allowing more than one
         | 
| 75 | 
            +
                    # wandb.tensorboard.patch call per process.
         | 
| 76 | 
            +
                    with Pool(pool_size, maxtasksperchild=1) as p:
         | 
| 77 | 
            +
                        train_args = [
         | 
| 78 | 
            +
                            TrainArgs(**args_dict(algo, env, seed, args))
         | 
| 79 | 
            +
                            for algo, env, seed in itertools.product(algos, envs, seeds)
         | 
| 80 | 
            +
                        ]
         | 
| 81 | 
            +
                        p.map(train, train_args)
         | 
    	
        vpg/policy.py
    ADDED
    
    | @@ -0,0 +1,119 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from gym.spaces import Box
         | 
| 6 | 
            +
            from pathlib import Path
         | 
| 7 | 
            +
            from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
         | 
| 8 | 
            +
            from typing import NamedTuple, Optional, Sequence, TypeVar
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from shared.module import FeatureExtractor
         | 
| 11 | 
            +
            from shared.policy.actor import (
         | 
| 12 | 
            +
                PiForward,
         | 
| 13 | 
            +
                Actor,
         | 
| 14 | 
            +
                StateDependentNoiseActorHead,
         | 
| 15 | 
            +
                actor_head,
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
            from shared.policy.critic import CriticHead
         | 
| 18 | 
            +
            from shared.policy.on_policy import Step, clamp_actions
         | 
| 19 | 
            +
            from shared.policy.policy import ACTIVATION, Policy
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            PI_FILE_NAME = "pi.pt"
         | 
| 22 | 
            +
            V_FILE_NAME = "v.pt"
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class VPGActor(Actor):
         | 
| 26 | 
            +
                def __init__(self, feature_extractor: FeatureExtractor, head: Actor) -> None:
         | 
| 27 | 
            +
                    super().__init__()
         | 
| 28 | 
            +
                    self.feature_extractor = feature_extractor
         | 
| 29 | 
            +
                    self.head = head
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward:
         | 
| 32 | 
            +
                    fe = self.feature_extractor(obs)
         | 
| 33 | 
            +
                    return self.head(fe, a)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            class VPGActorCritic(Policy):
         | 
| 37 | 
            +
                def __init__(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    env: VecEnv,
         | 
| 40 | 
            +
                    hidden_sizes: Sequence[int],
         | 
| 41 | 
            +
                    init_layers_orthogonal: bool = True,
         | 
| 42 | 
            +
                    activation_fn: str = "tanh",
         | 
| 43 | 
            +
                    log_std_init: float = -0.5,
         | 
| 44 | 
            +
                    use_sde: bool = False,
         | 
| 45 | 
            +
                    full_std: bool = True,
         | 
| 46 | 
            +
                    squash_output: bool = False,
         | 
| 47 | 
            +
                    **kwargs,
         | 
| 48 | 
            +
                ) -> None:
         | 
| 49 | 
            +
                    super().__init__(env, **kwargs)
         | 
| 50 | 
            +
                    activation = ACTIVATION[activation_fn]
         | 
| 51 | 
            +
                    obs_space = env.observation_space
         | 
| 52 | 
            +
                    self.action_space = env.action_space
         | 
| 53 | 
            +
                    self.use_sde = use_sde
         | 
| 54 | 
            +
                    self.squash_output = squash_output
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    pi_feature_extractor = FeatureExtractor(
         | 
| 57 | 
            +
                        obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
                    pi_head = actor_head(
         | 
| 60 | 
            +
                        self.action_space,
         | 
| 61 | 
            +
                        (pi_feature_extractor.out_dim,) + tuple(hidden_sizes),
         | 
| 62 | 
            +
                        init_layers_orthogonal,
         | 
| 63 | 
            +
                        activation,
         | 
| 64 | 
            +
                        log_std_init=log_std_init,
         | 
| 65 | 
            +
                        use_sde=use_sde,
         | 
| 66 | 
            +
                        full_std=full_std,
         | 
| 67 | 
            +
                        squash_output=squash_output,
         | 
| 68 | 
            +
                    )
         | 
| 69 | 
            +
                    self.pi = VPGActor(pi_feature_extractor, pi_head)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    v_feature_extractor = FeatureExtractor(
         | 
| 72 | 
            +
                        obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
         | 
| 73 | 
            +
                    )
         | 
| 74 | 
            +
                    v_head = CriticHead(
         | 
| 75 | 
            +
                        (v_feature_extractor.out_dim,) + tuple(hidden_sizes),
         | 
| 76 | 
            +
                        activation=activation,
         | 
| 77 | 
            +
                        init_layers_orthogonal=init_layers_orthogonal,
         | 
| 78 | 
            +
                    )
         | 
| 79 | 
            +
                    self.v = nn.Sequential(v_feature_extractor, v_head)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def _as_tensor(self, obs: VecEnvObs) -> torch.Tensor:
         | 
| 82 | 
            +
                    assert isinstance(obs, np.ndarray)
         | 
| 83 | 
            +
                    o = torch.as_tensor(obs)
         | 
| 84 | 
            +
                    if self.device is not None:
         | 
| 85 | 
            +
                        o = o.to(self.device)
         | 
| 86 | 
            +
                    return o
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def step(self, obs: VecEnvObs) -> Step:
         | 
| 89 | 
            +
                    o = self._as_tensor(obs)
         | 
| 90 | 
            +
                    with torch.no_grad():
         | 
| 91 | 
            +
                        pi, _, _ = self.pi(o)
         | 
| 92 | 
            +
                        a = pi.sample()
         | 
| 93 | 
            +
                        logp_a = pi.log_prob(a)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                        v = self.v(o)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    a_np = a.cpu().numpy()
         | 
| 98 | 
            +
                    clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
         | 
| 99 | 
            +
                    return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def act(self, obs: np.ndarray, deterministic: bool = True) -> np.ndarray:
         | 
| 102 | 
            +
                    if not deterministic:
         | 
| 103 | 
            +
                        return self.step(obs).clamped_a
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        o = self._as_tensor(obs)
         | 
| 106 | 
            +
                        with torch.no_grad():
         | 
| 107 | 
            +
                            pi, _, _ = self.pi(o)
         | 
| 108 | 
            +
                            a = pi.mode
         | 
| 109 | 
            +
                        return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def load(self, path: str) -> None:
         | 
| 112 | 
            +
                    super().load(path)
         | 
| 113 | 
            +
                    self.reset_noise()
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def reset_noise(self, batch_size: Optional[int] = None) -> None:
         | 
| 116 | 
            +
                    if isinstance(self.pi.head, StateDependentNoiseActorHead):
         | 
| 117 | 
            +
                        self.pi.head.sample_weights(
         | 
| 118 | 
            +
                            batch_size=batch_size if batch_size else self.env.num_envs
         | 
| 119 | 
            +
                        )
         |